@@ -431,11 +431,13 @@ def test_gammaincc_ddk_performance(benchmark):
431
431
x = vector ("x" )
432
432
433
433
out = gammaincc (k , x )
434
- grad_fn = function ([k , x ], grad (out .sum (), wrt = [k ]), mode = "FAST_RUN" )
434
+ grad_fn = function (
435
+ [k , x ], grad (out .sum (), wrt = [k ]), mode = "FAST_RUN" , trust_input = True
436
+ )
435
437
vals = [
436
438
# Values that hit the second branch of the gradient
437
- np .full ((1000 ,), 3.2 ),
438
- np .full ((1000 ,), 0.01 ),
439
+ np .full ((1000 ,), 3.2 , dtype = k . dtype ),
440
+ np .full ((1000 ,), 0.01 , dtype = x . dtype ),
439
441
]
440
442
441
443
verify_grad (gammaincc , vals , rng = rng )
@@ -1127,9 +1129,13 @@ def test_benchmark(self, case, wrt, benchmark):
1127
1129
a1 , a2 , b1 , z = pt .scalars ("a1" , "a2" , "b1" , "z" )
1128
1130
hyp2f1_out = pt .hyp2f1 (a1 , a2 , b1 , z )
1129
1131
hyp2f1_grad = pt .grad (hyp2f1_out , wrt = a1 if wrt == "a" else [a1 , a2 , b1 , z ])
1130
- f_grad = function ([a1 , a2 , b1 , z ], hyp2f1_grad )
1132
+ f_grad = function ([a1 , a2 , b1 , z ], hyp2f1_grad , trust_input = True )
1131
1133
1132
1134
(test_a1 , test_a2 , test_b1 , test_z , * expected_dds ) = case
1135
+ test_a1 = np .array (test_a1 , dtype = a1 .dtype )
1136
+ test_a2 = np .array (test_a2 , dtype = a2 .dtype )
1137
+ test_b1 = np .array (test_b1 , dtype = b1 .dtype )
1138
+ test_z = np .array (test_z , dtype = z .dtype )
1133
1139
1134
1140
result = benchmark (f_grad , test_a1 , test_a2 , test_b1 , test_z )
1135
1141
0 commit comments