|
19 | 19 | from pytensor.scalar import constant as scalar_constant
|
20 | 20 | from pytensor.tensor.basic import (
|
21 | 21 | Alloc,
|
| 22 | + ExtractDiag, |
22 | 23 | Join,
|
23 | 24 | ScalarFromTensor,
|
24 | 25 | TensorFromScalar,
|
25 | 26 | alloc,
|
26 | 27 | cast,
|
27 | 28 | concatenate,
|
28 | 29 | expand_dims,
|
| 30 | + full, |
29 | 31 | get_scalar_constant_value,
|
30 | 32 | get_underlying_scalar_constant_value,
|
31 | 33 | register_infer_shape,
|
@@ -1793,3 +1795,96 @@ def ravel_multidimensional_int_idx(fgraph, node):
|
1793 | 1795 | "numba",
|
1794 | 1796 | use_db_name_as_tag=False, # Not included if only "specialize" is requested
|
1795 | 1797 | )
|
| 1798 | + |
| 1799 | + |
| 1800 | +@register_canonicalize |
| 1801 | +@register_stabilize |
| 1802 | +@register_specialize |
| 1803 | +@node_rewriter([ExtractDiag]) |
| 1804 | +def extract_diag_of_diagonal_set_subtensor(fgraph, node): |
| 1805 | + """Undo extract diagonal from a set diagonal |
| 1806 | +
|
| 1807 | + This rewrites the following pattern: |
| 1808 | + y = write_diagonal*(x, x_diag, offset=k1) |
| 1809 | + z = extract_diag(y, offset=k2) |
| 1810 | +
|
| 1811 | + as: |
| 1812 | + z = diag_x, if k1 == k2 |
| 1813 | + z = x if k1 != k2 |
| 1814 | +
|
| 1815 | + * write_diagonal is not an atomic operation, but a sequence of Arange/SetSubtensor operations. |
| 1816 | +
|
| 1817 | + """ |
| 1818 | + |
| 1819 | + def is_contant_arange(var) -> bool: |
| 1820 | + if not (isinstance(var, TensorConstant) and var.type.ndim == 1): |
| 1821 | + return False |
| 1822 | + |
| 1823 | + data = var.data |
| 1824 | + start, stop = data[0], data[-1] + 1 |
| 1825 | + return data.size == (stop - start) and (data == np.arange(start, stop)).all() |
| 1826 | + |
| 1827 | + [diag_x] = node.inputs |
| 1828 | + if not ( |
| 1829 | + diag_x.owner is not None |
| 1830 | + and isinstance(diag_x.owner.op, AdvancedIncSubtensor) |
| 1831 | + and diag_x.owner.op.set_instead_of_inc |
| 1832 | + ): |
| 1833 | + return None |
| 1834 | + |
| 1835 | + x, y, *idxs = diag_x.owner.inputs |
| 1836 | + |
| 1837 | + if not ( |
| 1838 | + x.type.ndim >= 2 |
| 1839 | + and None not in x.type.shape[-2:] |
| 1840 | + and x.type.shape[-2] == x.type.shape[-1] |
| 1841 | + ): |
| 1842 | + # TODO: for now we only support rewrite with static square shape for x |
| 1843 | + return None |
| 1844 | + |
| 1845 | + op = node.op |
| 1846 | + if op.axis2 > len(idxs): |
| 1847 | + return None |
| 1848 | + |
| 1849 | + # Check all non-axis indices are full slices |
| 1850 | + axis = {op.axis1, op.axis2} |
| 1851 | + if not all(is_full_slice(idx) for i, idx in enumerate(idxs) if i not in axis): |
| 1852 | + return None |
| 1853 | + |
| 1854 | + # Check axis indices are arange we would expect from setting on the diagonal |
| 1855 | + axis1_idx = idxs[op.axis1] |
| 1856 | + axis2_idx = idxs[op.axis2] |
| 1857 | + if not (is_contant_arange(axis1_idx) and is_contant_arange(axis2_idx)): |
| 1858 | + return None |
| 1859 | + |
| 1860 | + dim_length = x.type.shape[-1] |
| 1861 | + offset = op.offset |
| 1862 | + start_stop1 = (axis1_idx.data[0], axis1_idx.data[-1] + 1) |
| 1863 | + start_stop2 = (axis2_idx.data[0], axis2_idx.data[-1] + 1) |
| 1864 | + orig_start1, orig_start2 = start_stop1[0], start_stop2[0] |
| 1865 | + |
| 1866 | + if offset < 0: |
| 1867 | + # The logic for checking if we are selecting or not a diagonal for negative offset is the same |
| 1868 | + # as the one with positive offset but swapped axis |
| 1869 | + start_stop1, start_stop2 = start_stop2, start_stop1 |
| 1870 | + offset = -offset |
| 1871 | + |
| 1872 | + start1, stop1 = start_stop1 |
| 1873 | + start2, stop2 = start_stop2 |
| 1874 | + if ( |
| 1875 | + start1 == 0 |
| 1876 | + and start2 == offset |
| 1877 | + and stop1 == dim_length - offset |
| 1878 | + and stop2 == dim_length |
| 1879 | + ): |
| 1880 | + # We are extracting the just written diagonal |
| 1881 | + if y.type.ndim == 0 or y.type.shape[-1] == 1: |
| 1882 | + # We may need to broadcast y |
| 1883 | + y = full((*x.shape[:-2], dim_length - offset), y, dtype=x.type.dtype) |
| 1884 | + return [y] |
| 1885 | + elif (orig_start2 - orig_start1) != op.offset: |
| 1886 | + # Some other diagonal was written, ignore it |
| 1887 | + return [op(x)] |
| 1888 | + else: |
| 1889 | + # A portion, but no the whole diagonal was written, don't do anything |
| 1890 | + return None |
0 commit comments