Skip to content

Commit 29b954a

Browse files
committed
Add diff method to XTensorVariable
1 parent 30e1a42 commit 29b954a

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pytensor/xtensor/type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,15 @@ def cumsum(self, dim):
502502
def cumprod(self, dim):
503503
return px.reduction.cumprod(self, dim)
504504

505+
def diff(self, dim, n=1):
506+
"""Compute the n-th discrete difference along the given dimension."""
507+
slice1 = {dim: slice(1, None)}
508+
slice2 = {dim: slice(None, -1)}
509+
x = self
510+
for _ in range(n):
511+
x = x[slice1] - x[slice2]
512+
return x
513+
505514

506515
class XTensorConstantSignature(tuple):
507516
def __eq__(self, other):

tests/xtensor/test_indexing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,22 @@ def test_single_vector_indexing_interacting_with_exisiting_dim():
117117
res = fn(x_test, xidx_test)
118118
expected_res = x_test[xidx_test.rename(a="b"), 1:]
119119
xr_assert_allclose(res, expected_res)
120+
121+
122+
@pytest.mark.parametrize("n", ["implicit", 1, 2])
123+
@pytest.mark.parametrize("dim", ["a", "b"])
124+
def test_diff(dim, n):
125+
x = xtensor(dims=("a", "b"), shape=(7, 11))
126+
if n == "implicit":
127+
out = x.diff(dim)
128+
else:
129+
out = x.diff(dim, n=n)
130+
131+
fn = xr_function([x], out)
132+
x_test = xr_arange_like(x)
133+
res = fn(x_test)
134+
if n == "implicit":
135+
expected_res = x_test.diff(dim)
136+
else:
137+
expected_res = x_test.diff(dim, n=n)
138+
xr_assert_allclose(res, expected_res)

0 commit comments

Comments
 (0)