Skip to content

Add rewrite to optimize block_diag(a, b) @ c #1044

Open
@jessegrabowski

Description

@jessegrabowski

Description

Given a block diagonal matrix formed of matrices $A \in \mathbb R^{n \times m}$ and $B^{o \times p}$, such that:

$$ D = \begin{bmatrix} A & 0 \\ 0 & B \end{bmatrix} $$

Computing the matrix multiplication $DC$ can be simplified. Define $C_1 \in \mathbb R^{n \times k}$ and $C_2^{o \times k}$ such that:

$$ C = \begin{bmatrix} C_1 \\ C_2 \end{bmatrix} $$

then:

$$ \begin{align} DC &= \begin{bmatrix} A & 0 \\ 0 & B \end{bmatrix} \begin{bmatrix} C_1 \\ C_2 \end{bmatrix} \\ &= \begin{bmatrix} A C_1 \\ B C_2 \end{bmatrix} \end{align} $$

We can compute these smaller dot products then concatenate the results back together for a speedup. Code:

import numpy as np
from scipy import linalg

rng = np.random.default_rng()
n = 1000
A, B = rng.normal(size=(2, n, n))
C = rng.normal(size=(2*n, 2*n))

A, B, C = map(np.ascontiguousarray, [A, B, C])

def direct(A, B, C):
    X = linalg.block_diag(A, B)
    return X @ C

def rewrite(A, B, C):
    n = A.shape[0]
    C_1, C_2 = C[:n], C[n:]
    return np.concatenate([A @ C_1, B @ C_2])
    

np.allclose(direct(A, B, C), rewrite(A, B, C)) # True

Speed test:

direct_time = %timeit -o direct(A, B, C)
rewrite_time = %timeit -o rewrite(A, B, C)

speedup_factor = (1 - rewrite_time.best / direct_time.best)
print(f'{speedup_factor:0.2%}')

75.8 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
43.5 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
40.77%

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions