1
1
import warnings
2
2
3
+ from pytensor .tensor import TensorVariable , mul
4
+
3
5
4
6
try :
5
7
import xarray as xr
@@ -133,10 +135,99 @@ def __complex__(self):
133
135
"Call `.astype(complex)` for the symbolic equivalent."
134
136
)
135
137
138
+ # DataArray-like attributes
139
+ # https://docs.xarray.dev/en/latest/api.html#id1
140
+ @property
141
+ def values (self ) -> TensorVariable :
142
+ from pytensor .xtensor .basic import tensor_from_xtensor
143
+
144
+ return tensor_from_xtensor (self )
145
+
146
+ data = values
147
+
148
+ @property
149
+ def coords (self ):
150
+ raise NotImplementedError ("coords not implemented for XTensorVariable" )
151
+
152
+ @property
153
+ def dims (self ) -> tuple [str ]:
154
+ return self .type .dims
155
+
156
+ @property
157
+ def sizes (self ) -> dict [str , TensorVariable ]:
158
+ return dict (zip (self .dims , self .shape ))
159
+
160
+ @property
161
+ def as_numpy (self ):
162
+ # No-op, since the underlying data is always a numpy array
163
+ return self
164
+
165
+ # ndarray attributes
166
+ # https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
167
+ @property
168
+ def ndim (self ) -> int :
169
+ return self .type .ndim
170
+
171
+ @property
172
+ def shape (self ) -> tuple [TensorVariable ]:
173
+ from pytensor .xtensor .basic import tensor_from_xtensor
174
+
175
+ return tuple (tensor_from_xtensor (self ).shape )
176
+
177
+ @property
178
+ def size (self ):
179
+ return mul (* self .shape )
180
+
181
+ @property
182
+ def dtype (self ):
183
+ return self .type .dtype
184
+
185
+ # DataArray contents
186
+ # https://docs.xarray.dev/en/latest/api.html#dataarray-contents
187
+ def rename (self , new_name_or_name_dict , ** names ):
188
+ from pytensor .xtensor .basic import rename
189
+
190
+ if isinstance (new_name_or_name_dict , str ):
191
+ # TODO: Should we make a symbolic copy?
192
+ self .name = new_name_or_name_dict
193
+ name_dict = None
194
+ else :
195
+ name_dict = new_name_or_name_dict
196
+ return rename (name_dict , ** names )
197
+
198
+ # def swap_dims(self, *args, **kwargs):
199
+ # ...
200
+ #
201
+ # def expand_dims(self, *args, **kwargs):
202
+ # ...
203
+ #
204
+ # def squeeze(self):
205
+ # ...
206
+
207
+ def copy (self ):
208
+ from pytensor .xtensor .math import identity
209
+
210
+ return identity (self )
211
+
212
+ def astype (self , dtype ):
213
+ from pytensor .xtensor .math import cast
214
+
215
+ return cast (self , dtype )
216
+
217
+ def item (self ):
218
+ raise NotImplementedError ("item not implemented for XTensorVariable" )
219
+
220
+ # Indexing
221
+ # https://docs.xarray.dev/en/latest/api.html#id2
136
222
def __setitem__ (self , key , value ):
137
- raise TypeError (
138
- "XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
139
- )
223
+ raise TypeError ("XTensorVariable does not support item assignment." )
224
+
225
+ @property
226
+ def loc (self ):
227
+ raise NotImplementedError ("loc not implemented for XTensorVariable" )
228
+
229
+ def sel (self , * args , ** kwargs ):
230
+ raise NotImplementedError ("sel not implemented for XTensorVariable" )
140
231
141
232
def __getitem__ (self , idx ):
142
233
from pytensor .xtensor .indexing import index
@@ -159,11 +250,6 @@ def __getitem__(self, idx):
159
250
160
251
return index (self , * idx )
161
252
162
- def sel (self , * args , ** kwargs ):
163
- raise NotImplementedError (
164
- "sel not implemented for XTensorVariable, use isel instead"
165
- )
166
-
167
253
def isel (
168
254
self ,
169
255
indexers : dict [str , Any ] | None = None ,
@@ -208,6 +294,81 @@ def isel(
208
294
209
295
return index (self , * indices )
210
296
297
+ def _head_tail_or_thin (
298
+ self ,
299
+ indexers : dict [str , Any ] | int | None ,
300
+ indexers_kwargs : dict [str , Any ],
301
+ * ,
302
+ kind : Literal ["head" , "tail" , "thin" ],
303
+ ):
304
+ if indexers_kwargs :
305
+ if indexers is not None :
306
+ raise ValueError (
307
+ "Cannot pass both indexers and indexers_kwargs to head"
308
+ )
309
+ indexers = indexers_kwargs
310
+
311
+ if indexers is None :
312
+ if kind == "thin" :
313
+ raise TypeError (
314
+ "thin() indexers must be either dict-like or a single integer"
315
+ )
316
+ else :
317
+ # Default to 5 for head and tail
318
+ indexers = {dim : 5 for dim in self .type .dims }
319
+
320
+ elif not isinstance (indexers , dict ):
321
+ indexers = {dim : indexers for dim in self .type .dims }
322
+
323
+ if kind == "head" :
324
+ indices = {dim : slice (None , value ) for dim , value in indexers .items ()}
325
+ elif kind == "tail" :
326
+ sizes = self .sizes
327
+ # Can't use slice(-value, None), in case value is zero
328
+ indices = {
329
+ dim : slice (sizes [dim ] - value , None ) for dim , value in indexers .items ()
330
+ }
331
+ elif kind == "thin" :
332
+ indices = {dim : slice (None , None , value ) for dim , value in indexers .items ()}
333
+ return self .isel (indices )
334
+
335
+ def head (self , indexers : dict [str , Any ] | int | None = None , ** indexers_kwargs ):
336
+ return self ._head_tail_or_thin (indexers , indexers_kwargs , kind = "head" )
337
+
338
+ def tail (self , indexers : dict [str , Any ] | int | None = None , ** indexers_kwargs ):
339
+ return self ._head_tail_or_thin (indexers , indexers_kwargs , kind = "tail" )
340
+
341
+ def thin (self , indexers : dict [str , Any ] | int | None = None , ** indexers_kwargs ):
342
+ return self ._head_tail_or_thin (indexers , indexers_kwargs , kind = "thin" )
343
+
344
+ # ndarray methods
345
+ # https://docs.xarray.dev/en/latest/api.html#id7
346
+ def clip (self , min , max ):
347
+ from pytensor .xtensor .math import clip
348
+
349
+ return clip (self , min , max )
350
+
351
+ def conj (self ):
352
+ from pytensor .xtensor .math import conj
353
+
354
+ return conj (self )
355
+
356
+ @property
357
+ def imag (self ):
358
+ from pytensor .xtensor .math import imag
359
+
360
+ return imag (self )
361
+
362
+ @property
363
+ def real (self ):
364
+ from pytensor .xtensor .math import real
365
+
366
+ return real (self )
367
+
368
+ # @property
369
+ # def T(self):
370
+ # ...
371
+
211
372
212
373
class XTensorConstantSignature (tuple ):
213
374
def __eq__ (self , other ):
0 commit comments