Skip to content

Commit a018283

Browse files
committed
Merge branch 'main' of github.com:pymc-devs/pytensor into elemwise_torch_improvement
2 parents f277af7 + e73258b commit a018283

30 files changed

+1030
-649
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ exclude: |
99
)$
1010
repos:
1111
- repo: https://github.com/pre-commit/pre-commit-hooks
12-
rev: v4.6.0
12+
rev: v5.0.0
1313
hooks:
1414
- id: debug-statements
1515
exclude: |
@@ -27,7 +27,7 @@ repos:
2727
- id: sphinx-lint
2828
args: ["."]
2929
- repo: https://github.com/astral-sh/ruff-pre-commit
30-
rev: v0.6.5
30+
rev: v0.7.1
3131
hooks:
3232
- id: ruff
3333
args: ["--fix", "--output-format=full"]

pytensor/compile/function/types.py

Lines changed: 122 additions & 127 deletions
Large diffs are not rendered by default.

pytensor/gradient.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,6 @@ def fiter_variable(self, other):
128128
" a symbolic placeholder."
129129
)
130130

131-
def may_share_memory(a, b):
132-
return False
133-
134131
def value_eq(a, b, force_same_dtype=True):
135132
raise AssertionError(
136133
"If you're assigning to a DisconnectedType you're"

pytensor/graph/null_type.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ def filter(self, data, strict=False, allow_downcast=None):
2626
def filter_variable(self, other, allow_convert=True):
2727
raise ValueError("No values may be assigned to a NullType")
2828

29-
def may_share_memory(a, b):
30-
return False
31-
3229
def values_eq(self, a, b, force_same_dtype=True):
3330
raise ValueError("NullType has no values to compare")
3431

pytensor/graph/op.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,17 +513,24 @@ def make_py_thunk(
513513
"""
514514
node_input_storage = [storage_map[r] for r in node.inputs]
515515
node_output_storage = [storage_map[r] for r in node.outputs]
516+
node_compute_map = [compute_map[r] for r in node.outputs]
516517

517518
if debug and hasattr(self, "debug_perform"):
518519
p = node.op.debug_perform
519520
else:
520521
p = node.op.perform
521522

522523
@is_thunk_type
523-
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
524+
def rval(
525+
p=p,
526+
i=node_input_storage,
527+
o=node_output_storage,
528+
n=node,
529+
cm=node_compute_map,
530+
):
524531
r = p(n, [x[0] for x in i], o)
525-
for o in node.outputs:
526-
compute_map[o][0] = True
532+
for entry in cm:
533+
entry[0] = True
527534
return r
528535

529536
rval.inputs = node_input_storage

pytensor/graph/type.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,7 @@ def in_same_class(self, otype: "Type") -> bool | None:
4848
unique element (i.e. it uses `self.__eq__`).
4949
5050
"""
51-
if self == otype:
52-
return True
53-
54-
return False
51+
return self == otype
5552

5653
def is_super(self, otype: "Type") -> bool | None:
5754
"""Determine if `self` is a supertype of `otype`.

pytensor/link/c/cmodule.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,13 +2007,18 @@ def try_blas_flag(flags):
20072007
cflags.extend(f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs())
20082008

20092009
res = GCC_compiler.try_compile_tmp(
2010-
test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True
2010+
test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True, output=True
20112011
)
20122012
# res[0]: shows successful compilation
20132013
# res[1]: shows successful execution
2014+
# res[2]: shows execution results
2015+
# res[3]: shows execution or compilation error message
20142016
if res and res[0] and res[1]:
20152017
return " ".join(flags)
20162018
else:
2019+
_logger.debug(
2020+
"try_blas_flags of flags: %r\nfailed with error message %s", flags, res[3]
2021+
)
20172022
return ""
20182023

20192024

@@ -2801,7 +2806,6 @@ def check_libs(
28012806
_logger.debug("The following blas flags will be used: '%s'", res)
28022807
return res
28032808
else:
2804-
_logger.debug(f"Supplied flags {res} failed to compile")
28052809
_logger.debug("Supplied flags '%s' failed to compile", res)
28062810
raise RuntimeError(f"Supplied flags {flags} failed to compile")
28072811

pytensor/scalar/basic.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,13 +303,6 @@ def clone(self, dtype=None, **kwargs):
303303
dtype = self.dtype
304304
return type(self)(dtype)
305305

306-
@staticmethod
307-
def may_share_memory(a, b):
308-
# This class represent basic c type, represented in python
309-
# with numpy.scalar. They are read only. So from python, they
310-
# can never share memory.
311-
return False
312-
313306
def filter(self, data, strict=False, allow_downcast=None):
314307
py_type = self.dtype_specs()[0]
315308
if strict and not isinstance(data, py_type):
@@ -4253,7 +4246,11 @@ def __str__(self):
42534246
r.name = f"o{int(i)}"
42544247
io = set(self.fgraph.inputs + self.fgraph.outputs)
42554248
for i, r in enumerate(self.fgraph.variables):
4256-
if r not in io and len(self.fgraph.clients[r]) > 1:
4249+
if (
4250+
not isinstance(r, Constant)
4251+
and r not in io
4252+
and len(self.fgraph.clients[r]) > 1
4253+
):
42574254
r.name = f"t{int(i)}"
42584255

42594256
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
@@ -4352,7 +4349,7 @@ def c_code_template(self):
43524349
if var not in self.fgraph.inputs:
43534350
# This is an orphan
43544351
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
4355-
subd[var] = var.type.c_literal(var.data)
4352+
subd[var] = f"({var.type.c_literal(var.data)})"
43564353
else:
43574354
raise ValueError(
43584355
"All orphans in the fgraph to Composite must"
@@ -4411,7 +4408,7 @@ def c_code(self, node, nodename, inames, onames, sub):
44114408
return self.c_code_template % d
44124409

44134410
def c_code_cache_version_outer(self) -> tuple[int, ...]:
4414-
return (4,)
4411+
return (5,)
44154412

44164413

44174414
class Compositef32:

pytensor/scalar/loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def c_code_template(self):
239239
if var not in self.fgraph.inputs:
240240
# This is an orphan
241241
if isinstance(var, Constant) and isinstance(var.type, CLinkerType):
242-
subd[var] = var.type.c_literal(var.data)
242+
subd[var] = f"({var.type.c_literal(var.data)})"
243243
else:
244244
raise ValueError(
245245
"All orphans in the fgraph to ScalarLoop must"
@@ -342,4 +342,4 @@ def c_code(self, node, nodename, inames, onames, sub):
342342
return res
343343

344344
def c_code_cache_version_outer(self):
345-
return (2,)
345+
return (3,)

pytensor/tensor/blockwise.py

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Sequence
2-
from copy import copy
32
from typing import Any, cast
43

54
import numpy as np
@@ -79,7 +78,6 @@ def __init__(
7978
self.name = name
8079
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
8180
self.gufunc_spec = gufunc_spec
82-
self._gufunc = None
8381
if destroy_map is not None:
8482
self.destroy_map = destroy_map
8583
if self.destroy_map != core_op.destroy_map:
@@ -91,11 +89,6 @@ def __init__(
9189

9290
super().__init__(**kwargs)
9391

94-
def __getstate__(self):
95-
d = copy(self.__dict__)
96-
d["_gufunc"] = None
97-
return d
98-
9992
def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
10093
core_input_types = []
10194
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
@@ -296,32 +289,46 @@ def L_op(self, inputs, outs, ograds):
296289

297290
return rval
298291

299-
def _create_gufunc(self, node):
292+
def _create_node_gufunc(self, node) -> None:
293+
"""Define (or retrieve) the node gufunc used in `perform`.
294+
295+
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
296+
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
297+
298+
The gufunc is stored in the tag of the node.
299+
"""
300300
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
301301

302302
if gufunc_spec is not None:
303-
self._gufunc = import_func_from_string(gufunc_spec[0])
304-
if self._gufunc:
305-
return self._gufunc
306-
else:
303+
gufunc = import_func_from_string(gufunc_spec[0])
304+
if gufunc is None:
307305
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
308306

309-
n_outs = len(self.outputs_sig)
310-
core_node = self._create_dummy_core_node(node.inputs)
311-
312-
def core_func(*inner_inputs):
313-
inner_outputs = [[None] for _ in range(n_outs)]
307+
else:
308+
# Wrap core_op perform method in numpy vectorize
309+
n_outs = len(self.outputs_sig)
310+
core_node = self._create_dummy_core_node(node.inputs)
311+
inner_outputs_storage = [[None] for _ in range(n_outs)]
312+
313+
def core_func(
314+
*inner_inputs,
315+
core_node=core_node,
316+
inner_outputs_storage=inner_outputs_storage,
317+
):
318+
self.core_op.perform(
319+
core_node,
320+
[np.asarray(inp) for inp in inner_inputs],
321+
inner_outputs_storage,
322+
)
314323

315-
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
316-
self.core_op.perform(core_node, inner_inputs, inner_outputs)
324+
if n_outs == 1:
325+
return inner_outputs_storage[0][0]
326+
else:
327+
return tuple(r[0] for r in inner_outputs_storage)
317328

318-
if len(inner_outputs) == 1:
319-
return inner_outputs[0][0]
320-
else:
321-
return tuple(r[0] for r in inner_outputs)
329+
gufunc = np.vectorize(core_func, signature=self.signature)
322330

323-
self._gufunc = np.vectorize(core_func, signature=self.signature)
324-
return self._gufunc
331+
node.tag.gufunc = gufunc
325332

326333
def _check_runtime_broadcast(self, node, inputs):
327334
batch_ndim = self.batch_ndim(node)
@@ -340,10 +347,12 @@ def _check_runtime_broadcast(self, node, inputs):
340347
)
341348

342349
def perform(self, node, inputs, output_storage):
343-
gufunc = self._gufunc
350+
gufunc = getattr(node.tag, "gufunc", None)
344351

345352
if gufunc is None:
346-
gufunc = self._create_gufunc(node)
353+
# Cache it once per node
354+
self._create_node_gufunc(node)
355+
gufunc = node.tag.gufunc
347356

348357
self._check_runtime_broadcast(node, inputs)
349358

pytensor/tensor/random/op.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -387,24 +387,17 @@ def dist_params(self, node) -> Sequence[Variable]:
387387
return node.inputs[2:]
388388

389389
def perform(self, node, inputs, outputs):
390-
rng_var_out, smpl_out = outputs
391-
392390
rng, size, *args = inputs
393391

394392
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
395393
if not self.inplace:
396394
rng = copy(rng)
397395

398-
rng_var_out[0] = rng
399-
400-
if size is not None:
401-
size = tuple(size)
402-
smpl_val = self.rng_fn(rng, *([*args, size]))
403-
404-
if not isinstance(smpl_val, np.ndarray) or str(smpl_val.dtype) != self.dtype:
405-
smpl_val = np.asarray(smpl_val, dtype=self.dtype)
406-
407-
smpl_out[0] = smpl_val
396+
outputs[0][0] = rng
397+
outputs[1][0] = np.asarray(
398+
self.rng_fn(rng, *args, None if size is None else tuple(size)),
399+
dtype=self.dtype,
400+
)
408401

409402
def grad(self, inputs, outputs):
410403
return [

pytensor/tensor/rewriting/blockwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ def local_blockwise_alloc(fgraph, node):
127127
value, *shape = inp.owner.inputs
128128

129129
# Check what to do with the value of the Alloc
130-
squeezed_value = _squeeze_left(value, batch_ndim)
131-
missing_ndim = len(shape) - value.type.ndim
130+
missing_ndim = inp.type.ndim - value.type.ndim
131+
squeezed_value = _squeeze_left(value, (batch_ndim - missing_ndim))
132132
if (
133133
(((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:])
134134
!= inp.type.broadcastable[batch_ndim:]

pytensor/tensor/rewriting/linalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44

55
from pytensor import Variable
66
from pytensor import tensor as pt
7+
from pytensor.compile import optdb
78
from pytensor.graph import Apply, FunctionGraph
89
from pytensor.graph.rewriting.basic import (
910
copy_stack_trace,
11+
in2out,
1012
node_rewriter,
1113
)
1214
from pytensor.scalar.basic import Mul
@@ -45,9 +47,11 @@
4547
Cholesky,
4648
Solve,
4749
SolveBase,
50+
_bilinear_solve_discrete_lyapunov,
4851
block_diag,
4952
cholesky,
5053
solve,
54+
solve_discrete_lyapunov,
5155
solve_triangular,
5256
)
5357

@@ -966,3 +970,22 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
966970
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
967971

968972
return [eye_input * (non_eye_input**0.5)]
973+
974+
975+
@node_rewriter([_bilinear_solve_discrete_lyapunov])
976+
def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
977+
"""
978+
Replace BilinearSolveDiscreteLyapunov with a direct computation that is supported by JAX
979+
"""
980+
A, B = (cast(TensorVariable, x) for x in node.inputs)
981+
result = solve_discrete_lyapunov(A, B, method="direct")
982+
983+
return [result]
984+
985+
986+
optdb.register(
987+
"jax_bilinaer_lyapunov_to_direct",
988+
in2out(jax_bilinaer_lyapunov_to_direct),
989+
"jax",
990+
position=0.9, # Run before canonicalization
991+
)

0 commit comments

Comments
 (0)