Skip to content

Commit 4971ac3

Browse files
committed
draft commit to push on CI
1 parent 885ff0c commit 4971ac3

File tree

7 files changed

+41
-25
lines changed

7 files changed

+41
-25
lines changed

pytensor/sparse/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ class DenseFromSparse(Op):
954954
955955
"""
956956

957-
__props__ = ()
957+
__props__ = ("sparse_grad",)
958958

959959
def __init__(self, structured=True):
960960
self.sparse_grad = structured

pytensor/sparse/rewriting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,13 +1099,13 @@ def c_code_cache_version(self):
10991099
csm_grad_c = CSMGradC()
11001100

11011101

1102-
@node_rewriter([csm_grad(None)])
1102+
@node_rewriter([csm_grad()])
11031103
def local_csm_grad_c(fgraph, node):
11041104
"""
11051105
csm_grad(None) -> csm_grad_c
11061106
11071107
"""
1108-
if node.op == csm_grad(None):
1108+
if node.op == csm_grad():
11091109
return [csm_grad_c(*node.inputs)]
11101110
return False
11111111

pytensor/sparse/type.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def __init__(
7474
):
7575
if shape is None and broadcastable is None:
7676
shape = (None, None)
77-
77+
if broadcastable is None:
78+
broadcastable = (False, False)
79+
if broadcastable != (False, False):
80+
raise ValueError("Broadcasting sparse types is not yet implemented")
7881
if format not in self.format_cls:
7982
raise ValueError(
8083
f'unsupported format "{format}" not in list',
@@ -96,7 +99,9 @@ def clone(
9699
dtype = self.dtype
97100
if shape is None:
98101
shape = self.shape
99-
return type(self)(format, dtype, shape=shape, **kwargs)
102+
return type(self)(
103+
format, dtype, shape=shape, broadcastable=broadcastable, **kwargs
104+
)
100105

101106
def filter(self, value, strict=False, allow_downcast=None):
102107
if isinstance(value, Variable):

pytensor/tensor/shape.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from numbers import Number
33
from textwrap import dedent
4-
from typing import Dict, List, Tuple, Union
4+
from typing import Dict, List, Sequence, Tuple, Union
55

66
import numpy as np
77

@@ -392,10 +392,15 @@ class SpecifyShape(COp):
392392
Maybe in the future we will never do the assert!
393393
"""
394394

395+
__props__ = ("output_broadcastable",)
395396
view_map = {0: [0]}
396397
__props__ = ()
397398
_f16_ok = True
398399

400+
def __init__(self, output_broadcastable: Sequence) -> None:
401+
super().__init__()
402+
self.output_broadcastable = tuple(output_broadcastable)
403+
399404
def make_node(self, x, *shape):
400405
from pytensor.tensor.basic import get_scalar_constant_value
401406

@@ -432,7 +437,9 @@ def make_node(self, x, *shape):
432437
except NotScalarConstantError:
433438
pass
434439

435-
out_var = x.type.clone(shape=type_shape)()
440+
out_var = x.type.clone(
441+
shape=type_shape, broadcastable=self.output_broadcastable
442+
)()
436443

437444
return Apply(self, [x, *shape], [out_var])
438445

@@ -539,12 +546,10 @@ def c_code_cache_version(self):
539546
return (2,)
540547

541548

542-
_specify_shape = SpecifyShape()
543-
544-
545549
def specify_shape(
546550
x: Union[np.ndarray, Number, Variable],
547551
shape: Union[ShapeValueType, List[ShapeValueType], Tuple[ShapeValueType, ...]],
552+
broadcastable=None,
548553
):
549554
"""Specify a fixed shape for a `Variable`.
550555
@@ -574,8 +579,9 @@ def specify_shape(
574579
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
575580
if not new_shape_info and len(shape) == x.type.ndim:
576581
return x
577-
578-
return _specify_shape(x, *shape)
582+
if broadcastable is None:
583+
broadcastable = x.broadcastable
584+
return SpecifyShape(broadcastable)(x, *shape)
579585

580586

581587
@_get_vector_length.register(SpecifyShape)
@@ -934,7 +940,11 @@ def specify_broadcastable(x, *axes):
934940
raise ValueError("Trying to specify broadcastable of non-existent dimension")
935941

936942
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
937-
return specify_shape(x, shape_info)
943+
broadcastable = [
944+
True if i in axes else b for i, b in enumerate(x.type.broadcastable)
945+
]
946+
947+
return specify_shape(x, shape_info, broadcastable=broadcastable)
938948

939949

940950
class Unbroadcast(COp):

pytensor/tensor/var.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,7 @@ def __init__(self, type: _TensorTypeType, data, name=None):
10181018
)
10191019

10201020
# We want all the shape information from `data`
1021-
new_type = type.clone(shape=data_shape)
1021+
new_type = type.clone(shape=data_shape, broadcastable=type.broadcastable)
10221022

10231023
assert not any(s is None for s in new_type.shape)
10241024

tests/sparse/test_basic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,21 +1093,23 @@ def check_format_ndim(format, ndim):
10931093
s = SparseFromDense(format)(x)
10941094
s_m = -s
10951095
d = dense_from_sparse(s_m)
1096+
pytensor.grad(None, x, known_grads={d: d.type()})
10961097
c = d.sum()
10971098
g = pytensor.grad(c, x)
10981099
f = pytensor.function([x], [s, g])
10991100
f(np.array(0, dtype=config.floatX, ndmin=ndim))
11001101
f(np.array(7, dtype=config.floatX, ndmin=ndim))
11011102

1102-
def test_format_ndim(self):
1103-
for format in "csc", "csr":
1104-
for ndim in 0, 1, 2:
1105-
self.check_format_ndim(format, ndim)
1103+
@pytest.mark.parametrize("format", ["csc", "csr"])
1104+
@pytest.mark.parametrize("ndim", [0, 1, 2])
1105+
def test_format_ndim(self, format, ndim):
1106+
self.check_format_ndim(format, ndim)
11061107

1107-
with pytest.raises(TypeError):
1108-
self.check_format_ndim(format, 3)
1109-
with pytest.raises(TypeError):
1110-
self.check_format_ndim(format, 4)
1108+
@pytest.mark.parametrize("format", ["csc", "csr"])
1109+
@pytest.mark.parametrize("ndim", [3, 4])
1110+
def test_format_ndim_raises(self, format, ndim):
1111+
with pytest.raises(TypeError):
1112+
self.check_format_ndim(format, ndim)
11111113

11121114

11131115
class TestCsmProperties:

tests/tensor/test_shape.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Shape_i,
2020
SpecifyShape,
2121
Unbroadcast,
22-
_specify_shape,
2322
reshape,
2423
shape,
2524
shape_i,
@@ -350,13 +349,13 @@ def test_check_inputs(self):
350349
specify_shape([[1, 2, 3], [4, 5, 6]], (2.2, 3))
351350

352351
with pytest.raises(TypeError, match="must be integer types"):
353-
_specify_shape([[1, 2, 3], [4, 5, 6]], *(2.2, 3))
352+
SpecifyShape([False, False])([[1, 2, 3], [4, 5, 6]], *(2.2, 3))
354353

355354
with pytest.raises(ValueError, match="will never match"):
356355
specify_shape(matrix(), [4])
357356

358357
with pytest.raises(ValueError, match="will never match"):
359-
_specify_shape(matrix(), *[4])
358+
SpecifyShape([False, False])(matrix(), *[4])
360359

361360
with pytest.raises(ValueError, match="must have fixed dimensions"):
362361
specify_shape(matrix(), vector(dtype="int32"))

0 commit comments

Comments
 (0)