27
27
from pytensor .tensor .exceptions import NotScalarConstantError
28
28
from pytensor .tensor .math import abs as pt_abs
29
29
from pytensor .tensor .math import all as pt_all
30
+ from pytensor .tensor .math import eq as pt_eq
30
31
from pytensor .tensor .math import (
31
- bitwise_and ,
32
32
ge ,
33
33
gt ,
34
34
log ,
39
39
sign ,
40
40
switch ,
41
41
)
42
- from pytensor .tensor .math import eq as pt_eq
43
42
from pytensor .tensor .math import max as pt_max
44
43
from pytensor .tensor .math import sum as pt_sum
45
44
from pytensor .tensor .shape import specify_broadcastable
@@ -1618,22 +1617,18 @@ def _linspace_core(
1618
1617
start : TensorVariable ,
1619
1618
stop : TensorVariable ,
1620
1619
num : int ,
1621
- dtype : str ,
1622
1620
endpoint = True ,
1623
1621
retstep = False ,
1624
1622
axis = 0 ,
1625
1623
) -> TensorVariable | tuple [TensorVariable , TensorVariable ]:
1626
1624
div = (num - 1 ) if endpoint else num
1627
- delta = ( stop - start ). astype ( dtype )
1628
- samples = ptb .arange (0 , num , dtype = dtype ).reshape ((- 1 ,) + (1 ,) * delta .ndim )
1625
+ delta = stop - start
1626
+ samples = ptb .arange (0 , num ).reshape ((- 1 ,) + (1 ,) * delta .ndim )
1629
1627
1630
- step = switch ( gt ( div , 0 ), delta / div , np . nan )
1628
+ step = delta / div
1631
1629
samples = switch (gt (div , 0 ), samples * delta / div + start , samples * delta + start )
1632
- samples = switch (
1633
- bitwise_and (gt (num , 1 ), np .asarray (endpoint )),
1634
- set_subtensor (samples [- 1 , ...], stop ),
1635
- samples ,
1636
- )
1630
+ if endpoint :
1631
+ samples = switch (gt (num , 1 ), set_subtensor (samples [- 1 , ...], stop ), samples )
1637
1632
1638
1633
if axis != 0 :
1639
1634
samples = ptb .moveaxis (samples , 0 , axis )
@@ -1644,17 +1639,14 @@ def _linspace_core(
1644
1639
return samples
1645
1640
1646
1641
1647
- def _broadcast_inputs_and_dtypes (* args , dtype = None ):
1642
+ def _broadcast_inputs (* args ):
1648
1643
args = map (ptb .as_tensor_variable , args )
1649
1644
args = broadcast_arrays (* args )
1650
1645
1651
- if dtype is None :
1652
- dtype = pytensor .config .floatX
1653
-
1654
- return args , dtype
1646
+ return args
1655
1647
1656
1648
1657
- def _broadcast_base_with_inputs (start , stop , base , dtype , axis ):
1649
+ def _broadcast_base_with_inputs (start , stop , base , axis ):
1658
1650
"""
1659
1651
Broadcast the base tensor with the start and stop tensors if base is not a scalar. This is important because it
1660
1652
may change how the axis argument is interpreted in the final output.
@@ -1664,14 +1656,13 @@ def _broadcast_base_with_inputs(start, stop, base, dtype, axis):
1664
1656
start
1665
1657
stop
1666
1658
base
1667
- dtype
1668
1659
axis
1669
1660
1670
1661
Returns
1671
1662
-------
1672
1663
1673
1664
"""
1674
- base = ptb .as_tensor_variable (base , dtype = dtype )
1665
+ base = ptb .as_tensor_variable (base )
1675
1666
if base .ndim > 0 :
1676
1667
ndmax = len (broadcast_shape (start , stop , base ))
1677
1668
start , stop , base = (
@@ -1747,19 +1738,22 @@ def linspace(
1747
1738
step: TensorVariable
1748
1739
Tensor containing the spacing between samples. Only returned if `retstep` is True.
1749
1740
"""
1741
+ if dtype is None :
1742
+ dtype = pytensor .config .floatX
1750
1743
end , num = _check_deprecated_inputs (stop , end , num , steps )
1751
- ( start , stop ), type = _broadcast_inputs_and_dtypes (start , stop , dtype = dtype )
1744
+ start , stop = _broadcast_inputs (start , stop )
1752
1745
1753
- return _linspace_core (
1746
+ ls = _linspace_core (
1754
1747
start = start ,
1755
1748
stop = stop ,
1756
1749
num = num ,
1757
- dtype = dtype ,
1758
1750
endpoint = endpoint ,
1759
1751
retstep = retstep ,
1760
1752
axis = axis ,
1761
1753
)
1762
1754
1755
+ return ls .astype (dtype )
1756
+
1763
1757
1764
1758
def geomspace (
1765
1759
start : TensorLike ,
@@ -1826,9 +1820,11 @@ def geomspace(
1826
1820
samples: TensorVariable
1827
1821
Tensor containing `num` evenly-spaced values between [start, stop]. The range is inclusive if `endpoint` is True.
1828
1822
"""
1823
+ if dtype is None :
1824
+ dtype = pytensor .config .floatX
1829
1825
stop , num = _check_deprecated_inputs (stop , end , num , steps )
1830
- ( start , stop ), dtype = _broadcast_inputs_and_dtypes (start , stop , dtype = dtype )
1831
- start , stop , base = _broadcast_base_with_inputs (start , stop , base , dtype , axis )
1826
+ start , stop = _broadcast_inputs (start , stop )
1827
+ start , stop , base = _broadcast_base_with_inputs (start , stop , base , axis )
1832
1828
1833
1829
out_sign = sign (start )
1834
1830
log_start , log_stop = (
@@ -1840,23 +1836,22 @@ def geomspace(
1840
1836
stop = log_stop ,
1841
1837
num = num ,
1842
1838
endpoint = endpoint ,
1843
- dtype = dtype ,
1844
1839
axis = 0 ,
1845
1840
retstep = False ,
1846
1841
)
1847
1842
result = base ** result
1848
1843
1849
1844
if num > 0 :
1850
- set_subtensor (result [0 , ...], start , inplace = True )
1845
+ result = set_subtensor (result [0 , ...], start )
1851
1846
if num > 1 and endpoint :
1852
- set_subtensor (result [- 1 , ...], stop , inplace = True )
1847
+ result = set_subtensor (result [- 1 , ...], stop )
1853
1848
1854
1849
result = result * out_sign
1855
1850
1856
1851
if axis != 0 :
1857
1852
result = ptb .moveaxis (result , 0 , axis )
1858
1853
1859
- return result
1854
+ return result . astype ( dtype )
1860
1855
1861
1856
1862
1857
def logspace (
@@ -1870,21 +1865,22 @@ def logspace(
1870
1865
end : TensorLike | None = None ,
1871
1866
steps : TensorLike | None = None ,
1872
1867
) -> TensorVariable :
1868
+ if dtype is None :
1869
+ dtype = pytensor .config .floatX
1873
1870
stop , num = _check_deprecated_inputs (stop , end , num , steps )
1874
- ( start , stop ), type = _broadcast_inputs_and_dtypes (start , stop , dtype = dtype )
1875
- start , stop , base = _broadcast_base_with_inputs (start , stop , base , dtype , axis )
1871
+ start , stop = _broadcast_inputs (start , stop )
1872
+ start , stop , base = _broadcast_base_with_inputs (start , stop , base , axis )
1876
1873
1877
1874
ls = _linspace_core (
1878
1875
start = start ,
1879
1876
stop = stop ,
1880
1877
num = num ,
1881
1878
endpoint = endpoint ,
1882
- dtype = dtype ,
1883
1879
axis = axis ,
1884
1880
retstep = False ,
1885
1881
)
1886
1882
1887
- return base ** ls
1883
+ return ( base ** ls ). astype ( dtype )
1888
1884
1889
1885
1890
1886
def broadcast_to (
0 commit comments