Skip to content

Commit 93c6e47

Browse files
Ch0ronomatoIan Schweer
authored and
Ian Schweer
committed
Merge branch 'main' of github.com:pymc-devs/pytensor into blockwise
2 parents d639fd9 + e73258b commit 93c6e47

File tree

136 files changed

+3446
-2197
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

136 files changed

+3446
-2197
lines changed

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
- uses: actions/checkout@v4
1616
- uses: mamba-org/setup-micromamba@v1
1717
with:
18-
micromamba-version: "latest" # any version from https://github.com/mamba-org/micromamba-releases
18+
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
1919
environment-file: environment.yml
2020
init-shell: bash
2121
cache-environment: true

.github/workflows/pypi.yml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ jobs:
3535
name: sdist
3636
path: dist/*.tar.gz
3737

38+
run_checks:
39+
name: Build & inspect our package.
40+
# Note: the resulting builds are not actually published.
41+
# This is purely for additional testing and diagnostic purposes.
42+
runs-on: ubuntu-latest
43+
44+
steps:
45+
- uses: actions/checkout@v4
46+
with:
47+
fetch-depth: 0
48+
- uses: hynek/build-and-inspect-python-package@v2
49+
3850
build_wheels:
3951
name: Build wheels for ${{ matrix.platform }}
4052
runs-on: ${{ matrix.platform }}
@@ -50,7 +62,7 @@ jobs:
5062
fetch-depth: 0
5163

5264
- name: Build wheels
53-
uses: pypa/cibuildwheel@v2.19.2
65+
uses: pypa/cibuildwheel@v2.21.2
5466

5567
- uses: actions/upload-artifact@v4
5668
with:
@@ -133,7 +145,7 @@ jobs:
133145
name: universal_wheel
134146
path: dist
135147

136-
- uses: pypa/gh-action-pypi-publish@v1.9.0
148+
- uses: pypa/gh-action-pypi-publish@v1.10.3
137149
with:
138150
user: __token__
139151
password: ${{ secrets.pypi_password }}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
name: Read the Docs Pull Request Preview
2+
on:
3+
pull_request_target:
4+
types:
5+
- opened
6+
7+
permissions:
8+
pull-requests: write
9+
10+
jobs:
11+
documentation-links:
12+
runs-on: ubuntu-latest
13+
steps:
14+
- uses: readthedocs/actions/preview@v1
15+
with:
16+
project-slug: "pytensor"

.github/workflows/test.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
install-jax: [0]
7979
install-torch: [0]
8080
part:
81-
- "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link"
81+
- "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
8282
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
8383
- "tests/scan"
8484
- "tests/sparse"
@@ -97,9 +97,9 @@ jobs:
9797
part: "tests/tensor/test_math.py"
9898
- fast-compile: 1
9999
float32: 1
100-
- part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link"
100+
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
101101
float32: 1
102-
- part: "--doctest-modules --ignore=pytensor/misc/check_duplicate_key.py pytensor --ignore=pytensor/link"
102+
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
103103
fast-compile: 1
104104
include:
105105
- install-numba: 1
@@ -135,7 +135,7 @@ jobs:
135135
uses: mamba-org/setup-micromamba@v1
136136
with:
137137
environment-name: pytensor-test
138-
micromamba-version: "latest"
138+
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
139139
init-shell: bash
140140
post-cleanup: "all"
141141
create-args: python=${{ matrix.python-version }}
@@ -157,7 +157,7 @@ jobs:
157157
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
158158
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
159159
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
160-
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
160+
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
161161
pip install pytest-sphinx
162162
163163
pip install -e ./
@@ -209,13 +209,13 @@ jobs:
209209
uses: mamba-org/setup-micromamba@v1
210210
with:
211211
environment-name: pytensor-test
212-
micromamba-version: "latest"
212+
micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved
213213
init-shell: bash
214214
post-cleanup: "all"
215215
- name: Install dependencies
216216
shell: micromamba-shell {0}
217217
run: |
218-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
218+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
219219
pip install -e ./
220220
micromamba list && pip freeze
221221
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

.pre-commit-config.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ exclude: |
99
)$
1010
repos:
1111
- repo: https://github.com/pre-commit/pre-commit-hooks
12-
rev: v4.6.0
12+
rev: v5.0.0
1313
hooks:
1414
- id: debug-statements
1515
exclude: |
@@ -21,8 +21,13 @@ repos:
2121
pytensor/tensor/variable\.py|
2222
)$
2323
- id: check-merge-conflict
24+
- repo: https://github.com/sphinx-contrib/sphinx-lint
25+
rev: v1.0.0
26+
hooks:
27+
- id: sphinx-lint
28+
args: ["."]
2429
- repo: https://github.com/astral-sh/ruff-pre-commit
25-
rev: v0.6.3
30+
rev: v0.7.1
2631
hooks:
2732
- id: ruff
2833
args: ["--fix", "--output-format=full"]

readthedocs.yml renamed to .readthedocs.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ sphinx:
44
conda:
55
environment: doc/environment.yml
66
build:
7-
os: "ubuntu-20.04"
7+
os: "ubuntu-lts-latest"
88
tools:
9-
python: "mambaforge-4.10"
9+
python: "mambaforge-latest"

doc/extending/creating_a_c_op.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ This distance between consecutive elements of an array over a given dimension,
152152
is called the stride of that dimension.
153153

154154

155-
Accessing NumPy :class`ndarray`\s' data and properties
155+
Accessing NumPy :class:`ndarray`'s data and properties
156156
------------------------------------------------------
157157

158158
The following macros serve to access various attributes of NumPy :class:`ndarray`\s.

doc/extending/creating_a_numba_jax_op.rst

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Adding JAX, Numba and Pytorch support for `Op`\s
44
PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do
55
this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function.
66

7-
This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.
7+
This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.
88

99
Step 1: Identify the PyTensor :class:`Op` you'd like to implement
1010
------------------------------------------------------------------------
@@ -60,7 +60,7 @@ could also have any data type (e.g. floats, ints), so our implementation
6060
must be able to handle all the possible data types.
6161

6262
It also tells us that there's only one return value, that it has a data type
63-
determined by :meth:`x.type()` i.e., the data type of the original tensor.
63+
determined by :meth:`x.type` i.e., the data type of the original tensor.
6464
This implies that the result is necessarily a matrix.
6565

6666
Some class may have a more complex behavior. For example, the :class:`CumOp`\ :class:`Op`
@@ -116,7 +116,7 @@ Here's an example for :class:`DimShuffle`:
116116

117117
.. tab-set::
118118

119-
.. tab-item:: JAX
119+
.. tab-item:: JAX
120120

121121
.. code:: python
122122
@@ -134,7 +134,7 @@ Here's an example for :class:`DimShuffle`:
134134
res = jnp.copy(res)
135135
136136
return res
137-
137+
138138
.. tab-item:: Numba
139139

140140
.. code:: python
@@ -465,7 +465,7 @@ Step 4: Write tests
465465
.. tab-item:: JAX
466466

467467
Test that your registered `Op` is working correctly by adding tests to the
468-
appropriate test suites in PyTensor (e.g. in ``tests.link.jax``).
468+
appropriate test suites in PyTensor (e.g. in ``tests.link.jax``).
469469
The tests should ensure that your implementation can
470470
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
471471
Check the existing tests for the general outline of these kinds of tests. In
@@ -478,7 +478,7 @@ Step 4: Write tests
478478
Here's a small example of a test for :class:`CumOp` above:
479479

480480
.. code:: python
481-
481+
482482
import numpy as np
483483
import pytensor.tensor as pt
484484
from pytensor.configdefaults import config
@@ -514,22 +514,22 @@ Step 4: Write tests
514514
.. code:: python
515515
516516
import pytest
517-
517+
518518
def test_jax_CumOp():
519519
"""Test JAX conversion of the `CumOp` `Op`."""
520520
a = pt.matrix("a")
521521
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
522-
522+
523523
with pytest.raises(NotImplementedError):
524524
out = pt.cumprod(a, axis=1)
525525
fgraph = FunctionGraph([a], [out])
526526
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
527-
528-
527+
528+
529529
.. tab-item:: Numba
530530

531531
Test that your registered `Op` is working correctly by adding tests to the
532-
appropriate test suites in PyTensor (e.g. in ``tests.link.numba``).
532+
appropriate test suites in PyTensor (e.g. in ``tests.link.numba``).
533533
The tests should ensure that your implementation can
534534
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
535535
Check the existing tests for the general outline of these kinds of tests. In
@@ -542,7 +542,7 @@ Step 4: Write tests
542542
Here's a small example of a test for :class:`CumOp` above:
543543

544544
.. code:: python
545-
545+
546546
from tests.link.numba.test_basic import compare_numba_and_py
547547
from pytensor.graph import FunctionGraph
548548
from pytensor.compile.sharedvalue import SharedVariable
@@ -561,11 +561,11 @@ Step 4: Write tests
561561
if not isinstance(i, SharedVariable | Constant)
562562
],
563563
)
564-
564+
565565
566566
567567
.. tab-item:: Pytorch
568-
568+
569569
Test that your registered `Op` is working correctly by adding tests to the
570570
appropriate test suites in PyTensor (``tests.link.pytorch``). The tests should ensure that your implementation can
571571
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
@@ -579,7 +579,7 @@ Step 4: Write tests
579579
Here's a small example of a test for :class:`CumOp` above:
580580

581581
.. code:: python
582-
582+
583583
import numpy as np
584584
import pytest
585585
import pytensor.tensor as pt
@@ -592,7 +592,7 @@ Step 4: Write tests
592592
["float64", "int64"],
593593
)
594594
@pytest.mark.parametrize(
595-
"axis",
595+
"axis",
596596
[None, 1, (0,)],
597597
)
598598
def test_pytorch_CumOp(axis, dtype):
@@ -650,4 +650,4 @@ as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
650650
All jitted functions now must have constant shape, which means a graph like the
651651
one of :class:`Eye` can never be translated to JAX, since it's fundamentally a
652652
function with dynamic shapes. In other words, only PyTensor graphs with static shapes
653-
can be translated to JAX at the moment.
653+
can be translated to JAX at the moment.

doc/extending/type.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ returns eitehr a new transferred variable (which can be the same as
333333
the input if no transfer is necessary) or returns None if the transfer
334334
can't be done.
335335

336-
Then register that function by calling :func:`register_transfer()`
336+
Then register that function by calling :func:`register_transfer`
337337
with it as argument.
338338

339339
An example

doc/library/compile/io.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ The ``inputs`` argument to ``pytensor.function`` is a list, containing the ``Var
3636
``self.<name>``. The default value is ``None``.
3737

3838
``value``: literal or ``Container``. The initial/default value for this
39-
input. If update is`` None``, this input acts just like
39+
input. If update is ``None``, this input acts just like
4040
an argument with a default value in Python. If update is not ``None``,
4141
changes to this
4242
value will "stick around", whether due to an update or a user's

doc/library/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ import ``pytensor`` and print the config variable, as in:
226226
in the future.
227227

228228
The ``'numpy+floatX'`` setting attempts to mimic NumPy casting rules,
229-
although it prefers to use ``float32` `numbers instead of ``float64`` when
229+
although it prefers to use ``float32`` numbers instead of ``float64`` when
230230
``config.floatX`` is set to ``'float32'`` and the associated data is not
231231
explicitly typed as ``float64`` (e.g. regular Python floats). Note that
232232
``'numpy+floatX'`` is not currently behaving exactly as planned (it is a

doc/library/graph/graph.rst

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,5 @@
44
:mod:`graph` -- Interface for the PyTensor graph
55
================================================
66

7-
---------
8-
Reference
9-
---------
10-
117
.. automodule:: pytensor.graph.basic
12-
:platform: Unix, Windows
13-
:synopsis: Interface for types of symbolic variables
148
:members:
15-
.. moduleauthor:: LISA

doc/library/graph/index.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11

22
.. _libdoc_graph:
33

4-
================================================
5-
:mod:`graph` -- Theano Internals [doc TODO]
6-
================================================
4+
========================================
5+
:mod:`graph` -- PyTensor Graph Internals
6+
========================================
77

88
.. module:: graph
9-
:platform: Unix, Windows
10-
:synopsis: Theano Internals
9+
1110
.. moduleauthor:: LISA
1211

1312
.. toctree::
1413
:maxdepth: 1
1514

1615
graph
1716
fgraph
17+
replace
1818
features
1919
op
2020
type

doc/library/graph/op.rst

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
21
.. _libdoc_graph_op:
32

4-
==============================================================
5-
:mod:`graph` -- Objects and functions for computational graphs
6-
==============================================================
3+
===========================================
4+
:mod:`op` -- Objects that define operations
5+
===========================================
76

87
.. automodule:: pytensor.graph.op
9-
:platform: Unix, Windows
10-
:synopsis: Interface for types of symbolic variables
118
:members:
12-
.. moduleauthor:: LISA

doc/library/graph/replace.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
.. _libdoc_graph_replace:
2+
3+
==================================================
4+
:mod:`replace` -- High level graph transformations
5+
==================================================
6+
7+
.. automodule:: pytensor.graph.replace
8+
:members:

0 commit comments

Comments
 (0)