@@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
161
161
which may wrongly be interpreted as broadcastable.
162
162
163
163
"""
164
- if not hasattr (v1 , "broadcastable" ) and not hasattr (v2 , "broadcastable" ):
164
+ if not isinstance (v1 . type , TensorType ) and not isinstance (v2 . type , TensorType ):
165
165
return
166
+
166
167
msg = (
167
168
"The broadcast pattern of the output of scan (%s) is "
168
169
"inconsistent with the one provided in `output_info` "
@@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
173
174
"them consistent, e.g. using aesara.tensor."
174
175
"{unbroadcast, specify_broadcastable}."
175
176
)
176
- size = min (len ( v1 .broadcastable ), len ( v2 .broadcastable ) )
177
+ size = min (v1 .type . ndim , v2 .type . ndim )
177
178
for n , (b1 , b2 ) in enumerate (
178
- zip (v1 .broadcastable [- size :], v2 .broadcastable [- size :])
179
+ zip (v1 .type . broadcastable [- size :], v2 . type .broadcastable [- size :])
179
180
):
180
181
if b1 != b2 :
181
- a1 = n + size - len ( v1 .broadcastable ) + 1
182
- a2 = n + size - len ( v2 .broadcastable ) + 1
182
+ a1 = n + size - v1 .type . ndim + 1
183
+ a2 = n + size - v2 .type . ndim + 1
183
184
raise TypeError (msg % (v1 .type , v2 .type , a1 , b1 , b2 , a2 ))
184
185
185
186
@@ -628,6 +629,7 @@ def validate_inner_graph(self):
628
629
type_input = self .inner_inputs [inner_iidx ].type
629
630
type_output = self .inner_outputs [inner_oidx ].type
630
631
if (
632
+ # TODO: Use the `Type` interface for this
631
633
type_input .dtype != type_output .dtype
632
634
or type_input .broadcastable != type_output .broadcastable
633
635
):
0 commit comments