Skip to content

Commit 0699b48

Browse files
committed
Trust input in test_math_scipy benchmark tests
1 parent e299023 commit 0699b48

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/tensor/test_math_scipy.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,11 +431,13 @@ def test_gammaincc_ddk_performance(benchmark):
431431
x = vector("x")
432432

433433
out = gammaincc(k, x)
434-
grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN")
434+
grad_fn = function(
435+
[k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN", trust_input=True
436+
)
435437
vals = [
436438
# Values that hit the second branch of the gradient
437-
np.full((1000,), 3.2),
438-
np.full((1000,), 0.01),
439+
np.full((1000,), 3.2, dtype=k.dtype),
440+
np.full((1000,), 0.01, dtype=x.dtype),
439441
]
440442

441443
verify_grad(gammaincc, vals, rng=rng)
@@ -1127,9 +1129,13 @@ def test_benchmark(self, case, wrt, benchmark):
11271129
a1, a2, b1, z = pt.scalars("a1", "a2", "b1", "z")
11281130
hyp2f1_out = pt.hyp2f1(a1, a2, b1, z)
11291131
hyp2f1_grad = pt.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z])
1130-
f_grad = function([a1, a2, b1, z], hyp2f1_grad)
1132+
f_grad = function([a1, a2, b1, z], hyp2f1_grad, trust_input=True)
11311133

11321134
(test_a1, test_a2, test_b1, test_z, *expected_dds) = case
1135+
test_a1 = np.array(test_a1, dtype=a1.dtype)
1136+
test_a2 = np.array(test_a2, dtype=a2.dtype)
1137+
test_b1 = np.array(test_b1, dtype=b1.dtype)
1138+
test_z = np.array(test_z, dtype=z.dtype)
11331139

11341140
result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z)
11351141

0 commit comments

Comments
 (0)