Skip to content

Commit 920fe40

Browse files
committed
Fix GEMV dot case with empty output and beta 0
Bug introduced in 709f745
1 parent 261aaf3 commit 920fe40

File tree

4 files changed

+60
-42
lines changed

4 files changed

+60
-42
lines changed

pytensor/tensor/blas.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,22 @@
113113
_logger = logging.getLogger("pytensor.tensor.blas")
114114

115115

116-
# If check_init_y() == True we need to initialize y when beta == 0.
117-
def check_init_y():
118-
# TODO: What is going on here?
116+
def must_initialize_y_gemv():
117+
# Check whether Scipy GEMV could output nan if y in not initialized
119118
from scipy.linalg.blas import get_blas_funcs
120119

121-
if check_init_y._result is None:
122-
y = float("NaN") * np.ones((2,))
120+
if must_initialize_y_gemv._result is None:
121+
y = np.full((2,), np.nan)
123122
x = np.ones((2,))
124123
A = np.ones((2, 2))
125124
gemv = get_blas_funcs("gemv", dtype=y.dtype)
126125
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
127-
check_init_y._result = np.isnan(y).any()
126+
must_initialize_y_gemv._result = np.isnan(y).any()
128127

129-
return check_init_y._result
128+
return must_initialize_y_gemv._result
130129

131130

132-
check_init_y._result = None # type: ignore
131+
must_initialize_y_gemv._result = None # type: ignore
133132

134133

135134
class Gemv(Op):
@@ -197,7 +196,13 @@ def perform(self, node, inputs, out_storage):
197196
f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}"
198197
)
199198

200-
if beta == 0 and check_init_y():
199+
if beta == 0 and must_initialize_y_gemv():
200+
# Most BLAS implementations of GEMV ignore y=nan when beta=0
201+
# PyTensor considers that the correct behavior,
202+
# and even exploits it to avoid copying or initializing outputs.
203+
# By deciding to exploit this, however, it becomes our responsibility
204+
# to ensure the behavior even in the rare cases BLAS deviates,
205+
# or users will get errors, even for graphs that had no nan to begin with.
201206
y.fill(0)
202207

203208
# Here I suppose that A is in c order. If we don't make it

pytensor/tensor/blas_c.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,12 @@ def c_code_cache_version(self):
336336
# ##### ####### #######
337337

338338

339-
def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=None):
339+
def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=None):
340340
"""
341341
z <- beta * y + alpha * dot(A, x)
342342
343343
where A is a matrix, y and x are vectors (ergo z is vector)
344+
z = y if inplace else y.copy()
344345
"""
345346
code = """
346347
@@ -400,17 +401,11 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
400401
}
401402
if (dbeta != 0)
402403
{
404+
// If dbeta is zero, we avoid doing the copy
403405
if (PyArray_CopyInto(%(z)s, %(y)s) != 0) {
404406
%(fail)s
405407
}
406408
}
407-
else if (%(force_init_beta)d)
408-
{
409-
PyObject *zero = PyFloat_FromDouble(0.);
410-
if (zero == NULL) %(fail)s;
411-
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
412-
Py_DECREF(zero);
413-
}
414409
}
415410
else
416411
{
@@ -422,6 +417,20 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
422417
}
423418
}
424419
420+
if (%(must_initialize_y)d && dbeta == 0)
421+
{
422+
// Most BLAS implementations of GEMV ignore y=nan when beta=0
423+
// PyTensor considers that the correct behavior,
424+
// and even exploits it to avoid copying or initializing outputs.
425+
// By deciding to exploit this, however, it becomes our responsibility
426+
// to ensure the behavior even in the rare cases BLAS deviates,
427+
// or users will get errors, even for graphs that had no nan to begin with.
428+
PyObject *zero = PyFloat_FromDouble(0.);
429+
if (zero == NULL) %(fail)s;
430+
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
431+
Py_DECREF(zero);
432+
}
433+
425434
{
426435
int NA0 = PyArray_DIMS(%(A)s)[0];
427436
int NA1 = PyArray_DIMS(%(A)s)[1];
@@ -491,13 +500,13 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non
491500
492501
if (is_float)
493502
{
494-
z_data[0] *= fbeta;
503+
z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.f;
495504
z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
496505
(float*)x_data, &Sx);
497506
}
498507
else
499508
{
500-
z_data[0] *= dbeta;
509+
z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.;
501510
z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
502511
(double*)x_data, &Sx);
503512
}
@@ -583,21 +592,21 @@ def c_code(self, node, name, inp, out, sub):
583592
alpha,
584593
beta,
585594
fail=sub["fail"],
586-
force_init_beta=check_force_gemv_init(),
595+
must_initialize_y=must_initialize_y_gemv(),
587596
params=sub["params"],
588597
)
589598
return code
590599

591600
def c_code_cache_version(self):
592-
return (16, blas_header_version(), check_force_gemv_init())
601+
return (17, blas_header_version(), must_initialize_y_gemv())
593602

594603

595604
cgemv_inplace = CGemv(inplace=True)
596605
cgemv_no_inplace = CGemv(inplace=False)
597606

598607

599-
def check_force_gemv_init():
600-
if check_force_gemv_init._force_init_beta is None:
608+
def must_initialize_y_gemv():
609+
if must_initialize_y_gemv._force_init_beta is None:
601610
from pytensor.link.c.cmodule import GCC_compiler
602611

603612
"""
@@ -643,13 +652,13 @@ def check_force_gemv_init():
643652
)
644653
if res:
645654
if res[0]:
646-
check_force_gemv_init._force_init_beta = res[1]
655+
must_initialize_y_gemv._force_init_beta = res[1]
647656
else:
648-
check_force_gemv_init._force_init_beta = False
657+
must_initialize_y_gemv._force_init_beta = False
649658
else:
650-
check_force_gemv_init._force_init_beta = False
659+
must_initialize_y_gemv._force_init_beta = False
651660

652-
return check_force_gemv_init._force_init_beta
661+
return must_initialize_y_gemv._force_init_beta
653662

654663

655-
check_force_gemv_init._force_init_beta = None
664+
must_initialize_y_gemv._force_init_beta = None

pytensor/tensor/rewriting/blas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
700700
new_out = [rval]
701701
elif xb[0] and yb[1]:
702702
# x and y are both vectors so this qualifies for a sdot / ddot
703-
# TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22
703+
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
704704
xv = x.dimshuffle(1)
705705
zeros = ptb.AllocEmpty(x.dtype)(1)
706706
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)

tests/tensor/test_blas_c.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytensor.tensor as pt
88
from pytensor.tensor.basic import AllocEmpty
99
from pytensor.tensor.blas import Ger
10-
from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init
10+
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
1111
from pytensor.tensor.blas_scipy import ScipyGer
1212
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
1313
from tests import unittest_tools
@@ -130,31 +130,35 @@ def setup_method(self):
130130
self.dtype = dtype
131131
self.mode = pytensor.compile.get_default_mode().including("fast_run")
132132
# matrix
133-
self.A = tensor(dtype=dtype, shape=(None, None))
133+
self.A = tensor("A", dtype=dtype, shape=(None, None))
134134
self.Aval = np.ones((2, 3), dtype=dtype)
135135

136136
# vector
137-
self.x = tensor(dtype=dtype, shape=(None,))
138-
self.y = tensor(dtype=dtype, shape=(None,))
137+
self.x = tensor("x", dtype=dtype, shape=(None,))
138+
self.y = tensor("y", dtype=dtype, shape=(None,))
139139
self.xval = np.asarray([1, 2], dtype=dtype)
140140
self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype)
141141

142142
# scalar
143-
self.a = tensor(dtype=dtype, shape=())
143+
self.a = tensor("a", dtype=dtype, shape=())
144144

145-
def test_nan_beta_0(self):
145+
@pytest.mark.parametrize("inplace", [True, False])
146+
def test_nan_beta_0(self, inplace):
146147
mode = self.mode.including()
147148
mode.check_isfinite = False
148149
f = pytensor.function(
149-
[self.A, self.x, self.y, self.a],
150+
[self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a],
150151
self.a * self.y + pt.dot(self.A, self.x),
151152
mode=mode,
152153
)
153-
Aval = np.ones((3, 1), dtype=self.dtype)
154-
xval = np.ones((1,), dtype=self.dtype)
155-
yval = float("NaN") * np.ones((3,), dtype=self.dtype)
156-
zval = f(Aval, xval, yval, 0)
157-
assert not np.isnan(zval).any()
154+
[node] = f.maker.fgraph.apply_nodes
155+
assert isinstance(node.op, CGemv) and node.op.inplace == inplace
156+
for rows in (3, 1):
157+
Aval = np.ones((rows, 1), dtype=self.dtype)
158+
xval = np.ones((1,), dtype=self.dtype)
159+
yval = np.full((rows,), np.nan, dtype=self.dtype)
160+
zval = f(Aval, xval, yval, 0)
161+
assert not np.isnan(zval).any()
158162

159163
def test_optimizations_vm(self):
160164
skip_if_blas_ldflags_empty()
@@ -192,7 +196,7 @@ def test_optimizations_mv(self):
192196
)
193197

194198
def test_force_gemv_init(self):
195-
if check_force_gemv_init():
199+
if must_initialize_y_gemv():
196200
warn(
197201
"WARNING: The current BLAS requires PyTensor to initialize"
198202
" memory for some GEMV calls which will result in a minor"

0 commit comments

Comments
 (0)