@@ -72,8 +72,6 @@ def test_any(x, data):
72
72
data = st .data (),
73
73
)
74
74
def test_diff (x , data ):
75
- # TODO:
76
- # 1. append/prepend
77
75
axis = data .draw (
78
76
st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
79
77
label = "axis"
@@ -91,6 +89,7 @@ def test_diff(x, data):
91
89
92
90
expected_shape = list (x .shape )
93
91
expected_shape [n_axis ] -= n
92
+
94
93
assert out .shape == tuple (expected_shape )
95
94
96
95
# value test
@@ -100,3 +99,43 @@ def test_diff(x, data):
100
99
l [n_axis ] += 1
101
100
assert out [idx ] == x [tuple (l )] - x [idx ], f"diff failed with { idx = } "
102
101
102
+
103
+ @pytest .mark .min_version ("2024.12" )
104
+ @pytest .mark .unvectorized
105
+ @given (
106
+ x = hh .arrays (hh .numeric_dtypes , hh .shapes (min_dims = 1 , min_side = 1 )),
107
+ data = st .data (),
108
+ )
109
+ def test_diff_append_prepend (x , data ):
110
+ axis = data .draw (
111
+ st .integers (- x .ndim , max (x .ndim - 1 , 0 )) | st .none (),
112
+ label = "axis"
113
+ )
114
+ if axis is None :
115
+ axis_kw = {"axis" : - 1 }
116
+ n_axis = x .ndim - 1
117
+ else :
118
+ axis_kw = {"axis" : axis }
119
+ n_axis = axis + x .ndim if axis < 0 else axis
120
+
121
+ n = data .draw (st .integers (1 , min (x .shape [n_axis ], 3 )))
122
+
123
+ append_shape = list (x .shape )
124
+ append_axis_len = data .draw (st .integers (1 , 2 * append_shape [n_axis ]), label = "append_axis" )
125
+ append_shape [n_axis ] = append_axis_len
126
+ append = data .draw (hh .arrays (dtype = x .dtype , shape = tuple (append_shape )), label = "append" )
127
+
128
+ prepend_shape = list (x .shape )
129
+ prepend_axis_len = data .draw (st .integers (1 , 2 * prepend_shape [n_axis ]), label = "prepend_axis" )
130
+ prepend_shape [n_axis ] = prepend_axis_len
131
+ prepend = data .draw (hh .arrays (dtype = x .dtype , shape = tuple (prepend_shape )), label = "prepend" )
132
+
133
+ out = xp .diff (x , ** axis_kw , n = n , append = append , prepend = prepend )
134
+
135
+ in_1 = xp .concat ((prepend , x , append ), ** axis_kw )
136
+ out_1 = xp .diff (in_1 , ** axis_kw , n = n )
137
+
138
+ assert out .shape == out_1 .shape
139
+ for idx in sh .ndindex (out .shape ):
140
+ assert out [idx ] == out_1 [idx ], f"{ idx = } "
141
+
0 commit comments