Skip to content

Commit 28789ec

Browse files
committed
Rewrite ExtractDiagonal of AllocDiagonal
1 parent 07d97ad commit 28789ec

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@
1919
from pytensor.scalar import constant as scalar_constant
2020
from pytensor.tensor.basic import (
2121
Alloc,
22+
ExtractDiag,
2223
Join,
2324
ScalarFromTensor,
2425
TensorFromScalar,
2526
alloc,
2627
cast,
2728
concatenate,
2829
expand_dims,
30+
full,
2931
get_scalar_constant_value,
3032
get_underlying_scalar_constant_value,
3133
register_infer_shape,
@@ -1793,3 +1795,96 @@ def ravel_multidimensional_int_idx(fgraph, node):
17931795
"numba",
17941796
use_db_name_as_tag=False, # Not included if only "specialize" is requested
17951797
)
1798+
1799+
1800+
@register_canonicalize
1801+
@register_stabilize
1802+
@register_specialize
1803+
@node_rewriter([ExtractDiag])
1804+
def extract_diag_of_diagonal_set_subtensor(fgraph, node):
1805+
"""Undo extract diagonal from a set diagonal
1806+
1807+
This rewrites the following pattern:
1808+
y = write_diagonal*(x, x_diag, offset=k1)
1809+
z = extract_diag(y, offset=k2)
1810+
1811+
as:
1812+
z = diag_x, if k1 == k2
1813+
z = x if k1 != k2
1814+
1815+
* write_diagonal is not an atomic operation, but a sequence of Arange/SetSubtensor operations.
1816+
1817+
"""
1818+
1819+
def is_contant_arange(var) -> bool:
1820+
if not (isinstance(var, TensorConstant) and var.type.ndim == 1):
1821+
return False
1822+
1823+
data = var.data
1824+
start, stop = data[0], data[-1] + 1
1825+
return data.size == (stop - start) and (data == np.arange(start, stop)).all()
1826+
1827+
[diag_x] = node.inputs
1828+
if not (
1829+
diag_x.owner is not None
1830+
and isinstance(diag_x.owner.op, AdvancedIncSubtensor)
1831+
and diag_x.owner.op.set_instead_of_inc
1832+
):
1833+
return None
1834+
1835+
x, y, *idxs = diag_x.owner.inputs
1836+
1837+
if not (
1838+
x.type.ndim >= 2
1839+
and None not in x.type.shape[-2:]
1840+
and x.type.shape[-2] == x.type.shape[-1]
1841+
):
1842+
# TODO: for now we only support rewrite with static square shape for x
1843+
return None
1844+
1845+
op = node.op
1846+
if op.axis2 > len(idxs):
1847+
return None
1848+
1849+
# Check all non-axis indices are full slices
1850+
axis = {op.axis1, op.axis2}
1851+
if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis):
1852+
return None
1853+
1854+
# Check axis indices are arange we would expect from setting on the diagonal
1855+
axis1_idx = idxs[op.axis1]
1856+
axis2_idx = idxs[op.axis2]
1857+
if not (is_contant_arange(axis1_idx) and is_contant_arange(axis2_idx)):
1858+
return None
1859+
1860+
dim_length = x.type.shape[-1]
1861+
offset = op.offset
1862+
start_stop1 = (axis1_idx.data[0], axis1_idx.data[-1] + 1)
1863+
start_stop2 = (axis2_idx.data[0], axis2_idx.data[-1] + 1)
1864+
orig_start1, orig_start2 = start_stop1[0], start_stop2[0]
1865+
1866+
if offset < 0:
1867+
# The logic for checking if we are selecting or not a diagonal for negative offset is the same
1868+
# as the one with positive offset but swapped axis
1869+
start_stop1, start_stop2 = start_stop2, start_stop1
1870+
offset = -offset
1871+
1872+
start1, stop1 = start_stop1
1873+
start2, stop2 = start_stop2
1874+
if (
1875+
start1 == 0
1876+
and start2 == offset
1877+
and stop1 == dim_length - offset
1878+
and stop2 == dim_length
1879+
):
1880+
# We are extracting the just written diagonal
1881+
if y.type.ndim == 0 or y.type.shape[-1] == 1:
1882+
# We may need to broadcast y
1883+
y = full((*x.shape[:-2], dim_length - offset), y, dtype=x.type.dtype)
1884+
return [y]
1885+
elif (orig_start2 - orig_start1) != op.offset:
1886+
# Some other diagonal was written, ignore it
1887+
return [op(x)]
1888+
else:
1889+
# A portion, but no the whole diagonal was written, don't do anything
1890+
return None

tests/tensor/rewriting/test_subtensor.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import random
2+
13
import numpy as np
24
import pytest
35

@@ -9,7 +11,7 @@
911
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1012
from pytensor.compile.ops import DeepCopyOp
1113
from pytensor.configdefaults import config
12-
from pytensor.graph import vectorize_graph
14+
from pytensor.graph import rewrite_graph, vectorize_graph
1315
from pytensor.graph.basic import Constant, Variable, ancestors, equal_computations
1416
from pytensor.graph.rewriting.basic import check_stack_trace
1517
from pytensor.raise_op import Assert
@@ -1956,3 +1958,37 @@ def test_unknown_step(self):
19561958
f(test_x, -2),
19571959
test_x[0:3:-2, -1:-6:2, ::],
19581960
)
1961+
1962+
1963+
def test_extract_diag_of_diagonal_set_subtensor():
1964+
A = pt.full((2, 6, 6), np.nan)
1965+
rows = pt.arange(A.shape[-2])
1966+
cols = pt.arange(A.shape[-1])
1967+
write_offsets = [-2, -1, 0, 1, 2]
1968+
# Randomize order of write operations, to make sure rewrite is not sensitive to it
1969+
random.shuffle(write_offsets)
1970+
for offset in write_offsets:
1971+
value = offset + 0.1 * offset
1972+
if offset == 0:
1973+
A = A[..., rows, cols].set(value)
1974+
elif offset > 0:
1975+
A = A[..., rows[:-offset], cols[offset:]].set(value)
1976+
else:
1977+
offset = -offset
1978+
A = A[..., rows[offset:], cols[:-offset]].set(value)
1979+
# Add a partial diagonal along offset 3
1980+
A = A[..., rows[1:-3], cols[4:]].set(np.pi)
1981+
1982+
read_offsets = [-2, -1, 0, 1, 2, 3]
1983+
outs = [A.diagonal(offset=offset, axis1=-2, axis2=-1) for offset in read_offsets]
1984+
rewritten_outs = rewrite_graph(outs, include=("ShapeOpt", "canonicalize"))
1985+
1986+
# Every output should just be an Alloc with value
1987+
expected_outs = []
1988+
for offset in read_offsets[:-1]:
1989+
value = np.asarray(offset + 0.1 * offset, dtype=A.type.dtype)
1990+
expected_outs.append(pt.full((np.int64(2), np.int8(6 - abs(offset))), value))
1991+
# The partial diagonal shouldn't be rewritten
1992+
expected_outs.append(outs[-1])
1993+
1994+
assert equal_computations(rewritten_outs, expected_outs)

0 commit comments

Comments
 (0)