Skip to content

Commit 56358cc

Browse files
Fix merge conflicts
2 parents 03902a7 + 981688c commit 56358cc

File tree

29 files changed

+2081
-230
lines changed

29 files changed

+2081
-230
lines changed

.github/workflows/pypi.yml

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ jobs:
3030
- name: Build SDist
3131
run: pipx run build --sdist
3232

33-
- uses: actions/upload-artifact@v3
33+
- uses: actions/upload-artifact@v4
3434
with:
35+
name: sdist
3536
path: dist/*.tar.gz
3637

3738
build_wheels:
38-
name: Build ${{ matrix.python-version }} wheels on ${{ matrix.platform }}
39+
name: Build wheels for ${{ matrix.platform }}
3940
runs-on: ${{ matrix.platform }}
4041
strategy:
4142
matrix:
@@ -51,19 +52,27 @@ jobs:
5152
- name: Build wheels
5253
uses: pypa/cibuildwheel@v2.19.2
5354

54-
- uses: actions/upload-artifact@v3
55+
- uses: actions/upload-artifact@v4
5556
with:
57+
name: wheels-${{ matrix.platform }}
5658
path: ./wheelhouse/*.whl
5759

5860
check_dist:
5961
name: Check dist
6062
needs: [make_sdist,build_wheels]
6163
runs-on: ubuntu-22.04
6264
steps:
63-
- uses: actions/download-artifact@v3
65+
- uses: actions/download-artifact@v4
6466
with:
65-
name: artifact
67+
name: sdist
6668
path: dist
69+
70+
- uses: actions/download-artifact@v4
71+
with:
72+
pattern: wheels-*
73+
path: dist
74+
merge-multiple: true
75+
6776
- name: Check SDist
6877
run: |
6978
mkdir -p test-sdist
@@ -83,12 +92,18 @@ jobs:
8392
runs-on: ubuntu-latest
8493
if: github.event_name == 'release' && github.event.action == 'published'
8594
steps:
86-
- uses: actions/download-artifact@v3
87-
with:
88-
name: artifact
89-
path: dist
95+
- uses: actions/download-artifact@v4
96+
with:
97+
name: sdist
98+
path: dist
9099

91-
- uses: pypa/gh-action-pypi-publish@v1.9.0
92-
with:
93-
user: __token__
94-
password: ${{ secrets.pypi_password }}
100+
- uses: actions/download-artifact@v4
101+
with:
102+
pattern: wheels-*
103+
path: dist
104+
merge-multiple: true
105+
106+
- uses: pypa/gh-action-pypi-publish@v1.9.0
107+
with:
108+
user: __token__
109+
password: ${{ secrets.pypi_password }}

.github/workflows/test.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ jobs:
187187
FLOAT32: ${{ matrix.float32 }}
188188

189189
- name: Upload coverage file
190-
uses: actions/upload-artifact@v3
190+
uses: actions/upload-artifact@v4
191191
with:
192-
name: coverage
192+
name: coverage-${{ steps.matrix-id.outputs.id }}
193193
path: coverage/coverage-${{ steps.matrix-id.outputs.id }}.xml
194194

195195
benchmarks:
@@ -273,10 +273,11 @@ jobs:
273273
python -m pip install -U coverage>=5.1 coveralls
274274
275275
- name: Download coverage file
276-
uses: actions/download-artifact@v3
276+
uses: actions/download-artifact@v4
277277
with:
278-
name: coverage
278+
pattern: coverage-*
279279
path: coverage
280+
merge-multiple: true
280281

281282
- name: Upload coverage to Codecov
282283
uses: codecov/codecov-action@v4

pytensor/compile/builders.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from pytensor.compile.function import function
1010
from pytensor.compile.function.pfunc import rebuild_collect_shared
11-
from pytensor.compile.mode import optdb
1211
from pytensor.compile.sharedvalue import SharedVariable
1312
from pytensor.configdefaults import config
1413
from pytensor.gradient import DisconnectedType, Rop, grad
@@ -24,7 +23,6 @@
2423
from pytensor.graph.null_type import NullType
2524
from pytensor.graph.op import HasInnerGraph, Op
2625
from pytensor.graph.replace import clone_replace
27-
from pytensor.graph.rewriting.basic import in2out, node_rewriter
2826
from pytensor.graph.utils import MissingInputError
2927

3028

@@ -575,7 +573,7 @@ def lop_overrides(inps, grads):
575573
for inp_grad in input_grads
576574
if not isinstance(inp_grad.type, DisconnectedType | NullType)
577575
]
578-
lop_op = type(self)(
576+
lop_op = OpFromGraph(
579577
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
580578
outputs=connected_input_grads,
581579
inline=self.is_inline,
@@ -669,7 +667,7 @@ def _build_and_cache_rop_op(self):
669667
for out_grad in output_grads
670668
if not isinstance(out_grad.type, DisconnectedType | NullType)
671669
]
672-
rop_op = type(self)(
670+
rop_op = OpFromGraph(
673671
inputs=inner_inputs + eval_points,
674672
outputs=filtered_output_grads,
675673
inline=self.is_inline,
@@ -852,29 +850,3 @@ def perform(self, node, inputs, outputs):
852850
assert len(variables) == len(outputs)
853851
for output, variable in zip(outputs, variables):
854852
output[0] = variable
855-
856-
857-
@node_rewriter([OpFromGraph])
858-
def inline_ofg_expansion(fgraph, node):
859-
"""
860-
This optimization expands internal graph of OpFromGraph.
861-
Only performed if node.op.is_inline == True
862-
Doing so can improve optimization at the cost of compilation speed.
863-
"""
864-
op = node.op
865-
if not isinstance(op, OpFromGraph):
866-
return False
867-
if not op.is_inline:
868-
return False
869-
return clone_replace(op.inner_outputs, dict(zip(op.inner_inputs, node.inputs)))
870-
871-
872-
# We want to run this before the first merge optimizer
873-
# and before the first scan optimizer.
874-
optdb.register(
875-
"inline_ofg_expansion",
876-
in2out(inline_ofg_expansion),
877-
"fast_compile",
878-
"fast_run",
879-
position=-0.01,
880-
)

pytensor/link/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pytensor.link.pytorch.linker import PytorchLinker

pytensor/link/jax/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytensor.link.jax.dispatch.blockwise
77
import pytensor.link.jax.dispatch.elemwise
88
import pytensor.link.jax.dispatch.extra_ops
9+
import pytensor.link.jax.dispatch.pad
910
import pytensor.link.jax.dispatch.math
1011
import pytensor.link.jax.dispatch.nlinalg
1112
import pytensor.link.jax.dispatch.random

pytensor/link/jax/dispatch/basic.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import warnings
2+
from collections.abc import Callable
23
from functools import singledispatch
34

45
import jax
56
import jax.numpy as jnp
67
import numpy as np
78

9+
from pytensor.compile import JAX
10+
from pytensor.compile.builders import OpFromGraph
811
from pytensor.compile.ops import DeepCopyOp, ViewOp
912
from pytensor.configdefaults import config
1013
from pytensor.graph.fg import FunctionGraph
@@ -114,3 +117,24 @@ def viewop(x):
114117
return x
115118

116119
return viewop
120+
121+
122+
@jax_funcify.register(OpFromGraph)
123+
def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
124+
_ = kwargs.pop("storage_map", None)
125+
126+
# Apply inner rewrites
127+
JAX.optimizer(ofg.fgraph)
128+
fgraph_fn = jax_funcify(ofg.fgraph, **kwargs)
129+
130+
if len(ofg.fgraph.outputs) == 1:
131+
132+
def opfromgraph(*inputs):
133+
return fgraph_fn(*inputs)[0]
134+
135+
else:
136+
137+
def opfromgraph(*inputs):
138+
return fgraph_fn(*inputs)
139+
140+
return opfromgraph

pytensor/link/jax/dispatch/pad.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import jax.numpy as jnp
2+
import numpy as np
3+
4+
from pytensor.link.jax.dispatch import jax_funcify
5+
from pytensor.tensor.pad import Pad
6+
7+
8+
@jax_funcify.register(Pad)
9+
def jax_funcify_pad(op, **kwargs):
10+
pad_mode = op.pad_mode
11+
reflect_type = op.reflect_type
12+
has_stat_length = op.has_stat_length
13+
14+
if pad_mode == "constant":
15+
16+
def constant_pad(x, pad_width, constant_values):
17+
return jnp.pad(x, pad_width, mode=pad_mode, constant_values=constant_values)
18+
19+
return constant_pad
20+
21+
elif pad_mode == "linear_ramp":
22+
23+
def lr_pad(x, pad_width, end_values):
24+
# JAX does not allow a dynamic input if end_values is non-scalar
25+
if not isinstance(end_values, int | float):
26+
end_values = tuple(np.array(end_values))
27+
return jnp.pad(x, pad_width, mode=pad_mode, end_values=end_values)
28+
29+
return lr_pad
30+
31+
elif pad_mode in ["maximum", "minimum", "mean"] and has_stat_length:
32+
33+
def stat_pad(x, pad_width, stat_length):
34+
# JAX does not allow a dynamic input here, need to cast to tuple
35+
return jnp.pad(
36+
x, pad_width, mode=pad_mode, stat_length=tuple(np.array(stat_length))
37+
)
38+
39+
return stat_pad
40+
41+
elif pad_mode in ["reflect", "symmetric"]:
42+
43+
def loop_pad(x, pad_width):
44+
return jnp.pad(x, pad_width, mode=pad_mode, reflect_type=reflect_type)
45+
46+
return loop_pad
47+
48+
else:
49+
50+
def pad(x, pad_width):
51+
return jnp.pad(x, pad_width, mode=pad_mode)
52+
53+
return pad

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
33

44
# # Load dispatch specializations
5+
import pytensor.link.pytorch.dispatch.blas
56
import pytensor.link.pytorch.dispatch.scalar
67
import pytensor.link.pytorch.dispatch.elemwise
8+
import pytensor.link.pytorch.dispatch.math
79
import pytensor.link.pytorch.dispatch.extra_ops
810
import pytensor.link.pytorch.dispatch.subtensor
911
import pytensor.link.pytorch.dispatch.sort
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.blas import BatchedDot
5+
6+
7+
@pytorch_funcify.register(BatchedDot)
8+
def pytorch_funcify_BatchedDot(op, **kwargs):
9+
def batched_dot(a, b):
10+
if a.shape[0] != b.shape[0]:
11+
raise TypeError("Shapes must match in the 0-th dimension")
12+
return torch.bmm(a, b)
13+
14+
return batched_dot
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.math import Dot
5+
6+
7+
@pytorch_funcify.register(Dot)
8+
def pytorch_funcify_Dot(op, **kwargs):
9+
def dot(x, y):
10+
return torch.matmul(x, y)
11+
12+
return dot

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
130130
from pytensor.tensor.extra_ops import *
131131
from pytensor.tensor.io import *
132132
from pytensor.tensor.math import *
133+
from pytensor.tensor.pad import pad
133134
from pytensor.tensor.shape import (
134135
reshape,
135136
shape,

0 commit comments

Comments
 (0)