Skip to content

XTensorVariable indexing update #1438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jun 2, 2025

This builds on top of #1429 to support index assignment and increment x[idx] = y and x[idx] += y, following xarray indexing and broadcasting semantics.

As with regular TensorVariables, due to immutability/hashability constraints of PyTensor, it is not possible to use the python native operations, so one has to write z = x[idx].set(y) and z = x[idx].inc(y)


📚 Documentation preview 📚: https://pytensor--1438.org.readthedocs.build/en/1438/

@ricardoV94 ricardoV94 requested a review from OriolAbril June 2, 2025 09:38
@ricardoV94 ricardoV94 changed the base branch from main to labeled_tensors June 2, 2025 09:39
@ricardoV94 ricardoV94 changed the title Labeled indexing update XTensorVariable indexing update Jun 2, 2025
@ricardoV94 ricardoV94 force-pushed the labeled_indexing_update branch from 4e105b4 to e5fbf7a Compare June 3, 2025 16:14
@ricardoV94 ricardoV94 marked this pull request as ready for review June 3, 2025 16:15
@ricardoV94 ricardoV94 force-pushed the labeled_indexing_update branch from e5fbf7a to 04f057c Compare June 3, 2025 16:17
Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also unable to find a combination of indexes/setter data that I think should work but doesn't.

This one already works but I'd suggest adding as a test case too:

x = xtensor("x", shape=(5, 5), dims=("a", "b"))
y = xtensor("y", shape=(3,), dims=("d",))
idx = xtensor("idx", dtype=int, shape=(None,), dims=("d",))

# define "d" dimension by slicing the "a" dimension
# so we can set y into x
orthogonal_update1 = x[idx].set(y)
fn = xr_function([x, idx, y], [orthogonal_update1])

x_test = np.abs(xr_random_like(x))
y_test = -np.abs(xr_random_like(y))
idx_test = DataArray([0, 2, 3], dims=("d",))

result = fn(x_test, idx_test, y_test)
x_test[idx_test] = y_test
expected_result = [x_test]
xr_assert_allclose(result, expected_result)

@ricardoV94
Copy link
Member Author

This one already works but I'd suggest adding as a test case too:

Will do!

@ricardoV94 ricardoV94 force-pushed the labeled_indexing_update branch from 04f057c to 2ad906a Compare June 4, 2025 15:10
@ricardoV94 ricardoV94 merged commit ea690e6 into pymc-devs:labeled_tensors Jun 4, 2025
4 of 5 checks passed
@ricardoV94 ricardoV94 deleted the labeled_indexing_update branch June 4, 2025 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants