@@ -1081,42 +1081,22 @@ def test_banded_dot(A_shape, kl, ku):
1081
1081
res = banded_dot (A , b , kl , ku )
1082
1082
res_2 = A @ b
1083
1083
1084
- fn = function ([A , b ], [res , res_2 ])
1084
+ fn = function ([A , b ], [res , res_2 ], trust_input = True )
1085
1085
assert any (isinstance (node .op , BandedDot ) for node in fn .maker .fgraph .apply_nodes )
1086
1086
1087
1087
x_val , x2_val = fn (A_val , b_val )
1088
1088
1089
1089
np .testing .assert_allclose (x_val , x2_val )
1090
1090
1091
1091
1092
+ @pytest .mark .parametrize ("op" , ["dot" , "banded_dot" ], ids = str )
1092
1093
@pytest .mark .parametrize (
1093
1094
"A_shape" , [(10 , 10 ), (100 , 100 ), (1000 , 1000 )], ids = ["10" , "100" , "1000" ]
1094
1095
)
1095
1096
@pytest .mark .parametrize (
1096
1097
"kl, ku" , [(1 , 1 ), (0 , 1 ), (2 , 2 )], ids = ["tridiag" , "upper-only" , "banded" ]
1097
1098
)
1098
- def test_banded_dot_perf (A_shape , kl , ku , benchmark ):
1099
- rng = np .random .default_rng ()
1100
-
1101
- A_val = _make_banded_A (rng .normal (size = A_shape ), kl = kl , ku = ku )
1102
- b_val = rng .normal (size = (A_shape [- 1 ],))
1103
-
1104
- A = pt .tensor ("A" , shape = A_val .shape , dtype = A_val .dtype )
1105
- b = pt .tensor ("b" , shape = b_val .shape , dtype = b_val .dtype )
1106
-
1107
- res = banded_dot (A , b , kl , ku )
1108
- fn = function ([A , b ], res , trust_input = True )
1109
-
1110
- benchmark (fn , A_val , b_val )
1111
-
1112
-
1113
- @pytest .mark .parametrize (
1114
- "A_shape" , [(10 , 10 ), (100 , 100 ), (1000 , 1000 )], ids = ["10" , "100" , "1000" ]
1115
- )
1116
- @pytest .mark .parametrize (
1117
- "kl, ku" , [(1 , 1 ), (0 , 1 ), (2 , 2 )], ids = ["tridiag" , "upper-only" , "banded" ]
1118
- )
1119
- def test_dot_perf (A_shape , kl , ku , benchmark ):
1099
+ def test_banded_dot_perf (op , A_shape , kl , ku , benchmark ):
1120
1100
rng = np .random .default_rng ()
1121
1101
1122
1102
A_val = _make_banded_A (rng .normal (size = A_shape ), kl = kl , ku = ku )
@@ -1125,7 +1105,12 @@ def test_dot_perf(A_shape, kl, ku, benchmark):
1125
1105
A = pt .tensor ("A" , shape = A_val .shape )
1126
1106
b = pt .tensor ("b" , shape = b_val .shape )
1127
1107
1128
- res = A @ b
1129
- fn = function ([A , b ], res )
1108
+ if op == "dot" :
1109
+ f = pt .dot
1110
+ elif op == "banded_dot" :
1111
+ f = functools .partial (banded_dot , lower_diags = kl , upper_diags = ku )
1112
+
1113
+ res = f (A , b )
1114
+ fn = function ([A , b ], res , trust_input = True )
1130
1115
1131
1116
benchmark (fn , A_val , b_val )
0 commit comments