Skip to content

Commit 61aedd7

Browse files
committed
Implement casting for XTensorVariables
1 parent f8dbd5c commit 61aedd7

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

pytensor/xtensor/math.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import inspect
22
import sys
33

4+
import numpy as np
5+
46
import pytensor.scalar as ps
7+
from pytensor import config
58
from pytensor.scalar import ScalarOp
9+
from pytensor.scalar.basic import _cast_mapping
10+
from pytensor.xtensor.basic import as_xtensor
611
from pytensor.xtensor.vectorization import XElemwise
712

813

@@ -29,3 +34,26 @@ def get_all_scalar_ops():
2934

3035
for name, op in get_all_scalar_ops().items():
3136
setattr(this_module, name, op)
37+
38+
39+
_xelemwise_cast_op: dict[str, XElemwise] = {}
40+
41+
42+
def cast(x, dtype):
43+
if dtype == "floatX":
44+
dtype = config.floatX
45+
else:
46+
dtype = np.dtype(dtype).name
47+
48+
x = as_xtensor(x)
49+
if x.type.dtype == dtype:
50+
return x
51+
if x.type.dtype.startswith("complex") and not dtype.startswith("complex"):
52+
raise TypeError(
53+
"Casting from complex to real is ambiguous: consider"
54+
" real(), imag(), angle() or abs()"
55+
)
56+
57+
if dtype not in _xelemwise_cast_op:
58+
_xelemwise_cast_op[dtype] = XElemwise(scalar_op=_cast_mapping[dtype])
59+
return _xelemwise_cast_op[dtype](x)

tests/xtensor/test_math.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: E402
22
import pytest
3+
from xtensor.util import xr_arange_like
34

45

56
pytest.importorskip("xarray") #
@@ -107,3 +108,21 @@ def test_multiple_constant():
107108
res = fn(x_test)
108109
expected_res = np.exp(x_test * 2) + 2
109110
np.testing.assert_allclose(res, expected_res)
111+
112+
113+
def test_cast():
114+
x = xtensor("x", shape=(2, 3), dims=("a", "b"), dtype="float32")
115+
yf64 = x.astype("float64")
116+
yi16 = x.astype("int16")
117+
ybool = x.astype("bool")
118+
119+
fn = xr_function([x], [yf64, yi16, ybool])
120+
x_test = xr_arange_like(x)
121+
res_f64, res_i16, res_bool = fn(x_test)
122+
xr_assert_allclose(res_f64, x_test.astype("float64"))
123+
xr_assert_allclose(res_i16, x_test.astype("int16"))
124+
xr_assert_allclose(res_bool, x_test.astype("bool"))
125+
126+
yc64 = x.astype("complex64")
127+
with pytest.raises(TypeError, match="Casting from complex to real is ambiguous"):
128+
yc64.astype("float64")

0 commit comments

Comments
 (0)