Skip to content

Commit bf678ca

Browse files
brandonwillardtwiecki
authored andcommitted
Add np.shape and ndim overloads for sparse Numba types
1 parent 039ed1c commit bf678ca

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

pytensor/link/numba/dispatch/sparse.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import scipy as sp
23
import scipy.sparse
34
from numba.core import cgutils, types
@@ -6,6 +7,8 @@
67
box,
78
make_attribute_wrapper,
89
models,
10+
overload,
11+
overload_attribute,
912
register_model,
1013
typeof_impl,
1114
unbox,
@@ -140,3 +143,21 @@ def box_matrix(typ, val, c):
140143
c.pyapi.decref(shape_obj)
141144

142145
return obj
146+
147+
148+
@overload(np.shape)
149+
def overload_sparse_shape(x):
150+
if isinstance(x, CSMatrixType):
151+
return lambda x: x.shape
152+
153+
154+
@overload_attribute(CSMatrixType, "ndim")
155+
def overload_sparse_ndim(inst):
156+
157+
if not isinstance(inst, CSMatrixType):
158+
return
159+
160+
def ndim(inst):
161+
return 2
162+
163+
return ndim

tests/link/numba/test_sparse.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,27 @@ def test_boxing(x, y):
3838
assert np.array_equal(res_y_val.indices, y_val.indices)
3939
assert np.array_equal(res_y_val.indptr, y_val.indptr)
4040
assert res_y_val.shape == y_val.shape
41+
42+
43+
def test_sparse_shape():
44+
@numba.njit
45+
def test_fn(x):
46+
return np.shape(x)
47+
48+
x_val = sp.sparse.csr_matrix(np.eye(100))
49+
50+
res = test_fn(x_val)
51+
52+
assert res == (100, 100)
53+
54+
55+
def test_sparse_ndim():
56+
@numba.njit
57+
def test_fn(x):
58+
return x.ndim
59+
60+
x_val = sp.sparse.csr_matrix(np.eye(100))
61+
62+
res = test_fn(x_val)
63+
64+
assert res == 2

0 commit comments

Comments
 (0)