Skip to content

Fix GEMV dot case with empty output and beta 0 #1410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,23 +113,22 @@
_logger = logging.getLogger("pytensor.tensor.blas")


# If check_init_y() == True we need to initialize y when beta == 0.
def check_init_y():
# TODO: What is going on here?
def must_initialize_y_gemv():
# Check whether Scipy GEMV could output nan if y in not initialized
from scipy.linalg.blas import get_blas_funcs

if check_init_y._result is None:
y = float("NaN") * np.ones((2,))
if must_initialize_y_gemv._result is None:
y = np.full((2,), np.nan)

Check warning on line 121 in pytensor/tensor/blas.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blas.py#L121

Added line #L121 was not covered by tests
x = np.ones((2,))
A = np.ones((2, 2))
gemv = get_blas_funcs("gemv", dtype=y.dtype)
gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True)
check_init_y._result = np.isnan(y).any()
must_initialize_y_gemv._result = np.isnan(y).any()

Check warning on line 126 in pytensor/tensor/blas.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blas.py#L126

Added line #L126 was not covered by tests

return check_init_y._result
return must_initialize_y_gemv._result

Check warning on line 128 in pytensor/tensor/blas.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blas.py#L128

Added line #L128 was not covered by tests


check_init_y._result = None # type: ignore
must_initialize_y_gemv._result = None # type: ignore


class Gemv(Op):
Expand Down Expand Up @@ -197,7 +196,13 @@
f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}"
)

if beta == 0 and check_init_y():
if beta == 0 and must_initialize_y_gemv():
# Most BLAS implementations of GEMV ignore y=nan when beta=0
# PyTensor considers that the correct behavior,
# and even exploits it to avoid copying or initializing outputs.
# By deciding to exploit this, however, it becomes our responsibility
# to ensure the behavior even in the rare cases BLAS deviates,
# or users will get errors, even for graphs that had no nan to begin with.
y.fill(0)

# Here I suppose that A is in c order. If we don't make it
Expand Down
47 changes: 28 additions & 19 deletions pytensor/tensor/blas_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,12 @@
# ##### ####### #######


def gemv_c_code(y, A, x, z, alpha, beta, fail, force_init_beta=False, params=None):
def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=None):
"""
z <- beta * y + alpha * dot(A, x)

where A is a matrix, y and x are vectors (ergo z is vector)
z = y if inplace else y.copy()
"""
code = """

Expand Down Expand Up @@ -400,17 +401,11 @@
}
if (dbeta != 0)
{
// If dbeta is zero, we avoid doing the copy
if (PyArray_CopyInto(%(z)s, %(y)s) != 0) {
%(fail)s
}
}
else if (%(force_init_beta)d)
{
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}
}
else
{
Expand All @@ -422,6 +417,20 @@
}
}

if (%(must_initialize_y)d && dbeta == 0)
{
// Most BLAS implementations of GEMV ignore y=nan when beta=0
// PyTensor considers that the correct behavior,
// and even exploits it to avoid copying or initializing outputs.
// By deciding to exploit this, however, it becomes our responsibility
// to ensure the behavior even in the rare cases BLAS deviates,
// or users will get errors, even for graphs that had no nan to begin with.
PyObject *zero = PyFloat_FromDouble(0.);
if (zero == NULL) %(fail)s;
if (PyArray_FillWithScalar(%(z)s, zero) != 0) %(fail)s;
Py_DECREF(zero);
}

{
int NA0 = PyArray_DIMS(%(A)s)[0];
int NA1 = PyArray_DIMS(%(A)s)[1];
Expand Down Expand Up @@ -491,13 +500,13 @@

if (is_float)
{
z_data[0] *= fbeta;
z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.f;
z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1,
(float*)x_data, &Sx);
}
else
{
z_data[0] *= dbeta;
z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.;
z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1,
(double*)x_data, &Sx);
}
Expand Down Expand Up @@ -583,21 +592,21 @@
alpha,
beta,
fail=sub["fail"],
force_init_beta=check_force_gemv_init(),
must_initialize_y=must_initialize_y_gemv(),
params=sub["params"],
)
return code

def c_code_cache_version(self):
return (16, blas_header_version(), check_force_gemv_init())
return (17, blas_header_version(), must_initialize_y_gemv())


cgemv_inplace = CGemv(inplace=True)
cgemv_no_inplace = CGemv(inplace=False)


def check_force_gemv_init():
if check_force_gemv_init._force_init_beta is None:
def must_initialize_y_gemv():
if must_initialize_y_gemv._force_init_beta is None:
from pytensor.link.c.cmodule import GCC_compiler

"""
Expand Down Expand Up @@ -643,13 +652,13 @@
)
if res:
if res[0]:
check_force_gemv_init._force_init_beta = res[1]
must_initialize_y_gemv._force_init_beta = res[1]

Check warning on line 655 in pytensor/tensor/blas_c.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blas_c.py#L655

Added line #L655 was not covered by tests
else:
check_force_gemv_init._force_init_beta = False
must_initialize_y_gemv._force_init_beta = False
else:
check_force_gemv_init._force_init_beta = False
must_initialize_y_gemv._force_init_beta = False

Check warning on line 659 in pytensor/tensor/blas_c.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blas_c.py#L659

Added line #L659 was not covered by tests

return check_force_gemv_init._force_init_beta
return must_initialize_y_gemv._force_init_beta


check_force_gemv_init._force_init_beta = None
must_initialize_y_gemv._force_init_beta = None
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
new_out = [rval]
elif xb[0] and yb[1]:
# x and y are both vectors so this qualifies for a sdot / ddot
# TODO: PyTensor doesn't have a sdot, but gemv is better than _dot22
# PyTensor's CGemv will call sdot/ddot at runtime, the Scipy Gemv may not
xv = x.dimshuffle(1)
zeros = ptb.AllocEmpty(x.dtype)(1)
rval = gemv_no_inplace(zeros, one, y.T, xv, zero)
Expand Down
34 changes: 20 additions & 14 deletions tests/tensor/test_blas_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytensor.tensor as pt
from pytensor.tensor.basic import AllocEmpty
from pytensor.tensor.blas import Ger
from pytensor.tensor.blas_c import CGemv, CGer, check_force_gemv_init
from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv
from pytensor.tensor.blas_scipy import ScipyGer
from pytensor.tensor.type import dmatrix, dvector, matrix, scalar, tensor, vector
from tests import unittest_tools
Expand Down Expand Up @@ -130,31 +130,35 @@ def setup_method(self):
self.dtype = dtype
self.mode = pytensor.compile.get_default_mode().including("fast_run")
# matrix
self.A = tensor(dtype=dtype, shape=(None, None))
self.A = tensor("A", dtype=dtype, shape=(None, None))
self.Aval = np.ones((2, 3), dtype=dtype)

# vector
self.x = tensor(dtype=dtype, shape=(None,))
self.y = tensor(dtype=dtype, shape=(None,))
self.x = tensor("x", dtype=dtype, shape=(None,))
self.y = tensor("y", dtype=dtype, shape=(None,))
self.xval = np.asarray([1, 2], dtype=dtype)
self.yval = np.asarray([1.5, 2.7, 3.9], dtype=dtype)

# scalar
self.a = tensor(dtype=dtype, shape=())
self.a = tensor("a", dtype=dtype, shape=())

def test_nan_beta_0(self):
@pytest.mark.parametrize("inplace", [True, False])
def test_nan_beta_0(self, inplace):
mode = self.mode.including()
mode.check_isfinite = False
f = pytensor.function(
[self.A, self.x, self.y, self.a],
[self.A, self.x, pytensor.In(self.y, mutable=inplace), self.a],
self.a * self.y + pt.dot(self.A, self.x),
mode=mode,
)
Aval = np.ones((3, 1), dtype=self.dtype)
xval = np.ones((1,), dtype=self.dtype)
yval = float("NaN") * np.ones((3,), dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not np.isnan(zval).any()
[node] = f.maker.fgraph.apply_nodes
assert isinstance(node.op, CGemv) and node.op.inplace == inplace
for rows in (3, 1):
Aval = np.ones((rows, 1), dtype=self.dtype)
xval = np.ones((1,), dtype=self.dtype)
yval = np.full((rows,), np.nan, dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not np.isnan(zval).any()

def test_optimizations_vm(self):
skip_if_blas_ldflags_empty()
Expand Down Expand Up @@ -191,8 +195,10 @@ def test_optimizations_mv(self):
np.dot(self.Aval[::-1, ::-1], self.yval),
)

def test_force_gemv_init(self):
if check_force_gemv_init():
def test_must_initialize_y_gemv(self):
if must_initialize_y_gemv():
# FIME: This warn should be emitted by the function if we find it relevant
# Not in a test that doesn't care about the outcome either way
warn(
"WARNING: The current BLAS requires PyTensor to initialize"
" memory for some GEMV calls which will result in a minor"
Expand Down