Skip to content

Commit d6a3ddf

Browse files
committed
Add diff method to XTensorVariable
1 parent 86bd5e8 commit d6a3ddf

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

pytensor/xtensor/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,15 @@ def cumsum(self, dim):
500500
def cumprod(self, dim):
501501
return px.reduction.cumprod(self, dim)
502502

503+
def diff(self, dim, n=1):
504+
"""Compute the n-th discrete difference along the given dimension."""
505+
slice1 = {dim: slice(1, None)}
506+
slice2 = {dim: slice(None, -1)}
507+
x = self
508+
for _ in range(n):
509+
x = x[slice1] - x[slice2]
510+
return x
511+
503512

504513
class XTensorConstantSignature(tuple):
505514
def __eq__(self, other):

tests/xtensor/test_indexing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33
from xarray import DataArray
4+
from xtensor.util import xr_arange_like
45

56
from pytensor.xtensor import xtensor
67
from tests.xtensor.util import xr_assert_allclose, xr_function
@@ -40,3 +41,22 @@ def test_basic_indexing(labeled, indices):
4041
res = fn(x_test)
4142
expected_res = x_test[indices]
4243
xr_assert_allclose(res, expected_res)
44+
45+
46+
@pytest.mark.parametrize("n", ["implicit", 1, 2])
47+
@pytest.mark.parametrize("dim", ["a", "b"])
48+
def test_diff(dim, n):
49+
x = xtensor(dims=("a", "b"), shape=(7, 11))
50+
if n == "implicit":
51+
out = x.diff(dim)
52+
else:
53+
out = x.diff(dim, n=n)
54+
55+
fn = xr_function([x], out)
56+
x_test = xr_arange_like(x)
57+
res = fn(x_test)
58+
if n == "implicit":
59+
expected_res = x_test.diff(dim)
60+
else:
61+
expected_res = x_test.diff(dim, n=n)
62+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)