@@ -646,7 +646,12 @@ class Repeat(Op):
646
646
647
647
__props__ = ("axis" ,)
648
648
649
- def __init__ (self , axis = None ):
649
+ def __init__ (self , axis : int | None = None ):
650
+ if axis is not None :
651
+ if not isinstance (axis , int ) or axis < 0 :
652
+ raise ValueError (
653
+ f"Repeat only accepts positive integer axis or None, got { axis } "
654
+ )
650
655
self .axis = axis
651
656
652
657
def make_node (self , x , repeats ):
@@ -687,58 +692,64 @@ def make_node(self, x, repeats):
687
692
out_shape = list (x .type .shape )
688
693
out_shape [self .axis ] = None
689
694
690
- out_type = TensorType (
691
- x .dtype , shape = tuple (1 if s == 1 else None for s in out_shape )
692
- )
693
-
695
+ out_type = TensorType (x .dtype , shape = out_shape )
694
696
return Apply (self , [x , repeats ], [out_type ()])
695
697
696
698
def perform (self , node , inputs , output_storage ):
697
- x = inputs [0 ]
698
- repeats = inputs [1 ]
699
- z = output_storage [0 ]
700
- z [0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
699
+ [x , repeats ] = inputs
700
+ output_storage [0 ][0 ] = np .repeat (x , repeats = repeats , axis = self .axis )
701
701
702
702
def connection_pattern (self , node ):
703
703
return [[True ], [False ]]
704
704
705
705
def grad (self , inputs , gout ):
706
706
(x , repeats ) = inputs
707
707
(gz ,) = gout
708
+ axis = self .axis
708
709
if repeats .ndim == 0 :
709
- if self .axis is None :
710
- axis = x .ndim
711
- else :
712
- if self .axis >= 0 :
713
- axis = self .axis + 1
714
- else :
715
- axis = self .axis + x .ndim + 1
716
-
717
- shape = [x .shape [k ] for k in range (x .ndim )]
718
- shape .insert (axis , repeats )
710
+ # When axis is a scalar (same number of reps for all elements),
711
+ # We can split the repetitions into their own axis with reshape and sum them back
712
+ # to the original element location
713
+ sum_axis = x .ndim if axis is None else axis + 1
714
+ shape = list (x .shape )
715
+ shape .insert (sum_axis , repeats )
716
+ gx = gz .reshape (shape ).sum (axis = sum_axis )
719
717
720
- return [
721
- gz .reshape (shape , ndim = x .ndim + 1 ).sum (axis = axis ),
722
- DisconnectedType ()(),
723
- ]
724
718
elif repeats .ndim == 1 :
725
- # For this implementation, we would need to specify the length
726
- # of repeats in order to split gz in the right way to sum
727
- # the good part.
728
- raise NotImplementedError ()
719
+ # To sum the gradients that belong to the same repeated x,
720
+ # We create a repeated eye and dot product it with the gradient.
721
+ axis_size = x .size if axis is None else x .shape [axis ]
722
+ tiled_eye = repeat (
723
+ ptb .eye (axis_size ), repeats , axis = 0
724
+ ) # A sparse repeat would be neat
725
+
726
+ if axis is None :
727
+ gx = gz @ tiled_eye
728
+ # Undo the ravelling when axis=None
729
+ gx = gx .reshape (x .shape )
730
+ else :
731
+ # Place gradient axis at end for dot product
732
+ gx = ptb .moveaxis (gz , axis , - 1 )
733
+ gx = gx @ tiled_eye
734
+ # Place gradient back into the correct axis
735
+ gx = ptb .moveaxis (gx , - 1 , axis )
736
+
729
737
else :
730
738
raise ValueError ()
731
739
740
+ return [gx , DisconnectedType ()()]
741
+
732
742
def infer_shape (self , fgraph , node , ins_shapes ):
733
743
i0_shapes = ins_shapes [0 ]
734
744
repeats = node .inputs [1 ]
735
745
out_shape = list (i0_shapes )
746
+ axis = self .axis
736
747
737
748
# uint64 shape are not supported.
738
749
dtype = None
739
750
if repeats .dtype in ("uint8" , "uint16" , "uint32" ):
740
751
dtype = "int64"
741
- if self . axis is None :
752
+ if axis is None :
742
753
if repeats .ndim == 0 :
743
754
if len (i0_shapes ) == 0 :
744
755
out_shape = [repeats ]
@@ -751,82 +762,97 @@ def infer_shape(self, fgraph, node, ins_shapes):
751
762
out_shape = [pt_sum (repeats , dtype = dtype )]
752
763
else :
753
764
if repeats .ndim == 0 :
754
- out_shape [self . axis ] = out_shape [self . axis ] * repeats
765
+ out_shape [axis ] = out_shape [axis ] * repeats
755
766
else :
756
- out_shape [self . axis ] = pt_sum (repeats , dtype = dtype )
767
+ out_shape [axis ] = pt_sum (repeats , dtype = dtype )
757
768
return [out_shape ]
758
769
759
770
760
- def repeat (x , repeats , axis = None ):
761
- """Repeat elements of an array .
771
+ def repeat (a : "TensorLike" , repeats : TensorLike , axis : int or None ) -> TensorVariable :
772
+ """Repeat elements of a tensor .
762
773
763
- It returns an array which has the same shape as `x`, except along the given
764
- `axis`. The `axis` parameter is used to specify the axis along which values
765
- are repeated. By default, a flattened version of `x` is used.
774
+ See `numpy.repeat` for more information.
766
775
767
- The number of repetitions for each element is `repeats`. `repeats` is
768
- broadcasted to fit the length of the given `axis`.
769
776
770
777
Parameters
771
778
----------
772
- x
773
- Input data, tensor variable.
774
- repeats
775
- int, scalar or tensor variable
779
+ a: tensor_like
780
+ Input tensor
781
+ repeats: tensor_like
782
+ The number of repetitions for each element. repeats is broadcasted to fit the shape of the given axis.
776
783
axis : int, optional
784
+ The axis along which to repeat values. By default, use the flattened input array, and return a flat output array.
777
785
778
- See Also
786
+ Returns
787
+ -------
788
+ repeated_tensor: TensorVariable
789
+ Output tensor which as the same shape as a, except along the given axis
790
+
791
+ Examples
779
792
--------
780
- tensor.tile
793
+
794
+ .. testcode::
795
+
796
+ import pytensor.tensor as pt
797
+
798
+ a = pt.arange(4).reshape((2, 2))
799
+ out = pt.repeat(a, repeats=[2, 3], axis=0)
800
+ print(out.eval())
801
+
802
+ .. testoutput::
803
+
804
+ [[0 1]
805
+ [0 1]
806
+ [2 3]
807
+ [2 3]
808
+ [2 3]]
809
+
781
810
782
811
.. versionadded:: 0.6
783
812
784
813
"""
814
+ a = ptb .as_tensor_variable (a )
815
+
816
+ if axis is not None :
817
+ axis = normalize_axis_index (axis , a .ndim )
818
+
785
819
repeats = ptb .as_tensor_variable (repeats , dtype = np .int64 )
786
820
787
821
if repeats .ndim > 1 :
788
822
raise ValueError ("The dimension of repeats should not exceed 1." )
789
823
790
824
if repeats .ndim == 1 and not repeats .broadcastable [0 ]:
791
- return Repeat (axis = axis )(x , repeats )
825
+ # We only use the Repeat Op for vector repeats
826
+ return Repeat (axis = axis )(a , repeats )
792
827
else :
793
828
if repeats .ndim == 1 :
794
829
repeats = repeats [0 ]
795
830
796
- if x .dtype == "uint64" :
831
+ if a .dtype == "uint64" :
797
832
raise TypeError ("repeat doesn't support dtype uint64" )
798
833
799
834
if axis is None :
800
835
axis = 0
801
- x = x .flatten ()
802
- else :
803
- if axis >= x .ndim :
804
- raise ValueError ("Axis should not exceed x.ndim-1." )
805
- if axis < 0 :
806
- axis = x .ndim + axis
836
+ a = a .flatten ()
807
837
808
- shape = [ x . shape [ i ] for i in range ( x . ndim )]
838
+ repeat_shape = list ( a . shape )
809
839
810
- # shape_ is the shape of the intermediate tensor which has
840
+ # alloc_shape is the shape of the intermediate tensor which has
811
841
# an additional dimension comparing to x. We use alloc to
812
842
# allocate space for this intermediate tensor to replicate x
813
843
# along that additional dimension.
814
- shape_ = shape [:]
815
- shape_ .insert (axis + 1 , repeats )
844
+ alloc_shape = repeat_shape [:]
845
+ alloc_shape .insert (axis + 1 , repeats )
816
846
817
- # shape is now the shape of output, where shape[axis] becomes
847
+ # repeat_shape is now the shape of output, where shape[axis] becomes
818
848
# shape[axis]*repeats.
819
- shape [axis ] = shape [axis ] * repeats
820
-
821
- # dims_ is the dimension of that intermediate tensor.
822
- dims_ = list (np .arange (x .ndim ))
823
- dims_ .insert (axis + 1 , "x" )
849
+ repeat_shape [axis ] = repeat_shape [axis ] * repeats
824
850
825
851
# After the original tensor is duplicated along the additional
826
- # dimension, we reshape it to the expected output shape, and
827
- # return the output z.
828
- z = ptb . alloc ( x . dimshuffle ( * dims_ ), * shape_ ). reshape ( shape )
829
- return z
852
+ # dimension, we reshape it to the expected output shape
853
+ return ptb . alloc ( ptb . expand_dims ( a , axis + 1 ), * alloc_shape ). reshape (
854
+ repeat_shape
855
+ )
830
856
831
857
832
858
class Bartlett (Op ):
0 commit comments