@@ -3207,13 +3207,14 @@ def tile(
3207
3207
return A_replicated .reshape (tiled_shape )
3208
3208
3209
3209
3210
- class ARange (Op ):
3210
+ class ARange (COp ):
3211
3211
"""Create an array containing evenly spaced values within a given interval.
3212
3212
3213
3213
Parameters and behaviour are the same as numpy.arange().
3214
3214
3215
3215
"""
3216
3216
3217
+ # TODO: Arange should work with scalars as inputs, not arrays
3217
3218
__props__ = ("dtype" ,)
3218
3219
3219
3220
def __init__ (self , dtype ):
@@ -3293,13 +3294,30 @@ def upcast(var):
3293
3294
)
3294
3295
]
3295
3296
3296
- def perform (self , node , inp , out_ ):
3297
- start , stop , step = inp
3298
- (out ,) = out_
3299
- start = start .item ()
3300
- stop = stop .item ()
3301
- step = step .item ()
3302
- out [0 ] = np .arange (start , stop , step , dtype = self .dtype )
3297
+ def perform (self , node , inputs , output_storage ):
3298
+ start , stop , step = inputs
3299
+ output_storage [0 ][0 ] = np .arange (
3300
+ start .item (), stop .item (), step .item (), dtype = self .dtype
3301
+ )
3302
+
3303
+ def c_code (self , node , nodename , input_names , output_names , sub ):
3304
+ [start_name , stop_name , step_name ] = input_names
3305
+ [out_name ] = output_names
3306
+ typenum = np .dtype (self .dtype ).num
3307
+ return f"""
3308
+ double start = ((dtype_{ start_name } *)PyArray_DATA({ start_name } ))[0];
3309
+ double stop = ((dtype_{ stop_name } *)PyArray_DATA({ stop_name } ))[0];
3310
+ double step = ((dtype_{ step_name } *)PyArray_DATA({ step_name } ))[0];
3311
+ //printf("start: %f, stop: %f, step: %f\\ n", start, stop, step);
3312
+ Py_XDECREF({ out_name } );
3313
+ { out_name } = (PyArrayObject*) PyArray_Arange(start, stop, step, { typenum } );
3314
+ if (!{ out_name } ) {{
3315
+ { sub ["fail" ]}
3316
+ }}
3317
+ """
3318
+
3319
+ def c_code_cache_version (self ):
3320
+ return (0 ,)
3303
3321
3304
3322
def connection_pattern (self , node ):
3305
3323
return [[True ], [False ], [True ]]
@@ -3685,8 +3703,7 @@ def inverse_permutation(perm):
3685
3703
)
3686
3704
3687
3705
3688
- # TODO: optimization to insert ExtractDiag with view=True
3689
- class ExtractDiag (Op ):
3706
+ class ExtractDiag (COp ):
3690
3707
"""
3691
3708
Return specified diagonals.
3692
3709
@@ -3742,7 +3759,7 @@ class ExtractDiag(Op):
3742
3759
3743
3760
__props__ = ("offset" , "axis1" , "axis2" , "view" )
3744
3761
3745
- def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 , view = False ):
3762
+ def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 , view = True ):
3746
3763
self .view = view
3747
3764
if self .view :
3748
3765
self .view_map = {0 : [0 ]}
@@ -3765,24 +3782,74 @@ def make_node(self, x):
3765
3782
if x .ndim < 2 :
3766
3783
raise ValueError ("ExtractDiag needs an input with 2 or more dimensions" , x )
3767
3784
3768
- out_shape = [
3769
- st_dim
3770
- for i , st_dim in enumerate (x .type .shape )
3771
- if i not in (self .axis1 , self .axis2 )
3772
- ] + [None ]
3785
+ if (dim1 := x .type .shape [self .axis1 ]) is not None and (
3786
+ dim2 := x .type .shape [self .axis2 ]
3787
+ ) is not None :
3788
+ offset = self .offset
3789
+ if offset > 0 :
3790
+ diag_size = int (np .clip (dim2 - offset , 0 , dim1 ))
3791
+ elif offset < 0 :
3792
+ diag_size = int (np .clip (dim1 + offset , 0 , dim2 ))
3793
+ else :
3794
+ diag_size = int (np .minimum (dim1 , dim2 ))
3795
+ else :
3796
+ diag_size = None
3797
+
3798
+ out_shape = (
3799
+ * (
3800
+ dim
3801
+ for i , dim in enumerate (x .type .shape )
3802
+ if i not in (self .axis1 , self .axis2 )
3803
+ ),
3804
+ diag_size ,
3805
+ )
3773
3806
3774
3807
return Apply (
3775
3808
self ,
3776
3809
[x ],
3777
- [x .type .clone (dtype = x .dtype , shape = tuple ( out_shape ) )()],
3810
+ [x .type .clone (dtype = x .dtype , shape = out_shape )()],
3778
3811
)
3779
3812
3780
- def perform (self , node , inputs , outputs ):
3813
+ def perform (self , node , inputs , output_storage ):
3781
3814
(x ,) = inputs
3782
- (z ,) = outputs
3783
- z [0 ] = x .diagonal (self .offset , self .axis1 , self .axis2 )
3784
- if not self .view :
3785
- z [0 ] = z [0 ].copy ()
3815
+ out = x .diagonal (self .offset , self .axis1 , self .axis2 )
3816
+ if self .view :
3817
+ try :
3818
+ out .flags .writeable = True
3819
+ except ValueError :
3820
+ # We can't make this array writable
3821
+ out = out .copy ()
3822
+ else :
3823
+ out = out .copy ()
3824
+ output_storage [0 ][0 ] = out
3825
+
3826
+ def c_code (self , node , nodename , input_names , output_names , sub ):
3827
+ [x_name ] = input_names
3828
+ [out_name ] = output_names
3829
+ return f"""
3830
+ Py_XDECREF({ out_name } );
3831
+
3832
+ { out_name } = (PyArrayObject*) PyArray_Diagonal({ x_name } , { self .offset } , { self .axis1 } , { self .axis2 } );
3833
+ if (!{ out_name } ) {{
3834
+ { sub ["fail" ]} // Error already set by Numpy
3835
+ }}
3836
+
3837
+ if ({ int (self .view )} && PyArray_ISWRITEABLE({ x_name } )) {{
3838
+ // Make output writeable if input was writeable
3839
+ PyArray_ENABLEFLAGS({ out_name } , NPY_ARRAY_WRITEABLE);
3840
+ }} else {{
3841
+ // Make a copy
3842
+ PyArrayObject *{ out_name } _copy = (PyArrayObject*) PyArray_Copy({ out_name } );
3843
+ Py_DECREF({ out_name } );
3844
+ if (!{ out_name } _copy) {{
3845
+ { sub ['fail' ]} ; // Error already set by Numpy
3846
+ }}
3847
+ { out_name } = { out_name } _copy;
3848
+ }}
3849
+ """
3850
+
3851
+ def c_code_cache_version (self ):
3852
+ return (0 ,)
3786
3853
3787
3854
def grad (self , inputs , gout ):
3788
3855
# Avoid circular import
@@ -3829,19 +3896,6 @@ def infer_shape(self, fgraph, node, shapes):
3829
3896
out_shape .append (diag_size )
3830
3897
return [tuple (out_shape )]
3831
3898
3832
- def __setstate__ (self , state ):
3833
- self .__dict__ .update (state )
3834
-
3835
- if self .view :
3836
- self .view_map = {0 : [0 ]}
3837
-
3838
- if "offset" not in state :
3839
- self .offset = 0
3840
- if "axis1" not in state :
3841
- self .axis1 = 0
3842
- if "axis2" not in state :
3843
- self .axis2 = 1
3844
-
3845
3899
3846
3900
def extract_diag (x ):
3847
3901
warnings .warn (
0 commit comments