File tree Expand file tree Collapse file tree 1 file changed +19
-2
lines changed Expand file tree Collapse file tree 1 file changed +19
-2
lines changed Original file line number Diff line number Diff line change @@ -165,10 +165,27 @@ def matrix_shapes(draw, stack_shapes=shapes()):
165
165
allow_infinity = False ))
166
166
167
167
def mutually_broadcastable_shapes (
168
- num_shapes : int , ** kw
168
+ num_shapes : int ,
169
+ * ,
170
+ base_shape : Shape = (),
171
+ min_dims : int = 0 ,
172
+ max_dims : Optional [int ] = None ,
173
+ min_side : int = 0 ,
174
+ max_side : Optional [int ] = None ,
169
175
) -> SearchStrategy [Tuple [Shape , ...]]:
176
+ if max_dims is None :
177
+ max_dims = min (max (len (base_shape ), min_dims ) + 5 , 32 )
178
+ if max_side is None :
179
+ max_side = max (base_shape [- max_dims :] + (min_side ,)) + 5
170
180
return (
171
- xps .mutually_broadcastable_shapes (num_shapes , ** kw )
181
+ xps .mutually_broadcastable_shapes (
182
+ num_shapes ,
183
+ base_shape = base_shape ,
184
+ min_dims = min_dims ,
185
+ max_dims = max_dims ,
186
+ min_side = min_side ,
187
+ max_side = max_side ,
188
+ )
172
189
.map (lambda BS : BS .input_shapes )
173
190
.filter (lambda shapes : all (
174
191
prod (i for i in s if i > 0 ) < MAX_ARRAY_SIZE for s in shapes
You can’t perform that action at this time.
0 commit comments