Skip to content

Commit 807c0c9

Browse files
Use static shape information instead of broadcastable in Scan
1 parent e59cef1 commit 807c0c9

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

aesara/scan/op.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def copy_var_format(var, as_var):
200200
rval = as_var.type.filter_variable(rval)
201201
else:
202202
tmp = as_var.type.clone(
203-
shape=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
203+
shape=(tuple(var.type.shape[:1]) + tuple(as_var.type.shape))
204204
)
205205
rval = tmp.filter_variable(rval)
206206
return rval
@@ -805,7 +805,9 @@ def tensorConstructor(shape, dtype):
805805
# output sequence
806806
o = outputs[idx]
807807
self.output_types.append(
808-
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
808+
# TODO: What can we actually say about the shape of this
809+
# added dimension?
810+
typeConstructor((None,) + o.type.shape, o.type.dtype)
809811
)
810812

811813
idx += len(info.mit_mot_out_slices[jdx])
@@ -816,7 +818,9 @@ def tensorConstructor(shape, dtype):
816818

817819
for o in outputs[idx:end]:
818820
self.output_types.append(
819-
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
821+
# TODO: What can we actually say about the shape of this
822+
# added dimension?
823+
typeConstructor((None,) + o.type.shape, o.type.dtype)
820824
)
821825

822826
# shared outputs + possibly the ending condition
@@ -2320,8 +2324,8 @@ def infer_shape(self, fgraph, node, input_shapes):
23202324
# equivalent (if False). Here, we only need the variable.
23212325
v_shp_i = validator.check(shp_i)
23222326
if v_shp_i is None:
2323-
if hasattr(r, "broadcastable") and r.broadcastable[i]:
2324-
shp.append(1)
2327+
if r.type.shape[i] is not None:
2328+
shp.append(r.type.shape[i])
23252329
else:
23262330
shp.append(Shape_i(i)(r))
23272331
else:

0 commit comments

Comments
 (0)