@@ -200,7 +200,7 @@ def copy_var_format(var, as_var):
200
200
rval = as_var .type .filter_variable (rval )
201
201
else :
202
202
tmp = as_var .type .clone (
203
- shape = (tuple (var .broadcastable [:1 ]) + tuple (as_var .broadcastable ))
203
+ shape = (tuple (var .type . shape [:1 ]) + tuple (as_var .type . shape ))
204
204
)
205
205
rval = tmp .filter_variable (rval )
206
206
return rval
@@ -805,7 +805,9 @@ def tensorConstructor(shape, dtype):
805
805
# output sequence
806
806
o = outputs [idx ]
807
807
self .output_types .append (
808
- typeConstructor ((False ,) + o .type .broadcastable , o .type .dtype )
808
+ # TODO: What can we actually say about the shape of this
809
+ # added dimension?
810
+ typeConstructor ((None ,) + o .type .shape , o .type .dtype )
809
811
)
810
812
811
813
idx += len (info .mit_mot_out_slices [jdx ])
@@ -816,7 +818,9 @@ def tensorConstructor(shape, dtype):
816
818
817
819
for o in outputs [idx :end ]:
818
820
self .output_types .append (
819
- typeConstructor ((False ,) + o .type .broadcastable , o .type .dtype )
821
+ # TODO: What can we actually say about the shape of this
822
+ # added dimension?
823
+ typeConstructor ((None ,) + o .type .shape , o .type .dtype )
820
824
)
821
825
822
826
# shared outputs + possibly the ending condition
@@ -2320,8 +2324,8 @@ def infer_shape(self, fgraph, node, input_shapes):
2320
2324
# equivalent (if False). Here, we only need the variable.
2321
2325
v_shp_i = validator .check (shp_i )
2322
2326
if v_shp_i is None :
2323
- if hasattr ( r , "broadcastable" ) and r . broadcastable [i ]:
2324
- shp .append (1 )
2327
+ if r . type . shape [i ] is not None :
2328
+ shp .append (r . type . shape [ i ] )
2325
2329
else :
2326
2330
shp .append (Shape_i (i )(r ))
2327
2331
else :
0 commit comments