Skip to content

Commit 7b621c4

Browse files
committed
Make tensor API similar to that of other variable constructors
* Name is now the only optional non-keyword argument for all constructors
1 parent c5bfbcd commit 7b621c4

26 files changed

+166
-100
lines changed

pytensor/sparse/basic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3451,7 +3451,12 @@ def make_node(self, a, b):
34513451
return Apply(
34523452
self,
34533453
[a, b],
3454-
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
3454+
[
3455+
tensor(
3456+
dtype=dtype_out,
3457+
shape=(None, 1 if b.type.shape[1] == 1 else None),
3458+
)
3459+
],
34553460
)
34563461

34573462
def perform(self, node, inputs, outputs):
@@ -3582,7 +3587,9 @@ class StructuredDotGradCSC(COp):
35823587

35833588
def make_node(self, a_indices, a_indptr, b, g_ab):
35843589
return Apply(
3585-
self, [a_indices, a_indptr, b, g_ab], [tensor(g_ab.dtype, shape=(None,))]
3590+
self,
3591+
[a_indices, a_indptr, b, g_ab],
3592+
[tensor(dtype=g_ab.dtype, shape=(None,))],
35863593
)
35873594

35883595
def perform(self, node, inputs, outputs):
@@ -3716,7 +3723,7 @@ class StructuredDotGradCSR(COp):
37163723

37173724
def make_node(self, a_indices, a_indptr, b, g_ab):
37183725
return Apply(
3719-
self, [a_indices, a_indptr, b, g_ab], [tensor(b.dtype, shape=(None,))]
3726+
self, [a_indices, a_indptr, b, g_ab], [tensor(dtype=b.dtype, shape=(None,))]
37203727
)
37213728

37223729
def perform(self, node, inputs, outputs):

pytensor/sparse/rewriting.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,11 @@ def make_node(self, a_val, a_ind, a_ptr, a_nrows, b):
270270
r = Apply(
271271
self,
272272
[a_val, a_ind, a_ptr, a_nrows, b],
273-
[tensor(dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
273+
[
274+
tensor(
275+
dtype=dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None)
276+
)
277+
],
274278
)
275279
return r
276280

@@ -465,7 +469,12 @@ def make_node(self, a_val, a_ind, a_ptr, b):
465469
r = Apply(
466470
self,
467471
[a_val, a_ind, a_ptr, b],
468-
[tensor(self.dtype_out, shape=(None, 1 if b.type.shape[1] == 1 else None))],
472+
[
473+
tensor(
474+
dtype=self.dtype_out,
475+
shape=(None, 1 if b.type.shape[1] == 1 else None),
476+
)
477+
],
469478
)
470479
return r
471480

@@ -705,7 +714,11 @@ def make_node(self, alpha, x_val, x_ind, x_ptr, x_nrows, y, z):
705714
r = Apply(
706715
self,
707716
[alpha, x_val, x_ind, x_ptr, x_nrows, y, z],
708-
[tensor(dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None))],
717+
[
718+
tensor(
719+
dtype=dtype_out, shape=(None, 1 if y.type.shape[1] == 1 else None)
720+
)
721+
],
709722
)
710723
return r
711724

@@ -1142,7 +1155,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
11421155
"""
11431156
assert b.type.ndim == 2
11441157
return Apply(
1145-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1158+
self,
1159+
[a_data, a_indices, a_indptr, b],
1160+
[tensor(dtype=b.dtype, shape=(None,))],
11461161
)
11471162

11481163
def c_code_cache_version(self):
@@ -1280,7 +1295,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
12801295
"""
12811296
assert b.type.ndim == 2
12821297
return Apply(
1283-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1298+
self,
1299+
[a_data, a_indices, a_indptr, b],
1300+
[tensor(dtype=b.dtype, shape=(None,))],
12841301
)
12851302

12861303
def c_code_cache_version(self):
@@ -1470,7 +1487,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
14701487
"""
14711488
assert b.type.ndim == 1
14721489
return Apply(
1473-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1490+
self,
1491+
[a_data, a_indices, a_indptr, b],
1492+
[tensor(dtype=b.dtype, shape=(None,))],
14741493
)
14751494

14761495
def c_code_cache_version(self):
@@ -1642,7 +1661,9 @@ def make_node(self, a_data, a_indices, a_indptr, b):
16421661
assert a_indptr.type.ndim == 1
16431662
assert b.type.ndim == 1
16441663
return Apply(
1645-
self, [a_data, a_indices, a_indptr, b], [tensor(b.dtype, shape=(None,))]
1664+
self,
1665+
[a_data, a_indices, a_indptr, b],
1666+
[tensor(dtype=b.dtype, shape=(None,))],
16461667
)
16471668

16481669
def c_code_cache_version(self):

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2882,7 +2882,7 @@ def make_node(self, start, stop, step):
28822882
assert step.ndim == 0
28832883

28842884
inputs = [start, stop, step]
2885-
outputs = [tensor(self.dtype, shape=(None,))]
2885+
outputs = [tensor(dtype=self.dtype, shape=(None,))]
28862886

28872887
return Apply(self, inputs, outputs)
28882888

pytensor/tensor/blas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,7 @@ def make_node(self, x, y):
16801680
raise TypeError(y)
16811681
if y.type.dtype != x.type.dtype:
16821682
raise TypeError("dtype mismatch to Dot22")
1683-
outputs = [tensor(x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
1683+
outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))]
16841684
return Apply(self, [x, y], outputs)
16851685

16861686
def perform(self, node, inp, out):
@@ -1985,7 +1985,7 @@ def make_node(self, x, y, a):
19851985
raise TypeError("Dot22Scalar requires float or complex args", a.dtype)
19861986

19871987
sz = (x.type.shape[0], y.type.shape[1])
1988-
outputs = [tensor(x.type.dtype, shape=sz)]
1988+
outputs = [tensor(dtype=x.type.dtype, shape=sz)]
19891989
return Apply(self, [x, y, a], outputs)
19901990

19911991
def perform(self, node, inp, out):
@@ -2221,7 +2221,7 @@ def make_node(self, *inputs):
22212221
+ inputs[1].type.shape[2:]
22222222
)
22232223
out_shape = tuple(1 if s == 1 else None for s in out_shape)
2224-
return Apply(self, upcasted_inputs, [tensor(dtype, shape=out_shape)])
2224+
return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)])
22252225

22262226
def perform(self, node, inp, out):
22272227
x, y = inp

pytensor/tensor/io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, dtype, shape, mmap_mode=None):
3636
def make_node(self, path):
3737
if isinstance(path, str):
3838
path = Constant(Generic(), path)
39-
return Apply(self, [path], [tensor(self.dtype, shape=self.shape)])
39+
return Apply(self, [path], [tensor(dtype=self.dtype, shape=self.shape)])
4040

4141
def perform(self, node, inp, out):
4242
path = inp[0]
@@ -135,7 +135,7 @@ def make_node(self):
135135
[],
136136
[
137137
Variable(Generic(), None),
138-
tensor(self.dtype, shape=self.static_shape),
138+
tensor(dtype=self.dtype, shape=self.static_shape),
139139
],
140140
)
141141

@@ -180,7 +180,7 @@ def make_node(self, request, data):
180180
return Apply(
181181
self,
182182
[request, data],
183-
[tensor(data.dtype, shape=data.type.shape)],
183+
[tensor(dtype=data.dtype, shape=data.type.shape)],
184184
)
185185

186186
def perform(self, node, inp, out):

pytensor/tensor/math.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def make_node(self, x):
152152
if i not in all_axes
153153
)
154154
outputs = [
155-
tensor(x.type.dtype, shape=out_shape, name="max"),
156-
tensor("int64", shape=out_shape, name="argmax"),
155+
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
156+
tensor(dtype="int64", shape=out_shape, name="argmax"),
157157
]
158158
return Apply(self, inputs, outputs)
159159

@@ -370,7 +370,7 @@ def make_node(self, x, axis=None):
370370
# We keep the original broadcastable flags for dimensions on which
371371
# we do not perform the argmax.
372372
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
373-
outputs = [tensor("int64", shape=out_shape, name="argmax")]
373+
outputs = [tensor(dtype="int64", shape=out_shape, name="argmax")]
374374
return Apply(self, inputs, outputs)
375375

376376
def prepare_node(self, node, storage_map, compute_map, impl):
@@ -1922,7 +1922,7 @@ def make_node(self, *inputs):
19221922
sz = sx[:-1]
19231923

19241924
i_dtypes = [input.type.dtype for input in inputs]
1925-
outputs = [tensor(aes.upcast(*i_dtypes), shape=sz)]
1925+
outputs = [tensor(dtype=aes.upcast(*i_dtypes), shape=sz)]
19261926
return Apply(self, inputs, outputs)
19271927

19281928
def perform(self, node, inp, out):

pytensor/tensor/shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ def make_node(self, x, shp):
641641
except NotScalarConstantError:
642642
pass
643643

644-
return Apply(self, [x, shp], [tensor(x.type.dtype, shape=out_shape)])
644+
return Apply(self, [x, shp], [tensor(dtype=x.type.dtype, shape=out_shape)])
645645

646646
def perform(self, node, inp, out_, params):
647647
x, shp = inp

pytensor/tensor/type.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -775,18 +775,32 @@ def values_eq_approx_always_true(a, b):
775775
version=2,
776776
)
777777

778+
# Valid static type entries
779+
ST = Union[int, None]
780+
778781

779782
def tensor(
783+
name: Optional[str] = None,
784+
*,
780785
dtype: Optional["DTypeLike"] = None,
781-
*args,
786+
shape: Optional[Tuple[ST, ...]] = None,
782787
**kwargs,
783788
) -> "TensorVariable":
784789

790+
if name is not None:
791+
# Help catching errors with the new tensor API
792+
# Many single letter strings are valid sctypes
793+
if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)):
794+
np.obj2sctype(name)
795+
raise ValueError(
796+
f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
797+
"This name looks like a dtype, which you should pass as a keyword argument only."
798+
)
799+
785800
if dtype is None:
786801
dtype = config.floatX
787802

788-
name = kwargs.pop("name", None)
789-
return TensorType(dtype, *args, **kwargs)(name=name)
803+
return TensorType(dtype=dtype, shape=shape, **kwargs)(name=name)
790804

791805

792806
cscalar = TensorType("complex64", ())
@@ -805,6 +819,7 @@ def tensor(
805819

806820
def scalar(
807821
name: Optional[str] = None,
822+
*,
808823
dtype: Optional["DTypeLike"] = None,
809824
) -> "TensorVariable":
810825
"""Return a symbolic scalar variable.
@@ -844,9 +859,6 @@ def scalar(
844859
lvector = TensorType("int64", shape=(None,))
845860

846861

847-
ST = Union[int, None]
848-
849-
850862
def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
851863

852864
if not isinstance(shape, tuple):
@@ -863,6 +875,7 @@ def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
863875

864876
def vector(
865877
name: Optional[str] = None,
878+
*,
866879
dtype: Optional["DTypeLike"] = None,
867880
shape: Optional[Tuple[ST]] = (None,),
868881
) -> "TensorVariable":
@@ -908,6 +921,7 @@ def vector(
908921

909922
def matrix(
910923
name: Optional[str] = None,
924+
*,
911925
dtype: Optional["DTypeLike"] = None,
912926
shape: Optional[Tuple[ST, ST]] = (None, None),
913927
) -> "TensorVariable":
@@ -951,6 +965,7 @@ def matrix(
951965

952966
def row(
953967
name: Optional[str] = None,
968+
*,
954969
dtype: Optional["DTypeLike"] = None,
955970
shape: Optional[Tuple[Literal[1], ST]] = (1, None),
956971
) -> "TensorVariable":
@@ -994,6 +1009,7 @@ def row(
9941009

9951010
def col(
9961011
name: Optional[str] = None,
1012+
*,
9971013
dtype: Optional["DTypeLike"] = None,
9981014
shape: Optional[Tuple[ST, Literal[1]]] = (None, 1),
9991015
) -> "TensorVariable":
@@ -1035,6 +1051,7 @@ def col(
10351051

10361052
def tensor3(
10371053
name: Optional[str] = None,
1054+
*,
10381055
dtype: Optional["DTypeLike"] = None,
10391056
shape: Optional[Tuple[ST, ST, ST]] = (None, None, None),
10401057
) -> "TensorVariable":
@@ -1074,6 +1091,7 @@ def tensor3(
10741091

10751092
def tensor4(
10761093
name: Optional[str] = None,
1094+
*,
10771095
dtype: Optional["DTypeLike"] = None,
10781096
shape: Optional[Tuple[ST, ST, ST, ST]] = (None, None, None, None),
10791097
) -> "TensorVariable":
@@ -1113,6 +1131,7 @@ def tensor4(
11131131

11141132
def tensor5(
11151133
name: Optional[str] = None,
1134+
*,
11161135
dtype: Optional["DTypeLike"] = None,
11171136
shape: Optional[Tuple[ST, ST, ST, ST, ST]] = (None, None, None, None, None),
11181137
) -> "TensorVariable":
@@ -1152,6 +1171,7 @@ def tensor5(
11521171

11531172
def tensor6(
11541173
name: Optional[str] = None,
1174+
*,
11551175
dtype: Optional["DTypeLike"] = None,
11561176
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST]] = (
11571177
None,
@@ -1198,6 +1218,7 @@ def tensor6(
11981218

11991219
def tensor7(
12001220
name: Optional[str] = None,
1221+
*,
12011222
dtype: Optional["DTypeLike"] = None,
12021223
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST, ST]] = (
12031224
None,

tests/graph/test_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,13 @@ def test_get_var_by_name():
556556
def test_clone_new_inputs():
557557
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
558558

559-
x = at.tensor(np.float64, shape=(None,))
560-
y = at.tensor(np.float64, shape=(1,))
559+
x = at.tensor(dtype=np.float64, shape=(None,))
560+
y = at.tensor(dtype=np.float64, shape=(1,))
561561

562562
z = at.add(x, y)
563563
assert z.type.shape == (None,)
564564

565-
x_new = at.tensor(np.float64, shape=(1,))
565+
x_new = at.tensor(dtype=np.float64, shape=(1,))
566566

567567
# The output nodes should be reconstructed, because the input types' static
568568
# shape information increased in specificity

tests/link/numba/test_elemwise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
146146
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
147147
(
148148
set_test_value(
149-
at.tensor(config.floatX, shape=(None, 1, None), name="a"),
149+
at.tensor(dtype=config.floatX, shape=(None, 1, None), name="a"),
150150
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
151151
),
152152
("x", 2, "x", 0, "x"),
@@ -155,21 +155,21 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
155155
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
156156
(
157157
set_test_value(
158-
at.tensor(config.floatX, shape=(None, 1), name="a"),
158+
at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
159159
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
160160
),
161161
(0,),
162162
),
163163
(
164164
set_test_value(
165-
at.tensor(config.floatX, shape=(None, 1), name="a"),
165+
at.tensor(dtype=config.floatX, shape=(None, 1), name="a"),
166166
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
167167
),
168168
(0,),
169169
),
170170
(
171171
set_test_value(
172-
at.tensor(config.floatX, shape=(1, 1, 1), name="a"),
172+
at.tensor(dtype=config.floatX, shape=(1, 1, 1), name="a"),
173173
np.array([[[1.0]]], dtype=config.floatX),
174174
),
175175
(),

0 commit comments

Comments
 (0)