From 9aef886bf7836e6f11f83b54f9251fc2dfebe7b1 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 6 Dec 2022 19:30:11 -0600 Subject: [PATCH 1/5] Add np.shape and ndim overloads for sparse Numba types --- pytensor/link/numba/dispatch/sparse.py | 21 +++++++++++++++++++++ tests/link/numba/test_sparse.py | 24 ++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py index 916644991e..379c9fe887 100644 --- a/pytensor/link/numba/dispatch/sparse.py +++ b/pytensor/link/numba/dispatch/sparse.py @@ -1,3 +1,4 @@ +import numpy as np import scipy as sp import scipy.sparse from numba.core import cgutils, types @@ -6,6 +7,8 @@ box, make_attribute_wrapper, models, + overload, + overload_attribute, register_model, typeof_impl, unbox, @@ -139,3 +142,21 @@ def box_matrix(typ, val, c): c.pyapi.decref(shape_obj) return obj + + +@overload(np.shape) +def overload_sparse_shape(x): + if isinstance(x, CSMatrixType): + return lambda x: x.shape + + +@overload_attribute(CSMatrixType, "ndim") +def overload_sparse_ndim(inst): + + if not isinstance(inst, CSMatrixType): + return + + def ndim(inst): + return 2 + + return ndim diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py index 39227fb19f..069fd72cd6 100644 --- a/tests/link/numba/test_sparse.py +++ b/tests/link/numba/test_sparse.py @@ -38,3 +38,27 @@ def test_boxing(x, y): assert np.array_equal(res_y_val.indices, y_val.indices) assert np.array_equal(res_y_val.indptr, y_val.indptr) assert res_y_val.shape == y_val.shape + + +def test_sparse_shape(): + @numba.njit + def test_fn(x): + return np.shape(x) + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + res = test_fn(x_val) + + assert res == (100, 100) + + +def test_sparse_ndim(): + @numba.njit + def test_fn(x): + return x.ndim + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + res = test_fn(x_val) + + assert res == 2 From f204ea61b3ea0c0de7d7aca3dc92cb2c6b37b59c Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 6 Dec 2022 19:28:08 -0600 Subject: [PATCH 2/5] Fix PyTensor-to-Numba type resolution for sparse variables --- pytensor/link/numba/dispatch/basic.py | 11 +++++++++++ pytensor/link/numba/dispatch/sparse.py | 10 ++++++++-- tests/link/numba/test_sparse.py | 25 ++++++++++++++++++++++++- 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 60de0de6b6..7fe609dce2 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -25,6 +25,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.ifelse import IfElse +from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import ( compile_function_src, fgraph_to_python, @@ -32,6 +33,7 @@ ) from pytensor.scalar.basic import ScalarType from pytensor.scalar.math import Softplus +from pytensor.sparse import SparseTensorType from pytensor.tensor.blas import BatchedDot from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape @@ -105,6 +107,15 @@ def get_numba_type( dtype = np.dtype(pytensor_type.dtype) numba_dtype = numba.from_dtype(dtype) return numba_dtype + elif isinstance(pytensor_type, SparseTensorType): + dtype = pytensor_type.numpy_dtype + numba_dtype = numba.from_dtype(dtype) + if pytensor_type.format == "csr": + return CSRMatrixType(numba_dtype) + if pytensor_type.format == "csc": + return CSCMatrixType(numba_dtype) + + raise NotImplementedError() else: raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py index 379c9fe887..17826be90a 100644 --- a/pytensor/link/numba/dispatch/sparse.py +++ b/pytensor/link/numba/dispatch/sparse.py @@ -19,7 +19,10 @@ class CSMatrixType(types.Type): """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" name: str - instance_class: type + + @staticmethod + def instance_class(data, indices, indptr, shape): + raise NotImplementedError() def __init__(self, dtype): self.dtype = dtype @@ -29,6 +32,10 @@ def __init__(self, dtype): self.shape = types.UniTuple(types.int64, 2) super().__init__(self.name) + @property + def key(self): + return (self.name, self.dtype) + make_attribute_wrapper(CSMatrixType, "data", "data") make_attribute_wrapper(CSMatrixType, "indices", "indices") @@ -152,7 +159,6 @@ def overload_sparse_shape(x): @overload_attribute(CSMatrixType, "ndim") def overload_sparse_ndim(inst): - if not isinstance(inst, CSMatrixType): return diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py index 069fd72cd6..3951b8f5a1 100644 --- a/tests/link/numba/test_sparse.py +++ b/tests/link/numba/test_sparse.py @@ -1,9 +1,16 @@ import numba import numpy as np +import pytest import scipy as sp -# Load Numba customizations +# Make sure the Numba customizations are loaded import pytensor.link.numba.dispatch.sparse # noqa: F401 +from pytensor import config +from pytensor.sparse import Dot, SparseTensorType +from tests.link.numba.test_basic import compare_numba_and_py + + +pytestmark = pytest.mark.filterwarnings("error") def test_sparse_unboxing(): @@ -62,3 +69,19 @@ def test_fn(x): res = test_fn(x_val) assert res == 2 + + +def test_sparse_objmode(): + x = SparseTensorType("csc", dtype=config.floatX)() + y = SparseTensorType("csc", dtype=config.floatX)() + + out = Dot()(x, y) + + x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) + y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX) + + with pytest.warns( + UserWarning, + match="Numba will use object mode to run SparseDot's perform method", + ): + compare_numba_and_py(((x, y), (out,)), [x_val, y_val]) From 3708ac0d77847f920e69e38064a57efbfb062abc Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 7 Dec 2022 19:19:22 -0600 Subject: [PATCH 3/5] Use Type.filter in NumbaLinker.output_filter --- pytensor/link/numba/linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 3f0e35543f..c573835292 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -17,7 +17,7 @@ def output_filter(self, var: "Variable", out: Any) -> Any: if not isinstance(var, np.ndarray) and isinstance( var.type, pytensor.tensor.TensorType ): - return np.asarray(out, dtype=var.type.dtype) + return var.type.filter(out, allow_downcast=True) return out From e9b8a037937fc2adfcb6697b3e08b22f6261b225 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 7 Dec 2022 19:21:30 -0600 Subject: [PATCH 4/5] Compare single elements directly in compare_numba_and_py --- tests/link/numba/test_basic.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index ae4ef8cd4c..313e691763 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -127,8 +127,6 @@ def set_test_value(x, v): def compare_shape_dtype(x, y): - (x,) = x - (y,) = y return x.shape == y.shape and x.dtype == y.dtype @@ -286,7 +284,7 @@ def assert_fn(x, y): for j, p in zip(numba_res, py_res): assert_fn(j, p) else: - assert_fn(numba_res, py_res) + assert_fn(numba_res[0], py_res[0]) return pytensor_numba_fn, numba_res From c39cd9ec95863d62dac45e075f47cc5cb1916a5b Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 7 Dec 2022 18:44:46 -0600 Subject: [PATCH 5/5] Implement a copy method for Numba sparse types --- pytensor/link/numba/dispatch/sparse.py | 38 ++++++++++++++++++++++++++ tests/link/numba/test_sparse.py | 13 +++++++++ 2 files changed, 51 insertions(+) diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py index 17826be90a..e25083e92d 100644 --- a/pytensor/link/numba/dispatch/sparse.py +++ b/pytensor/link/numba/dispatch/sparse.py @@ -2,13 +2,16 @@ import scipy as sp import scipy.sparse from numba.core import cgutils, types +from numba.core.imputils import impl_ret_borrowed from numba.extending import ( NativeValue, box, + intrinsic, make_attribute_wrapper, models, overload, overload_attribute, + overload_method, register_model, typeof_impl, unbox, @@ -166,3 +169,38 @@ def ndim(inst): return 2 return ndim + + +@intrinsic +def _sparse_copy(typingctx, inst, data, indices, indptr, shape): + def _construct(context, builder, sig, args): + typ = sig.return_type + struct = cgutils.create_struct_proxy(typ)(context, builder) + _, data, indices, indptr, shape = args + struct.data = data + struct.indices = indices + struct.indptr = indptr + struct.shape = shape + return impl_ret_borrowed( + context, + builder, + sig.return_type, + struct._getvalue(), + ) + + sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) + + return sig, _construct + + +@overload_method(CSMatrixType, "copy") +def overload_sparse_copy(inst): + if not isinstance(inst, CSMatrixType): + return + + def copy(inst): + return _sparse_copy( + inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape + ) + + return copy diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py index 3951b8f5a1..482aec9558 100644 --- a/tests/link/numba/test_sparse.py +++ b/tests/link/numba/test_sparse.py @@ -71,6 +71,19 @@ def test_fn(x): assert res == 2 +def test_sparse_copy(): + @numba.njit + def test_fn(x): + y = x.copy() + return ( + y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) + ) + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + assert test_fn(x_val) + + def test_sparse_objmode(): x = SparseTensorType("csc", dtype=config.floatX)() y = SparseTensorType("csc", dtype=config.floatX)()