Skip to content

Commit 35f8906

Browse files
committed
Make sure both arrays in the assertion are masked
1 parent c083fa1 commit 35f8906

27 files changed

+180
-180
lines changed

array_api_tests/special_cases/test_abs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_abs_special_cases_one_arg_equal_1(arg1):
2424
"""
2525
res = abs(arg1)
2626
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -37,7 +37,7 @@ def test_abs_special_cases_one_arg_equal_2(arg1):
3737
"""
3838
res = abs(arg1)
3939
mask = exactly_equal(arg1, -zero(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -50,4 +50,4 @@ def test_abs_special_cases_one_arg_equal_3(arg1):
5050
"""
5151
res = abs(arg1)
5252
mask = exactly_equal(arg1, -infinity(arg1.dtype))
53-
assert_exactly_equal(res[mask], infinity(arg1.dtype))
53+
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])

array_api_tests/special_cases/test_acos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_acos_special_cases_one_arg_equal_1(arg1):
2424
"""
2525
res = acos(arg1)
2626
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -37,7 +37,7 @@ def test_acos_special_cases_one_arg_equal_2(arg1):
3737
"""
3838
res = acos(arg1)
3939
mask = exactly_equal(arg1, one(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -50,7 +50,7 @@ def test_acos_special_cases_one_arg_greater(arg1):
5050
"""
5151
res = acos(arg1)
5252
mask = greater(arg1, one(arg1.dtype))
53-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
53+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -63,4 +63,4 @@ def test_acos_special_cases_one_arg_less(arg1):
6363
"""
6464
res = acos(arg1)
6565
mask = less(arg1, -one(arg1.dtype))
66-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
66+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])

array_api_tests/special_cases/test_acosh.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_acosh_special_cases_one_arg_equal_1(arg1):
2424
"""
2525
res = acosh(arg1)
2626
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -37,7 +37,7 @@ def test_acosh_special_cases_one_arg_equal_2(arg1):
3737
"""
3838
res = acosh(arg1)
3939
mask = exactly_equal(arg1, one(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -50,7 +50,7 @@ def test_acosh_special_cases_one_arg_equal_3(arg1):
5050
"""
5151
res = acosh(arg1)
5252
mask = exactly_equal(arg1, infinity(arg1.dtype))
53-
assert_exactly_equal(res[mask], infinity(arg1.dtype))
53+
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -63,4 +63,4 @@ def test_acosh_special_cases_one_arg_less(arg1):
6363
"""
6464
res = acosh(arg1)
6565
mask = less(arg1, one(arg1.dtype))
66-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
66+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])

array_api_tests/special_cases/test_add.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_add_special_cases_two_args_either(arg1, arg2):
2525
"""
2626
res = add(arg1, arg2)
2727
mask = logical_or(exactly_equal(arg1, NaN(arg1.dtype)), exactly_equal(arg2, NaN(arg1.dtype)))
28-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
28+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2929

3030

3131
@given(numeric_arrays, numeric_arrays)
@@ -38,7 +38,7 @@ def test_add_special_cases_two_args_equal__equal_1(arg1, arg2):
3838
"""
3939
res = add(arg1, arg2)
4040
mask = logical_and(exactly_equal(arg1, infinity(arg1.dtype)), exactly_equal(arg2, -infinity(arg2.dtype)))
41-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
41+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
4242

4343

4444
@given(numeric_arrays, numeric_arrays)
@@ -51,7 +51,7 @@ def test_add_special_cases_two_args_equal__equal_2(arg1, arg2):
5151
"""
5252
res = add(arg1, arg2)
5353
mask = logical_and(exactly_equal(arg1, -infinity(arg1.dtype)), exactly_equal(arg2, infinity(arg2.dtype)))
54-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
54+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
5555

5656

5757
@given(numeric_arrays, numeric_arrays)
@@ -64,7 +64,7 @@ def test_add_special_cases_two_args_equal__equal_3(arg1, arg2):
6464
"""
6565
res = add(arg1, arg2)
6666
mask = logical_and(exactly_equal(arg1, infinity(arg1.dtype)), exactly_equal(arg2, infinity(arg2.dtype)))
67-
assert_exactly_equal(res[mask], infinity(arg1.dtype))
67+
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
6868

6969

7070
@given(numeric_arrays, numeric_arrays)
@@ -77,7 +77,7 @@ def test_add_special_cases_two_args_equal__equal_4(arg1, arg2):
7777
"""
7878
res = add(arg1, arg2)
7979
mask = logical_and(exactly_equal(arg1, -infinity(arg1.dtype)), exactly_equal(arg2, -infinity(arg2.dtype)))
80-
assert_exactly_equal(res[mask], -infinity(arg1.dtype))
80+
assert_exactly_equal(res[mask], -infinity(arg1.dtype)[mask])
8181

8282

8383
@given(numeric_arrays, numeric_arrays)
@@ -90,7 +90,7 @@ def test_add_special_cases_two_args_equal__equal_5(arg1, arg2):
9090
"""
9191
res = add(arg1, arg2)
9292
mask = logical_and(exactly_equal(arg1, infinity(arg1.dtype)), isfinite(arg2))
93-
assert_exactly_equal(res[mask], infinity(arg1.dtype))
93+
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
9494

9595

9696
@given(numeric_arrays, numeric_arrays)
@@ -103,7 +103,7 @@ def test_add_special_cases_two_args_equal__equal_6(arg1, arg2):
103103
"""
104104
res = add(arg1, arg2)
105105
mask = logical_and(exactly_equal(arg1, -infinity(arg1.dtype)), isfinite(arg2))
106-
assert_exactly_equal(res[mask], -infinity(arg1.dtype))
106+
assert_exactly_equal(res[mask], -infinity(arg1.dtype)[mask])
107107

108108

109109
@given(numeric_arrays, numeric_arrays)
@@ -116,7 +116,7 @@ def test_add_special_cases_two_args_equal__equal_7(arg1, arg2):
116116
"""
117117
res = add(arg1, arg2)
118118
mask = logical_and(isfinite(arg1), exactly_equal(arg2, infinity(arg2.dtype)))
119-
assert_exactly_equal(res[mask], infinity(arg1.dtype))
119+
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
120120

121121

122122
@given(numeric_arrays, numeric_arrays)
@@ -129,7 +129,7 @@ def test_add_special_cases_two_args_equal__equal_8(arg1, arg2):
129129
"""
130130
res = add(arg1, arg2)
131131
mask = logical_and(isfinite(arg1), exactly_equal(arg2, -infinity(arg2.dtype)))
132-
assert_exactly_equal(res[mask], -infinity(arg1.dtype))
132+
assert_exactly_equal(res[mask], -infinity(arg1.dtype)[mask])
133133

134134

135135
@given(numeric_arrays, numeric_arrays)
@@ -142,7 +142,7 @@ def test_add_special_cases_two_args_equal__equal_9(arg1, arg2):
142142
"""
143143
res = add(arg1, arg2)
144144
mask = logical_and(exactly_equal(arg1, -zero(arg1.dtype)), exactly_equal(arg2, -zero(arg2.dtype)))
145-
assert_exactly_equal(res[mask], -zero(arg1.dtype))
145+
assert_exactly_equal(res[mask], -zero(arg1.dtype)[mask])
146146

147147

148148
@given(numeric_arrays, numeric_arrays)
@@ -155,7 +155,7 @@ def test_add_special_cases_two_args_equal__equal_10(arg1, arg2):
155155
"""
156156
res = add(arg1, arg2)
157157
mask = logical_and(exactly_equal(arg1, -zero(arg1.dtype)), exactly_equal(arg2, zero(arg2.dtype)))
158-
assert_exactly_equal(res[mask], zero(arg1.dtype))
158+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
159159

160160

161161
@given(numeric_arrays, numeric_arrays)
@@ -168,7 +168,7 @@ def test_add_special_cases_two_args_equal__equal_11(arg1, arg2):
168168
"""
169169
res = add(arg1, arg2)
170170
mask = logical_and(exactly_equal(arg1, zero(arg1.dtype)), exactly_equal(arg2, -zero(arg2.dtype)))
171-
assert_exactly_equal(res[mask], zero(arg1.dtype))
171+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
172172

173173

174174
@given(numeric_arrays, numeric_arrays)
@@ -181,7 +181,7 @@ def test_add_special_cases_two_args_equal__equal_12(arg1, arg2):
181181
"""
182182
res = add(arg1, arg2)
183183
mask = logical_and(exactly_equal(arg1, zero(arg1.dtype)), exactly_equal(arg2, zero(arg2.dtype)))
184-
assert_exactly_equal(res[mask], zero(arg1.dtype))
184+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
185185

186186

187187
@given(numeric_arrays, numeric_arrays)
@@ -194,7 +194,7 @@ def test_add_special_cases_two_args_equal__equal_13(arg1, arg2):
194194
"""
195195
res = add(arg1, arg2)
196196
mask = logical_and(logical_and(isfinite(arg1), nonzero(arg1)), exactly_equal(arg2, -arg1))
197-
assert_exactly_equal(res[mask], zero(arg1.dtype))
197+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
198198

199199

200200
@given(numeric_arrays, numeric_arrays)

array_api_tests/special_cases/test_asin.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_asin_special_cases_one_arg_equal_1(arg1):
2424
"""
2525
res = asin(arg1)
2626
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -37,7 +37,7 @@ def test_asin_special_cases_one_arg_equal_2(arg1):
3737
"""
3838
res = asin(arg1)
3939
mask = exactly_equal(arg1, zero(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -50,7 +50,7 @@ def test_asin_special_cases_one_arg_equal_3(arg1):
5050
"""
5151
res = asin(arg1)
5252
mask = exactly_equal(arg1, -zero(arg1.dtype))
53-
assert_exactly_equal(res[mask], -zero(arg1.dtype))
53+
assert_exactly_equal(res[mask], -zero(arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -63,7 +63,7 @@ def test_asin_special_cases_one_arg_greater(arg1):
6363
"""
6464
res = asin(arg1)
6565
mask = greater(arg1, one(arg1.dtype))
66-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
66+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
6767

6868

6969
@given(numeric_arrays)
@@ -76,4 +76,4 @@ def test_asin_special_cases_one_arg_less(arg1):
7676
"""
7777
res = asin(arg1)
7878
mask = less(arg1, -one(arg1.dtype))
79-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
79+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])

array_api_tests/special_cases/test_asinh.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_asinh_special_cases_one_arg_equal_1(arg1):
2424
"""
2525
res = asinh(arg1)
2626
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -37,7 +37,7 @@ def test_asinh_special_cases_one_arg_equal_2(arg1):
3737
"""
3838
res = asinh(arg1)
3939
mask = exactly_equal(arg1, zero(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -50,7 +50,7 @@ def test_asinh_special_cases_one_arg_equal_3(arg1):
5050
"""
5151
res = asinh(arg1)
5252
mask = exactly_equal(arg1, -zero(arg1.dtype))
53-
assert_exactly_equal(res[mask], -zero(arg1.dtype))
53+
assert_exactly_equal(res[mask], -zero(arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -63,7 +63,7 @@ def test_asinh_special_cases_one_arg_equal_4(arg1):
6363
"""
6464
res = asinh(arg1)
6565
mask = exactly_equal(arg1, infinity(arg1.dtype))
66-
assert_exactly_equal(res[mask], infinity(arg1.dtype))
66+
assert_exactly_equal(res[mask], infinity(arg1.dtype)[mask])
6767

6868

6969
@given(numeric_arrays)
@@ -76,4 +76,4 @@ def test_asinh_special_cases_one_arg_equal_5(arg1):
7676
"""
7777
res = asinh(arg1)
7878
mask = exactly_equal(arg1, -infinity(arg1.dtype))
79-
assert_exactly_equal(res[mask], -infinity(arg1.dtype))
79+
assert_exactly_equal(res[mask], -infinity(arg1.dtype)[mask])

array_api_tests/special_cases/test_atan.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_atan_special_cases_one_arg_equal_1(arg1):
2424
"""
2525
res = atan(arg1)
2626
mask = exactly_equal(arg1, NaN(arg1.dtype))
27-
assert_exactly_equal(res[mask], NaN(arg1.dtype))
27+
assert_exactly_equal(res[mask], NaN(arg1.dtype)[mask])
2828

2929

3030
@given(numeric_arrays)
@@ -37,7 +37,7 @@ def test_atan_special_cases_one_arg_equal_2(arg1):
3737
"""
3838
res = atan(arg1)
3939
mask = exactly_equal(arg1, zero(arg1.dtype))
40-
assert_exactly_equal(res[mask], zero(arg1.dtype))
40+
assert_exactly_equal(res[mask], zero(arg1.dtype)[mask])
4141

4242

4343
@given(numeric_arrays)
@@ -50,7 +50,7 @@ def test_atan_special_cases_one_arg_equal_3(arg1):
5050
"""
5151
res = atan(arg1)
5252
mask = exactly_equal(arg1, -zero(arg1.dtype))
53-
assert_exactly_equal(res[mask], -zero(arg1.dtype))
53+
assert_exactly_equal(res[mask], -zero(arg1.dtype)[mask])
5454

5555

5656
@given(numeric_arrays)
@@ -63,7 +63,7 @@ def test_atan_special_cases_one_arg_equal_4(arg1):
6363
"""
6464
res = atan(arg1)
6565
mask = exactly_equal(arg1, infinity(arg1.dtype))
66-
assert_exactly_equal(res[mask], +π(arg1.dtype)/2)
66+
assert_exactly_equal(res[mask], +π(arg1.dtype)/2[mask])
6767

6868

6969
@given(numeric_arrays)
@@ -76,4 +76,4 @@ def test_atan_special_cases_one_arg_equal_5(arg1):
7676
"""
7777
res = atan(arg1)
7878
mask = exactly_equal(arg1, -infinity(arg1.dtype))
79-
assert_exactly_equal(res[mask], -π(arg1.dtype)/2)
79+
assert_exactly_equal(res[mask], -π(arg1.dtype)/2[mask])

0 commit comments

Comments
 (0)