Skip to content

Matrix determinant lemma example shows several optimizations misses #1040

Open
@ricardoV94

Description

@ricardoV94

Description

https://en.wikipedia.org/wiki/Matrix_determinant_lemma

Ideally, in the example below we would rewrite from the naive out to opt_out, when we can be confident that computing the determinant of A plus solving v.T A^-1 u should be cheaper than computing the determinant of B.

In practice, this may require doing some tentative rewrites and potentially evaluating flops (#875 would be nice), and backtracking if we don't find anything useful for the det of A.

The example below is setup so this is the case, by making A a diagonal matrix.

import numpy as np
import pytensor
import pytensor.tensor as pt

a = pt.vector("a")
A = pt.diag(a)
u = pt.vector("u")
v = pt.vector("v")

B = A + pt.outer(u, v)
out = pt.linalg.det(B)
fn = pytensor.function([a, u, v], out)

alt_out = (1 + v @ pt.linalg.solve(A, u)) * pt.linalg.det(A)
alt_fn = pytensor.function([a, u, v], alt_out)

opt_out = (1 + v @ ((1 / a) * u)) * pt.prod(a)
opt_fn = pytensor.function([a, u, v], opt_out)

rng = np.random.default_rng()
a_test, u_test, v_test = rng.normal(size=(3, 40))

np.testing.assert_allclose(
    fn(a_test, u_test, v_test),
    alt_fn(a_test, u_test, v_test),
)
np.testing.assert_allclose(
    fn(a_test, u_test, v_test),
    opt_fn(a_test, u_test, v_test),
)

%timeit fn(a_test, u_test, v_test)  # 49.2 μs ± 1.45 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit alt_fn(a_test, u_test, v_test)  # 110 μs ± 10.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%timeit opt_fn(a_test, u_test, v_test)  # 17.9 μs ± 1.05 μs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

The intermediate form alt_out highlights some extra missing rewrites, such as the diagonal solve and vector @ diagonal matrix products.

The matrix determinant lemma equivalence can be further beneficial if the graph also requires a solve/inverse of B, as some of the new expressions can be reused following: https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula

Another cases besides diagonal A, would be triangular or circulant matrices. There's a nice blog on the latter.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions