3
3
from pytensor .configdefaults import config
4
4
from pytensor .graph .fg import FunctionGraph
5
5
from pytensor .graph .op import get_test_value
6
- from pytensor .tensor .type import matrix , scalar , vector
6
+ from pytensor .tensor .type import matrix , scalar , tensor3 , vector
7
7
from tests .link .pytorch .test_basic import compare_pytorch_and_py
8
8
9
9
10
- def test_tensor_basics ():
10
+ def test_pytorch_dot ():
11
+ a = tensor3 ("a" )
12
+ a .tag .test_value = np .zeros ((3 , 2 , 4 )).astype (config .floatX )
13
+ b = tensor3 ("b" )
14
+ b .tag .test_value = np .zeros ((3 , 4 , 1 )).astype (config .floatX )
11
15
y = vector ("y" )
12
16
y .tag .test_value = np .r_ [1.0 , 2.0 ].astype (config .floatX )
13
17
x = vector ("x" )
@@ -19,12 +23,17 @@ def test_tensor_basics():
19
23
beta = scalar ("beta" )
20
24
beta .tag .test_value = np .array (5.0 , dtype = config .floatX )
21
25
22
- # 1D * 2D * 1D
23
- out = y .dot (alpha * A ). dot ( x ) + beta * y
24
- fgraph = FunctionGraph ([y , x , A , alpha , beta ], [out ])
26
+ # 3D * 3D
27
+ out = a .dot (b * alpha ) + beta * b
28
+ fgraph = FunctionGraph ([a , b , alpha , beta ], [out ])
25
29
compare_pytorch_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
26
30
27
31
# 2D * 2D
28
32
out = A .dot (A * alpha ) + beta * A
29
33
fgraph = FunctionGraph ([A , alpha , beta ], [out ])
30
34
compare_pytorch_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
35
+
36
+ # 1D * 2D and 1D * 1D
37
+ out = y .dot (alpha * A ).dot (x ) + beta * y
38
+ fgraph = FunctionGraph ([y , x , A , alpha , beta ], [out ])
39
+ compare_pytorch_and_py (fgraph , [get_test_value (i ) for i in fgraph .inputs ])
0 commit comments