Skip to content

Commit 8f213a3

Browse files
Delay setting dtype of xspace Ops until after all computation to match numpy outputs
1 parent c2b8465 commit 8f213a3

File tree

1 file changed

+28
-32
lines changed

1 file changed

+28
-32
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from pytensor.tensor.exceptions import NotScalarConstantError
2828
from pytensor.tensor.math import abs as pt_abs
2929
from pytensor.tensor.math import all as pt_all
30+
from pytensor.tensor.math import eq as pt_eq
3031
from pytensor.tensor.math import (
31-
bitwise_and,
3232
ge,
3333
gt,
3434
log,
@@ -39,7 +39,6 @@
3939
sign,
4040
switch,
4141
)
42-
from pytensor.tensor.math import eq as pt_eq
4342
from pytensor.tensor.math import max as pt_max
4443
from pytensor.tensor.math import sum as pt_sum
4544
from pytensor.tensor.shape import specify_broadcastable
@@ -1618,22 +1617,18 @@ def _linspace_core(
16181617
start: TensorVariable,
16191618
stop: TensorVariable,
16201619
num: int,
1621-
dtype: str,
16221620
endpoint=True,
16231621
retstep=False,
16241622
axis=0,
16251623
) -> TensorVariable | tuple[TensorVariable, TensorVariable]:
16261624
div = (num - 1) if endpoint else num
1627-
delta = (stop - start).astype(dtype)
1628-
samples = ptb.arange(0, num, dtype=dtype).reshape((-1,) + (1,) * delta.ndim)
1625+
delta = stop - start
1626+
samples = ptb.arange(0, num).reshape((-1,) + (1,) * delta.ndim)
16291627

1630-
step = switch(gt(div, 0), delta / div, np.nan)
1628+
step = delta / div
16311629
samples = switch(gt(div, 0), samples * delta / div + start, samples * delta + start)
1632-
samples = switch(
1633-
bitwise_and(gt(num, 1), np.asarray(endpoint)),
1634-
set_subtensor(samples[-1, ...], stop),
1635-
samples,
1636-
)
1630+
if endpoint:
1631+
samples = switch(gt(num, 1), set_subtensor(samples[-1, ...], stop), samples)
16371632

16381633
if axis != 0:
16391634
samples = ptb.moveaxis(samples, 0, axis)
@@ -1644,17 +1639,14 @@ def _linspace_core(
16441639
return samples
16451640

16461641

1647-
def _broadcast_inputs_and_dtypes(*args, dtype=None):
1642+
def _broadcast_inputs(*args):
16481643
args = map(ptb.as_tensor_variable, args)
16491644
args = broadcast_arrays(*args)
16501645

1651-
if dtype is None:
1652-
dtype = pytensor.config.floatX
1653-
1654-
return args, dtype
1646+
return args
16551647

16561648

1657-
def _broadcast_base_with_inputs(start, stop, base, dtype, axis):
1649+
def _broadcast_base_with_inputs(start, stop, base, axis):
16581650
"""
16591651
Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it
16601652
may change how the axis argument is interpreted in the final output.
@@ -1664,14 +1656,13 @@ def _broadcast_base_with_inputs(start, stop, base, dtype, axis):
16641656
start
16651657
stop
16661658
base
1667-
dtype
16681659
axis
16691660
16701661
Returns
16711662
-------
16721663
16731664
"""
1674-
base = ptb.as_tensor_variable(base, dtype=dtype)
1665+
base = ptb.as_tensor_variable(base)
16751666
if base.ndim > 0:
16761667
ndmax = len(broadcast_shape(start, stop, base))
16771668
start, stop, base = (
@@ -1747,19 +1738,22 @@ def linspace(
17471738
step: TensorVariable
17481739
Tensor containing the spacing between samples. Only returned if `retstep` is True.
17491740
"""
1741+
if dtype is None:
1742+
dtype = pytensor.config.floatX
17501743
end, num = _check_deprecated_inputs(stop, end, num, steps)
1751-
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1744+
start, stop = _broadcast_inputs(start, stop)
17521745

1753-
return _linspace_core(
1746+
ls = _linspace_core(
17541747
start=start,
17551748
stop=stop,
17561749
num=num,
1757-
dtype=dtype,
17581750
endpoint=endpoint,
17591751
retstep=retstep,
17601752
axis=axis,
17611753
)
17621754

1755+
return ls.astype(dtype)
1756+
17631757

17641758
def geomspace(
17651759
start: TensorLike,
@@ -1826,9 +1820,11 @@ def geomspace(
18261820
samples: TensorVariable
18271821
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
18281822
"""
1823+
if dtype is None:
1824+
dtype = pytensor.config.floatX
18291825
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1830-
(start, stop), dtype = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1831-
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1826+
start, stop = _broadcast_inputs(start, stop)
1827+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, axis)
18321828

18331829
out_sign = sign(start)
18341830
log_start, log_stop = (
@@ -1840,23 +1836,22 @@ def geomspace(
18401836
stop=log_stop,
18411837
num=num,
18421838
endpoint=endpoint,
1843-
dtype=dtype,
18441839
axis=0,
18451840
retstep=False,
18461841
)
18471842
result = base**result
18481843

18491844
if num > 0:
1850-
set_subtensor(result[0, ...], start, inplace=True)
1845+
result = set_subtensor(result[0, ...], start)
18511846
if num > 1 and endpoint:
1852-
set_subtensor(result[-1, ...], stop, inplace=True)
1847+
result = set_subtensor(result[-1, ...], stop)
18531848

18541849
result = result * out_sign
18551850

18561851
if axis != 0:
18571852
result = ptb.moveaxis(result, 0, axis)
18581853

1859-
return result
1854+
return result.astype(dtype)
18601855

18611856

18621857
def logspace(
@@ -1870,21 +1865,22 @@ def logspace(
18701865
end: TensorLike | None = None,
18711866
steps: TensorLike | None = None,
18721867
) -> TensorVariable:
1868+
if dtype is None:
1869+
dtype = pytensor.config.floatX
18731870
stop, num = _check_deprecated_inputs(stop, end, num, steps)
1874-
(start, stop), type = _broadcast_inputs_and_dtypes(start, stop, dtype=dtype)
1875-
start, stop, base = _broadcast_base_with_inputs(start, stop, base, dtype, axis)
1871+
start, stop = _broadcast_inputs(start, stop)
1872+
start, stop, base = _broadcast_base_with_inputs(start, stop, base, axis)
18761873

18771874
ls = _linspace_core(
18781875
start=start,
18791876
stop=stop,
18801877
num=num,
18811878
endpoint=endpoint,
1882-
dtype=dtype,
18831879
axis=axis,
18841880
retstep=False,
18851881
)
18861882

1887-
return base**ls
1883+
return (base**ls).astype(dtype)
18881884

18891885

18901886
def broadcast_to(

0 commit comments

Comments
 (0)