Skip to content

Commit 471657a

Browse files
Clean up some usage of the TensorType interface in Scan
1 parent 807c0c9 commit 471657a

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

aesara/scan/op.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
161161
which may wrongly be interpreted as broadcastable.
162162
163163
"""
164-
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
164+
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
165165
return
166+
166167
msg = (
167168
"The broadcast pattern of the output of scan (%s) is "
168169
"inconsistent with the one provided in `output_info` "
@@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
173174
"them consistent, e.g. using aesara.tensor."
174175
"{unbroadcast, specify_broadcastable}."
175176
)
176-
size = min(len(v1.broadcastable), len(v2.broadcastable))
177+
size = min(v1.type.ndim, v2.type.ndim)
177178
for n, (b1, b2) in enumerate(
178-
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
179+
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
179180
):
180181
if b1 != b2:
181-
a1 = n + size - len(v1.broadcastable) + 1
182-
a2 = n + size - len(v2.broadcastable) + 1
182+
a1 = n + size - v1.type.ndim + 1
183+
a2 = n + size - v2.type.ndim + 1
183184
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))
184185

185186

@@ -628,6 +629,7 @@ def validate_inner_graph(self):
628629
type_input = self.inner_inputs[inner_iidx].type
629630
type_output = self.inner_outputs[inner_oidx].type
630631
if (
632+
# TODO: Use the `Type` interface for this
631633
type_input.dtype != type_output.dtype
632634
or type_input.broadcastable != type_output.broadcastable
633635
):

0 commit comments

Comments
 (0)