diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 06d023d780..ba971c78da 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -190,6 +190,43 @@ def local_0_dot_x(fgraph, node): return [constant_zero] +@register_canonicalize +@register_stabilize +@node_rewriter([Dot]) +def local_1_dot_x(fgraph, node): + if not isinstance(node.op, Dot): + return False + + x = node.inputs[0] + y = node.inputs[1] + replace = False + try: + if get_underlying_scalar_constant_value(x, only_process_constants=True) == 1: + replace = True + var = y + except NotScalarConstantError: + pass + + try: + if get_underlying_scalar_constant_value(y, only_process_constants=True) == 1: + replace = True + var = x + except NotScalarConstantError: + pass + + if replace: + new_out = var + old_out = node.outputs[0] + + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + if not old_out.type.is_super(new_out.type): + new_out = new_out.reshape(old_out.shape) + # new_out = alloc_like(new_out, old_out, fgraph) + + return [new_out] + + @register_canonicalize @node_rewriter([DimShuffle]) def local_lift_transpose_through_dot(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 84322989bf..a9b9c7138c 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -4475,3 +4475,22 @@ def test_local_batched_matmul_to_core_matmul(): x_test = rng.normal(size=(5, 3, 2)) y_test = rng.normal(size=(5, 2, 2)) np.testing.assert_allclose(fn(x_test, y_test), x_test @ y_test) + + +@pytest.mark.parametrize( + "x", + ( + pt.col("x"), + fmatrix("x"), + vector("x"), + pt.tensor("x", shape=(1, 3, 2), dtype="float64"), + ), +) +def test_mul_with_1(x): + f = x @ [[1.0]] + with pytensor.config.change_flags(optimizer_verbose=True): + fn = pytensor.function([x], f, mode=get_default_mode().excluding("BlasOpt")) + + pytensor.dprint(fn) + + assert 0