@@ -2559,17 +2559,17 @@ def roll(x, shift, axis=None):
2559
2559
)
2560
2560
2561
2561
2562
- def stack (tensors : Sequence [TensorVariable ], axis : int = 0 ):
2562
+ def stack (tensors : Sequence ["TensorLike" ], axis : int = 0 ):
2563
2563
"""Stack tensors in sequence on given axis (default is 0).
2564
2564
2565
- Take a sequence of tensors and stack them on given axis to make a single
2566
- tensor. The size in dimension `axis` of the result will be equal to the number
2567
- of tensors passed.
2565
+ Take a sequence of tensors or tensor-like constant and stack them on
2566
+ given axis to make a single tensor. The size in dimension `axis` of the
2567
+ result will be equal to the number of tensors passed.
2568
2568
2569
2569
Parameters
2570
2570
----------
2571
- tensors : Sequence[TensorVariable ]
2572
- A list of tensors to be stacked.
2571
+ tensors : Sequence[TensorLike ]
2572
+ A list of tensors or tensor-like constants to be stacked.
2573
2573
axis : int
2574
2574
The index of the new axis. Default value is 0.
2575
2575
@@ -2604,11 +2604,11 @@ def stack(tensors: Sequence[TensorVariable], axis: int = 0):
2604
2604
(2, 2, 2, 3, 2)
2605
2605
"""
2606
2606
if not isinstance (tensors , Sequence ):
2607
- raise TypeError ("First argument should be Sequence[TensorVariable] " )
2607
+ raise TypeError ("First argument should be a Sequence. " )
2608
2608
elif len (tensors ) == 0 :
2609
- raise ValueError ("No tensor arguments provided" )
2609
+ raise ValueError ("No tensor arguments provided. " )
2610
2610
2611
- # If all tensors are scalars of the same type , call make_vector.
2611
+ # If all tensors are scalars, call make_vector.
2612
2612
# It makes the graph simpler, by not adding DimShuffles and SpecifyShapes
2613
2613
2614
2614
# This should be an optimization!
@@ -2618,12 +2618,13 @@ def stack(tensors: Sequence[TensorVariable], axis: int = 0):
2618
2618
# optimization.
2619
2619
# See ticket #660
2620
2620
if all (
2621
- # In case there are explicit ints in tensors
2622
- isinstance (t , (np .number , float , int , builtins .complex ))
2621
+ # In case there are explicit scalars in tensors
2622
+ isinstance (t , Number )
2623
+ or (isinstance (t , np .ndarray ) and t .ndim == 0 )
2623
2624
or (isinstance (t , Variable ) and isinstance (t .type , TensorType ) and t .ndim == 0 )
2624
2625
for t in tensors
2625
2626
):
2626
- # in case there is direct int
2627
+ # In case there is direct scalar
2627
2628
tensors = list (map (as_tensor_variable , tensors ))
2628
2629
dtype = aes .upcast (* [i .dtype for i in tensors ])
2629
2630
return MakeVector (dtype )(* tensors )
0 commit comments