From 7f1417c839c98b547282012ce65fce013798bf9b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 22 May 2025 17:02:47 +0200 Subject: [PATCH] Fix GEMV dot case with empty output and beta 0 Bug introduced in 709f745ccdfe06c1b0aad24cbdec139ec10c03ff --- pytensor/tensor/blas.py | 23 +++++++++------ pytensor/tensor/blas_c.py | 47 ++++++++++++++++++------------- pytensor/tensor/rewriting/blas.py | 2 +- tests/tensor/test_blas_c.py | 34 +++++++++++++--------- 4 files changed, 63 insertions(+), 43 deletions(-) diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 0977b500a6..fc8afcea50 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -113,23 +113,22 @@ _logger = logging.getLogger("pytensor.tensor.blas") -# If check_init_y() == True we need to initialize y when beta == 0. -def check_init_y(): - # TODO: What is going on here? +def must_initialize_y_gemv(): + # Check whether Scipy GEMV could output nan if y in not initialized from scipy.linalg.blas import get_blas_funcs - if check_init_y._result is None: - y = float("NaN") * np.ones((2,)) + if must_initialize_y_gemv._result is None: + y = np.full((2,), np.nan) x = np.ones((2,)) A = np.ones((2, 2)) gemv = get_blas_funcs("gemv", dtype=y.dtype) gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) - check_init_y._result = np.isnan(y).any() + must_initialize_y_gemv._result = np.isnan(y).any() - return check_init_y._result + return must_initialize_y_gemv._result -check_init_y._result = None # type: ignore +must_initialize_y_gemv._result = None # type: ignore class Gemv(Op): @@ -197,7 +196,13 @@ def perform(self, node, inputs, out_storage): f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}" ) - if beta == 0 and check_init_y(): + if beta == 0 and must_initialize_y_gemv(): + # Most BLAS implementations of GEMV ignore y=nan when beta=0 + # PyTensor considers that the correct behavior, + # and even exploits it to avoid copying or initializing outputs. + # By deciding to exploit this, however, it becomes our responsibility + # to ensure the behavior even in the rare cases BLAS deviates, + # or users will get errors, even for graphs that had no nan to begin with. y.fill(0) # Here I suppose that A is in c order. If we don't make it diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas_c.py index 9513163027..0ef0a1f476 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas_c.py @@ -336,11 +336,12 @@ def c_code_cache_version(self): # ##### ####### ####### -def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=None): +def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=None): """ z <- beta * y + alpha * dot(A, x) where A is a matrix, y and x are vectors (ergo z is vector) + z = y if inplace else y.copy() """ code = """ @@ -400,17 +401,11 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non } if (dbeta != 0) { + // If dbeta is zero, we avoid doing the copy if (PyArray_CopyInto(%(z)s, %(y)s) != 0) { %(fail)s } } - else if (%(force_init_beta)d) - { - PyObject *zero = PyFloat_FromDouble(0.); - if (zero == NULL) %(fail)s; - if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s; - Py_DECREF(zero); - } } else { @@ -422,6 +417,20 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non } } + if (%(must_initialize_y)d && dbeta == 0) + { + // Most BLAS implementations of GEMV ignore y=nan when beta=0 + // PyTensor considers that the correct behavior, + // and even exploits it to avoid copying or initializing outputs. + // By deciding to exploit this, however, it becomes our responsibility + // to ensure the behavior even in the rare cases BLAS deviates, + // or users will get errors, even for graphs that had no nan to begin with. + PyObject *zero = PyFloat_FromDouble(0.); + if (zero == NULL) %(fail)s; + if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s; + Py_DECREF(zero); + } + { int NA0 = PyArray_DIMS(%(A)s)[0]; int NA1 = PyArray_DIMS(%(A)s)[1]; @@ -491,13 +500,13 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=Non if (is_float) { - z_data[0] *= fbeta; + z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.f; z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1, (float*)x_data, &Sx); } else { - z_data[0] *= dbeta; + z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.; z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1, (double*)x_data, &Sx); } @@ -583,21 +592,21 @@ def c_code(self, node, name, inp, out, sub): alpha, beta, fail=sub["fail"], - force_init_beta=check_force_gemv_init(), + must_initialize_y=must_initialize_y_gemv(), params=sub["params"], ) return code def c_code_cache_version(self): - return (16, blas_header_version(), check_force_gemv_init()) + return (17, blas_header_version(), must_initialize_y_gemv()) cgemv_inplace = CGemv(inplace=True) cgemv_no_inplace = CGemv(inplace=False) -def check_force_gemv_init(): - if check_force_gemv_init._force_init_beta is None: +def must_initialize_y_gemv(): + if must_initialize_y_gemv._force_init_beta is None: from pytensor.link.c.cmodule import GCC_compiler """ @@ -643,13 +652,13 @@ def check_force_gemv_init(): ) if res: if res[0]: - check_force_gemv_init._force_init_beta = res[1] + must_initialize_y_gemv._force_init_beta = res[1] else: - check_force_gemv_init._force_init_beta = False + must_initialize_y_gemv._force_init_beta = False else: - check_force_gemv_init._force_init_beta = False + must_initialize_y_gemv._force_init_beta = False - return check_force_gemv_init._force_init_beta + return must_initialize_y_gemv._force_init_beta -check_force_gemv_init._force_init_beta = None +must_initialize_y_gemv._force_init_beta = None diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index b5c2564481..e626b0720b 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -700,7 +700,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node): new_out = [rval] elif xb[0] and yb[1]: # x and y are both vectors so this qualifies for a sdot / ddot - # TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22 + # PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not xv = x.dimshuffle(1) zeros = ptb.AllocEmpty(x.dtype)(1) rval = gemv_no_inplace(zeros, one, y.T, xv, zero) diff --git a/tests/tensor/test_blas_c.py b/tests/tensor/test_blas_c.py index e46c036766..b6ba1987b9 100644 --- a/tests/tensor/test_blas_c.py +++ b/tests/tensor/test_blas_c.py @@ -7,7 +7,7 @@ import pytensor.tensor as pt from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.blas import Ger -from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init +from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv from pytensor.tensor.blas_scipy import ScipyGer from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector from tests import unittest_tools @@ -130,31 +130,35 @@ def setup_method(self): self.dtype = dtype self.mode = pytensor.compile.get_default_mode().including("fast_run") # matrix - self.A = tensor(dtype=dtype, shape=(None, None)) + self.A = tensor("A", dtype=dtype, shape=(None, None)) self.Aval = np.ones((2, 3), dtype=dtype) # vector - self.x = tensor(dtype=dtype, shape=(None,)) - self.y = tensor(dtype=dtype, shape=(None,)) + self.x = tensor("x", dtype=dtype, shape=(None,)) + self.y = tensor("y", dtype=dtype, shape=(None,)) self.xval = np.asarray([1, 2], dtype=dtype) self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype) # scalar - self.a = tensor(dtype=dtype, shape=()) + self.a = tensor("a", dtype=dtype, shape=()) - def test_nan_beta_0(self): + @pytest.mark.parametrize("inplace", [True, False]) + def test_nan_beta_0(self, inplace): mode = self.mode.including() mode.check_isfinite = False f = pytensor.function( - [self.A, self.x, self.y, self.a], + [self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a], self.a * self.y + pt.dot(self.A, self.x), mode=mode, ) - Aval = np.ones((3, 1), dtype=self.dtype) - xval = np.ones((1,), dtype=self.dtype) - yval = float("NaN") * np.ones((3,), dtype=self.dtype) - zval = f(Aval, xval, yval, 0) - assert not np.isnan(zval).any() + [node] = f.maker.fgraph.apply_nodes + assert isinstance(node.op, CGemv) and node.op.inplace == inplace + for rows in (3, 1): + Aval = np.ones((rows, 1), dtype=self.dtype) + xval = np.ones((1,), dtype=self.dtype) + yval = np.full((rows,), np.nan, dtype=self.dtype) + zval = f(Aval, xval, yval, 0) + assert not np.isnan(zval).any() def test_optimizations_vm(self): skip_if_blas_ldflags_empty() @@ -191,8 +195,10 @@ def test_optimizations_mv(self): np.dot(self.Aval[::-1, ::-1], self.yval), ) - def test_force_gemv_init(self): - if check_force_gemv_init(): + def test_must_initialize_y_gemv(self): + if must_initialize_y_gemv(): + # FIME: This warn should be emitted by the function if we find it relevant + # Not in a test that doesn't care about the outcome either way warn( "WARNING: The current BLAS requires PyTensor to initialize" " memory for some GEMV calls which will result in a minor"