Skip to content

Commit 8150ef5

Browse files
committed
WIP add XTensorVariable properties and methods
1 parent de6e6d4 commit 8150ef5

File tree

1 file changed

+169
-8
lines changed

1 file changed

+169
-8
lines changed

pytensor/xtensor/type.py

Lines changed: 169 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import warnings
22

3+
from pytensor.tensor import TensorVariable, mul
4+
35

46
try:
57
import xarray as xr
@@ -133,10 +135,99 @@ def __complex__(self):
133135
"Call `.astype(complex)` for the symbolic equivalent."
134136
)
135137

138+
# DataArray-like attributes
139+
# https://docs.xarray.dev/en/latest/api.html#id1
140+
@property
141+
def values(self) -> TensorVariable:
142+
from pytensor.xtensor.basic import tensor_from_xtensor
143+
144+
return tensor_from_xtensor(self)
145+
146+
data = values
147+
148+
@property
149+
def coords(self):
150+
raise NotImplementedError("coords not implemented for XTensorVariable")
151+
152+
@property
153+
def dims(self) -> tuple[str]:
154+
return self.type.dims
155+
156+
@property
157+
def sizes(self) -> dict[str, TensorVariable]:
158+
return dict(zip(self.dims, self.shape))
159+
160+
@property
161+
def as_numpy(self):
162+
# No-op, since the underlying data is always a numpy array
163+
return self
164+
165+
# ndarray attributes
166+
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
167+
@property
168+
def ndim(self) -> int:
169+
return self.type.ndim
170+
171+
@property
172+
def shape(self) -> tuple[TensorVariable]:
173+
from pytensor.xtensor.basic import tensor_from_xtensor
174+
175+
return tuple(tensor_from_xtensor(self).shape)
176+
177+
@property
178+
def size(self):
179+
return mul(*self.shape)
180+
181+
@property
182+
def dtype(self):
183+
return self.type.dtype
184+
185+
# DataArray contents
186+
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
187+
def rename(self, new_name_or_name_dict, **names):
188+
from pytensor.xtensor.basic import rename
189+
190+
if isinstance(new_name_or_name_dict, str):
191+
# TODO: Should we make a symbolic copy?
192+
self.name = new_name_or_name_dict
193+
name_dict = None
194+
else:
195+
name_dict = new_name_or_name_dict
196+
return rename(name_dict, **names)
197+
198+
# def swap_dims(self, *args, **kwargs):
199+
# ...
200+
#
201+
# def expand_dims(self, *args, **kwargs):
202+
# ...
203+
#
204+
# def squeeze(self):
205+
# ...
206+
207+
def copy(self):
208+
from pytensor.xtensor.math import identity
209+
210+
return identity(self)
211+
212+
def astype(self, dtype):
213+
from pytensor.xtensor.math import cast
214+
215+
return cast(self, dtype)
216+
217+
def item(self):
218+
raise NotImplementedError("item not implemented for XTensorVariable")
219+
220+
# Indexing
221+
# https://docs.xarray.dev/en/latest/api.html#id2
136222
def __setitem__(self, key, value):
137-
raise TypeError(
138-
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
139-
)
223+
raise TypeError("XTensorVariable does not support item assignment.")
224+
225+
@property
226+
def loc(self):
227+
raise NotImplementedError("loc not implemented for XTensorVariable")
228+
229+
def sel(self, *args, **kwargs):
230+
raise NotImplementedError("sel not implemented for XTensorVariable")
140231

141232
def __getitem__(self, idx):
142233
from pytensor.xtensor.indexing import index
@@ -159,11 +250,6 @@ def __getitem__(self, idx):
159250

160251
return index(self, *idx)
161252

162-
def sel(self, *args, **kwargs):
163-
raise NotImplementedError(
164-
"sel not implemented for XTensorVariable, use isel instead"
165-
)
166-
167253
def isel(
168254
self,
169255
indexers: dict[str, Any] | None = None,
@@ -208,6 +294,81 @@ def isel(
208294

209295
return index(self, *indices)
210296

297+
def _head_tail_or_thin(
298+
self,
299+
indexers: dict[str, Any] | int | None,
300+
indexers_kwargs: dict[str, Any],
301+
*,
302+
kind: Literal["head", "tail", "thin"],
303+
):
304+
if indexers_kwargs:
305+
if indexers is not None:
306+
raise ValueError(
307+
"Cannot pass both indexers and indexers_kwargs to head"
308+
)
309+
indexers = indexers_kwargs
310+
311+
if indexers is None:
312+
if kind == "thin":
313+
raise TypeError(
314+
"thin() indexers must be either dict-like or a single integer"
315+
)
316+
else:
317+
# Default to 5 for head and tail
318+
indexers = {dim: 5 for dim in self.type.dims}
319+
320+
elif not isinstance(indexers, dict):
321+
indexers = {dim: indexers for dim in self.type.dims}
322+
323+
if kind == "head":
324+
indices = {dim: slice(None, value) for dim, value in indexers.items()}
325+
elif kind == "tail":
326+
sizes = self.sizes
327+
# Can't use slice(-value, None), in case value is zero
328+
indices = {
329+
dim: slice(sizes[dim] - value, None) for dim, value in indexers.items()
330+
}
331+
elif kind == "thin":
332+
indices = {dim: slice(None, None, value) for dim, value in indexers.items()}
333+
return self.isel(indices)
334+
335+
def head(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
336+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="head")
337+
338+
def tail(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
339+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="tail")
340+
341+
def thin(self, indexers: dict[str, Any] | int | None = None, **indexers_kwargs):
342+
return self._head_tail_or_thin(indexers, indexers_kwargs, kind="thin")
343+
344+
# ndarray methods
345+
# https://docs.xarray.dev/en/latest/api.html#id7
346+
def clip(self, min, max):
347+
from pytensor.xtensor.math import clip
348+
349+
return clip(self, min, max)
350+
351+
def conj(self):
352+
from pytensor.xtensor.math import conj
353+
354+
return conj(self)
355+
356+
@property
357+
def imag(self):
358+
from pytensor.xtensor.math import imag
359+
360+
return imag(self)
361+
362+
@property
363+
def real(self):
364+
from pytensor.xtensor.math import real
365+
366+
return real(self)
367+
368+
# @property
369+
# def T(self):
370+
# ...
371+
211372

212373
class XTensorConstantSignature(tuple):
213374
def __eq__(self, other):

0 commit comments

Comments
 (0)