Skip to content

@ operator returns wrong graph for batched matrices #451

Closed
@ricardoV94

Description

@ricardoV94

Description

import numpy as np
import pytensor.tensor as pt

x = pt.tensor("x", shape=(10, 3, 3))
x_val = np.random.normal(size=x.type.shape)

np.testing.assert_allclose(
  (x @ x).eval({x: x_val}),
  x_val @ x_val,
)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0
(shapes (10, 3, 10, 3), (10, 3, 3) mismatch)
 x: array([[[[ 3.433369e-01, -2.325831e+00, -1.113984e+00],
         [-7.958884e-02,  1.460045e+00, -1.120295e+00],
         [-4.348782e-01, -4.313117e-01, -6.228517e-02],...
 y: array([[[ 0.343337, -2.325831, -1.113984],
        [ 0.614024,  0.805719, -0.136424],
        [-1.463383, -2.395001,  4.435603]],...

This happens because the __matmul__ method is not returning a pt.matmul which works correctly

def __dot__(left, right):
return at.math.dense_dot(left, right)
def __rdot__(right, left):
return at.math.dense_dot(left, right)
dot = __dot__
__matmul__ = __dot__
__rmatmul__ = __rdot__

np.testing.assert_allclose(
  pt.matmul(x, x).eval({x: x_val}),
  x_val @ x_val,
)  # Fine

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions