File tree Expand file tree Collapse file tree 3 files changed +12
-6
lines changed Expand file tree Collapse file tree 3 files changed +12
-6
lines changed Original file line number Diff line number Diff line change @@ -1406,11 +1406,8 @@ def infer_static_shape(
1406
1406
`shape` will be validated and constant folded. As a result, this function
1407
1407
can be expensive and shouldn't be used unless absolutely necessary.
1408
1408
1409
- It mostly exists as a hold-over from pre-static shape times, when it was
1410
- required in order to produce correct broadcastable arrays and prevent
1411
- some graphs from being unusable. Now, it is no longer strictly required,
1412
- so don't use it unless you want the same shape graphs to be rewritten
1413
- multiple times during graph construction.
1409
+ It is often needed for `Op`s whose static shape and broadcastable flags
1410
+ depend on the values of their inputs, such as `Alloc` and `RandomVariable`.
1414
1411
1415
1412
Returns
1416
1413
-------
Original file line number Diff line number Diff line change @@ -992,12 +992,17 @@ def local_merge_consecutive_specify_shape(fgraph, node):
992
992
return [specify_shape (inner_obj , shape )]
993
993
994
994
995
+ _empty_shape = constant ([], dtype = "int64" )
996
+
997
+
995
998
@register_infer_shape
996
999
@node_rewriter ([Shape ])
997
1000
def local_shape_ground (fgraph , node ):
998
1001
"""Rewrite shape(x) -> make_vector(x.type.shape) when this is constant."""
999
1002
[x ] = node .inputs
1000
1003
static_shape = x .type .shape
1004
+ if len (static_shape ) == 0 :
1005
+ return [_empty_shape ]
1001
1006
if not any (dim is None for dim in static_shape ):
1002
1007
return [stack ([constant (dim , dtype = "int64" ) for dim in static_shape ])]
1003
1008
Original file line number Diff line number Diff line change @@ -908,7 +908,7 @@ def test_runtime_broadcast(self, mode):
908
908
self .check_runtime_broadcast (mode )
909
909
910
910
911
- def test_infer_shape ():
911
+ def test_infer_static_shape ():
912
912
with pytest .raises (TypeError , match = "^Shapes must be scalar integers.*" ):
913
913
infer_static_shape ([constant (1.0 )])
914
914
@@ -925,6 +925,10 @@ def test_infer_shape():
925
925
sh , static_shape = infer_static_shape (specify_size )
926
926
assert static_shape == (1 ,)
927
927
928
+ x = scalar ("x" )
929
+ sh , static_shape = infer_static_shape ([x .size ])
930
+ assert static_shape == (1 ,)
931
+
928
932
929
933
# This is slow for the ('int8', 3) version.
930
934
def test_eye ():
You can’t perform that action at this time.
0 commit comments