Skip to content

Commit 4452581

Browse files
committed
Make tensor API similar to that of other variable constructors
* Name is now the only optional non-keyword argument for all constructors
1 parent c5bfbcd commit 4452581

24 files changed

+155
-93
lines changed

pytensor/sparse/basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,7 +3451,12 @@ def make_node(self, a, b):
34513451
return Apply(
34523452
self,
34533453
[a, b],
3454-
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
3454+
[
3455+
tensor(
3456+
dtype=dtype_out,
3457+
shape=(None, 1 if b.type.shape[1] == 1 else None),
3458+
)
3459+
],
34553460
)
34563461

34573462
def perform(self, node, inputs, outputs):
@@ -3582,7 +3587,9 @@ class StructuredDotGradCSC(COp):
35823587

35833588
def make_node(self, a_indices, a_indptr, b, g_ab):
35843589
return Apply(
3585-
self, [a_indices, a_indptr, b, g_ab], [tensor(g_ab.dtype, shape=(None,))]
3590+
self,
3591+
[a_indices, a_indptr, b, g_ab],
3592+
[tensor(dtype=g_ab.dtype, shape=(None,))],
35863593
)
35873594

35883595
def perform(self, node, inputs, outputs):
@@ -3716,7 +3723,7 @@ class StructuredDotGradCSR(COp):
37163723

37173724
def make_node(self, a_indices, a_indptr, b, g_ab):
37183725
return Apply(
3719-
self, [a_indices, a_indptr, b, g_ab], [tensor(b.dtype, shape=(None,))]
3726+
self, [a_indices, a_indptr, b, g_ab], [tensor(dtype=b.dtype, shape=(None,))]
37203727
)
37213728

37223729
def perform(self, node, inputs, outputs):

pytensor/sparse/rewriting.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,11 @@ def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
270270
r = Apply(
271271
self,
272272
[a_val, a_ind, a_ptr, a_nrows, b],
273-
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
273+
[
274+
tensor(
275+
dtype=dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None)
276+
)
277+
],
274278
)
275279
return r
276280

@@ -465,7 +469,12 @@ def make_node(self, a_val, a_ind, a_ptr, b):
465469
r = Apply(
466470
self,
467471
[a_val, a_ind, a_ptr, b],
468-
[tensor(self.dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
472+
[
473+
tensor(
474+
dtype=self.dtype_out,
475+
shape=(None, 1 if b.type.shape[1] == 1 else None),
476+
)
477+
],
469478
)
470479
return r
471480

@@ -705,7 +714,11 @@ def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
705714
r = Apply(
706715
self,
707716
[alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
708-
[tensor(dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None))],
717+
[
718+
tensor(
719+
dtype=dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None)
720+
)
721+
],
709722
)
710723
return r
711724

@@ -1142,7 +1155,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
11421155
"""
11431156
assert b.type.ndim == 2
11441157
return Apply(
1145-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1158+
self,
1159+
[a_data, a_indices, a_indptr, b],
1160+
[tensor(dtype=b.dtype, shape=(None,))],
11461161
)
11471162

11481163
def c_code_cache_version(self):
@@ -1280,7 +1295,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
12801295
"""
12811296
assert b.type.ndim == 2
12821297
return Apply(
1283-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1298+
self,
1299+
[a_data, a_indices, a_indptr, b],
1300+
[tensor(dtype=b.dtype, shape=(None,))],
12841301
)
12851302

12861303
def c_code_cache_version(self):
@@ -1470,7 +1487,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
14701487
"""
14711488
assert b.type.ndim == 1
14721489
return Apply(
1473-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1490+
self,
1491+
[a_data, a_indices, a_indptr, b],
1492+
[tensor(dtype=b.dtype, shape=(None,))],
14741493
)
14751494

14761495
def c_code_cache_version(self):
@@ -1642,7 +1661,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
16421661
assert a_indptr.type.ndim == 1
16431662
assert b.type.ndim == 1
16441663
return Apply(
1645-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1664+
self,
1665+
[a_data, a_indices, a_indptr, b],
1666+
[tensor(dtype=b.dtype, shape=(None,))],
16461667
)
16471668

16481669
def c_code_cache_version(self):

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2882,7 +2882,7 @@ def make_node(self, start, stop, step):
28822882
assert step.ndim == 0
28832883

28842884
inputs = [start, stop, step]
2885-
outputs = [tensor(self.dtype, shape=(None,))]
2885+
outputs = [tensor(dtype=self.dtype, shape=(None,))]
28862886

28872887
return Apply(self, inputs, outputs)
28882888

pytensor/tensor/blas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ def make_node(self, x, y):
16801680
raise TypeError(y)
16811681
if y.type.dtype != x.type.dtype:
16821682
raise TypeError("dtype mismatch to Dot22")
1683-
outputs = [tensor(x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
1683+
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
16841684
return Apply(self, [x, y], outputs)
16851685

16861686
def perform(self, node, inp, out):
@@ -1985,7 +1985,7 @@ def make_node(self, x, y, a):
19851985
raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
19861986

19871987
sz = (x.type.shape[0], y.type.shape[1])
1988-
outputs = [tensor(x.type.dtype, shape=sz)]
1988+
outputs = [tensor(dtype=x.type.dtype, shape=sz)]
19891989
return Apply(self, [x, y, a], outputs)
19901990

19911991
def perform(self, node, inp, out):
@@ -2221,7 +2221,7 @@ def make_node(self, *inputs):
22212221
+ inputs[1].type.shape[2:]
22222222
)
22232223
out_shape = tuple(1 if s == 1 else None for s in out_shape)
2224-
return Apply(self, upcasted_inputs, [tensor(dtype, shape=out_shape)])
2224+
return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)])
22252225

22262226
def perform(self, node, inp, out):
22272227
x, y = inp

pytensor/tensor/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, dtype, shape, mmap_mode=None):
3636
def make_node(self, path):
3737
if isinstance(path, str):
3838
path = Constant(Generic(), path)
39-
return Apply(self, [path], [tensor(self.dtype, shape=self.shape)])
39+
return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
4040

4141
def perform(self, node, inp, out):
4242
path = inp[0]
@@ -135,7 +135,7 @@ def make_node(self):
135135
[],
136136
[
137137
Variable(Generic(), None),
138-
tensor(self.dtype, shape=self.static_shape),
138+
tensor(dtype=self.dtype, shape=self.static_shape),
139139
],
140140
)
141141

@@ -180,7 +180,7 @@ def make_node(self, request, data):
180180
return Apply(
181181
self,
182182
[request, data],
183-
[tensor(data.dtype, shape=data.type.shape)],
183+
[tensor(dtype=data.dtype, shape=data.type.shape)],
184184
)
185185

186186
def perform(self, node, inp, out):

pytensor/tensor/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def make_node(self, x):
152152
if i not in all_axes
153153
)
154154
outputs = [
155-
tensor(x.type.dtype, shape=out_shape, name="max"),
156-
tensor("int64", shape=out_shape, name="argmax"),
155+
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
156+
tensor(dtype="int64", shape=out_shape, name="argmax"),
157157
]
158158
return Apply(self, inputs, outputs)
159159

@@ -370,7 +370,7 @@ def make_node(self, x, axis=None):
370370
# We keep the original broadcastable flags for dimensions on which
371371
# we do not perform the argmax.
372372
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
373-
outputs = [tensor("int64", shape=out_shape, name="argmax")]
373+
outputs = [tensor(dtype="int64", shape=out_shape, name="argmax")]
374374
return Apply(self, inputs, outputs)
375375

376376
def prepare_node(self, node, storage_map, compute_map, impl):
@@ -1922,7 +1922,7 @@ def make_node(self, *inputs):
19221922
sz = sx[:-1]
19231923

19241924
i_dtypes = [input.type.dtype for input in inputs]
1925-
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)]
1925+
outputs = [tensor(dtype=aes.upcast(*i_dtypes), shape=sz)]
19261926
return Apply(self, inputs, outputs)
19271927

19281928
def perform(self, node, inp, out):

pytensor/tensor/type.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def __init__(
9292
9393
"""
9494

95+
if shape is None and broadcastable is None:
96+
raise ValueError("Must pass shape or broadcastable")
97+
elif shape is not None and broadcastable is not None:
98+
# TOOD: Lift this constraint when we reintroduce broadcastable flags
99+
raise ValueError("Cannot pass both shape and broadcastable")
100+
95101
if broadcastable is not None:
96102
warnings.warn(
97103
"The `broadcastable` keyword is deprecated; use `shape`.",
@@ -775,18 +781,30 @@ def values_eq_approx_always_true(a, b):
775781
version=2,
776782
)
777783

784+
# Valid static type entries
785+
ST = Union[int, None]
786+
778787

779788
def tensor(
789+
name: Optional[str] = None,
790+
*,
780791
dtype: Optional["DTypeLike"] = None,
781-
*args,
792+
shape: Optional[Tuple[ST, ...]] = None,
782793
**kwargs,
783794
) -> "TensorVariable":
784795

796+
if name is not None:
797+
# Help catching errors with the new tensor API
798+
if str(name) == "floatX" or np.obj2sctype(dtype):
799+
raise ValueError(
800+
f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
801+
"This name looks like a dtype, which you should pass as a keyword argument only."
802+
)
803+
785804
if dtype is None:
786805
dtype = config.floatX
787806

788-
name = kwargs.pop("name", None)
789-
return TensorType(dtype, *args, **kwargs)(name=name)
807+
return TensorType(dtype=dtype, shape=shape, **kwargs)(name=name)
790808

791809

792810
cscalar = TensorType("complex64", ())
@@ -805,6 +823,7 @@ def tensor(
805823

806824
def scalar(
807825
name: Optional[str] = None,
826+
*,
808827
dtype: Optional["DTypeLike"] = None,
809828
) -> "TensorVariable":
810829
"""Return a symbolic scalar variable.
@@ -844,9 +863,6 @@ def scalar(
844863
lvector = TensorType("int64", shape=(None,))
845864

846865

847-
ST = Union[int, None]
848-
849-
850866
def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
851867

852868
if not isinstance(shape, tuple):
@@ -863,6 +879,7 @@ def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
863879

864880
def vector(
865881
name: Optional[str] = None,
882+
*,
866883
dtype: Optional["DTypeLike"] = None,
867884
shape: Optional[Tuple[ST]] = (None,),
868885
) -> "TensorVariable":
@@ -908,6 +925,7 @@ def vector(
908925

909926
def matrix(
910927
name: Optional[str] = None,
928+
*,
911929
dtype: Optional["DTypeLike"] = None,
912930
shape: Optional[Tuple[ST, ST]] = (None, None),
913931
) -> "TensorVariable":
@@ -951,6 +969,7 @@ def matrix(
951969

952970
def row(
953971
name: Optional[str] = None,
972+
*,
954973
dtype: Optional["DTypeLike"] = None,
955974
shape: Optional[Tuple[Literal[1], ST]] = (1, None),
956975
) -> "TensorVariable":
@@ -994,6 +1013,7 @@ def row(
9941013

9951014
def col(
9961015
name: Optional[str] = None,
1016+
*,
9971017
dtype: Optional["DTypeLike"] = None,
9981018
shape: Optional[Tuple[ST, Literal[1]]] = (None, 1),
9991019
) -> "TensorVariable":
@@ -1035,6 +1055,7 @@ def col(
10351055

10361056
def tensor3(
10371057
name: Optional[str] = None,
1058+
*,
10381059
dtype: Optional["DTypeLike"] = None,
10391060
shape: Optional[Tuple[ST, ST, ST]] = (None, None, None),
10401061
) -> "TensorVariable":
@@ -1074,6 +1095,7 @@ def tensor3(
10741095

10751096
def tensor4(
10761097
name: Optional[str] = None,
1098+
*,
10771099
dtype: Optional["DTypeLike"] = None,
10781100
shape: Optional[Tuple[ST, ST, ST, ST]] = (None, None, None, None),
10791101
) -> "TensorVariable":
@@ -1113,6 +1135,7 @@ def tensor4(
11131135

11141136
def tensor5(
11151137
name: Optional[str] = None,
1138+
*,
11161139
dtype: Optional["DTypeLike"] = None,
11171140
shape: Optional[Tuple[ST, ST, ST, ST, ST]] = (None, None, None, None, None),
11181141
) -> "TensorVariable":
@@ -1152,6 +1175,7 @@ def tensor5(
11521175

11531176
def tensor6(
11541177
name: Optional[str] = None,
1178+
*,
11551179
dtype: Optional["DTypeLike"] = None,
11561180
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST]] = (
11571181
None,
@@ -1198,6 +1222,7 @@ def tensor6(
11981222

11991223
def tensor7(
12001224
name: Optional[str] = None,
1225+
*,
12011226
dtype: Optional["DTypeLike"] = None,
12021227
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST, ST]] = (
12031228
None,

tests/graph/test_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,13 @@ def test_get_var_by_name():
556556
def test_clone_new_inputs():
557557
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
558558

559-
x = at.tensor(np.float64, shape=(None,))
560-
y = at.tensor(np.float64, shape=(1,))
559+
x = at.tensor(dtype=np.float64, shape=(None,))
560+
y = at.tensor(dtype=np.float64, shape=(1,))
561561

562562
z = at.add(x, y)
563563
assert z.type.shape == (None,)
564564

565-
x_new = at.tensor(np.float64, shape=(1,))
565+
x_new = at.tensor(dtype=np.float64, shape=(1,))
566566

567567
# The output nodes should be reconstructed, because the input types' static
568568
# shape information increased in specificity

tests/link/numba/test_elemwise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
146146
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
147147
(
148148
set_test_value(
149-
at.tensor(config.floatX, shape=(None, 1, None), name="a"),
149+
at.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"),
150150
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
151151
),
152152
("x", 2, "x", 0, "x"),
@@ -155,21 +155,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
155155
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
156156
(
157157
set_test_value(
158-
at.tensor(config.floatX, shape=(None, 1), name="a"),
158+
at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
159159
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
160160
),
161161
(0,),
162162
),
163163
(
164164
set_test_value(
165-
at.tensor(config.floatX, shape=(None, 1), name="a"),
165+
at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
166166
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
167167
),
168168
(0,),
169169
),
170170
(
171171
set_test_value(
172-
at.tensor(config.floatX, shape=(1, 1, 1), name="a"),
172+
at.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"),
173173
np.array([[[1.0]]], dtype=config.floatX),
174174
),
175175
(),

0 commit comments

Comments
 (0)