Skip to content

Commit 3876e73

Browse files
Correct bad imports in optimize.py and expose it via pytensor.tensor.__init__ (#1464)
* Expose `pt.optimize` * Clean up imports in optimize.py * Add example notebook for `optimize.root` * Small updates * Add doc file * Link to optimize docs in tensor index * Add docs to user-facing functions. * Move `scipy.optimize` imports into `perform` methods * Use global import strategy * Remove props and overload __str__ * rerun example notebook
1 parent b56bff5 commit 3876e73

File tree

6 files changed

+2258
-29
lines changed

6 files changed

+2258
-29
lines changed

doc/gallery/optimization/root.ipynb

Lines changed: 2081 additions & 0 deletions
Large diffs are not rendered by default.

doc/library/tensor/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ symbolic expressions using calls that look just like numpy calls, such as
3131
math_opt
3232
basic_opt
3333
functional
34+
optimize

doc/library/tensor/optimize.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
========================================================
2+
:mod:`tensor.optimize` -- Symbolic Optimization Routines
3+
========================================================
4+
5+
.. module:: tensor.conv
6+
:platform: Unix, Windows
7+
:synopsis: Symbolic Optimization Routines
8+
.. moduleauthor:: LISA, PyMC Developers, PyTensor Developers
9+
10+
.. automodule:: pytensor.tensor.optimize
11+
:members:

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
118118
from pytensor.tensor import linalg
119119
from pytensor.tensor import special
120120
from pytensor.tensor import signal
121+
from pytensor.tensor import optimize
121122

122123
# For backward compatibility
123124
from pytensor.tensor import nlinalg

pytensor/tensor/optimize.py

Lines changed: 163 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@
44
from typing import cast
55

66
import numpy as np
7-
from scipy.optimize import minimize as scipy_minimize
8-
from scipy.optimize import minimize_scalar as scipy_minimize_scalar
9-
from scipy.optimize import root as scipy_root
10-
from scipy.optimize import root_scalar as scipy_root_scalar
117

128
import pytensor.scalar as ps
13-
from pytensor import Variable, function, graph_replace
9+
from pytensor.compile.function import function
1410
from pytensor.gradient import grad, hessian, jacobian
1511
from pytensor.graph import Apply, Constant, FunctionGraph
1612
from pytensor.graph.basic import ancestors, truncated_graph_inputs
1713
from pytensor.graph.op import ComputeMapType, HasInnerGraph, Op, StorageMapType
14+
from pytensor.graph.replace import graph_replace
1815
from pytensor.tensor.basic import (
1916
atleast_2d,
2017
concatenate,
@@ -24,7 +21,12 @@
2421
)
2522
from pytensor.tensor.math import dot
2623
from pytensor.tensor.slinalg import solve
27-
from pytensor.tensor.variable import TensorVariable
24+
from pytensor.tensor.variable import TensorVariable, Variable
25+
26+
27+
# scipy.optimize can be slow to import, and will not be used by most users
28+
# We import scipy.optimize lazily inside optimization perform methods to avoid this.
29+
optimize = None
2830

2931

3032
_log = logging.getLogger(__name__)
@@ -352,8 +354,6 @@ def implict_optimization_grads(
352354

353355

354356
class MinimizeScalarOp(ScipyScalarWrapperOp):
355-
__props__ = ("method",)
356-
357357
def __init__(
358358
self,
359359
x: Variable,
@@ -377,15 +377,22 @@ def __init__(
377377
self._fn = None
378378
self._fn_wrapped = None
379379

380+
def __str__(self):
381+
return f"{self.__class__.__name__}(method={self.method})"
382+
380383
def perform(self, node, inputs, outputs):
384+
global optimize
385+
if optimize is None:
386+
import scipy.optimize as optimize
387+
381388
f = self.fn_wrapped
382389
f.clear_cache()
383390

384391
# minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
385392
# the args of the objective function), but it is not used in the optimization.
386393
x0, *args = inputs
387394

388-
res = scipy_minimize_scalar(
395+
res = optimize.minimize_scalar(
389396
fun=f.value,
390397
args=tuple(args),
391398
method=self.method,
@@ -426,6 +433,27 @@ def minimize_scalar(
426433
):
427434
"""
428435
Minimize a scalar objective function using scipy.optimize.minimize_scalar.
436+
437+
Parameters
438+
----------
439+
objective : TensorVariable
440+
The objective function to minimize. This should be a PyTensor variable representing a scalar value.
441+
x : TensorVariable
442+
The variable with respect to which the objective function is minimized. It must be a scalar and an
443+
input to the computational graph of `objective`.
444+
method : str, optional
445+
The optimization method to use. Default is "brent". See `scipy.optimize.minimize_scalar` for other options.
446+
optimizer_kwargs : dict, optional
447+
Additional keyword arguments to pass to `scipy.optimize.minimize_scalar`.
448+
449+
Returns
450+
-------
451+
solution: TensorVariable
452+
Value of `x` that minimizes `objective(x, *args)`. If the success flag is False, this will be the
453+
final state returned by the minimization routine, not necessarily a minimum.
454+
success : TensorVariable
455+
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
456+
value, based on the requested convergence criteria.
429457
"""
430458

431459
args = _find_optimization_parameters(objective, x)
@@ -438,12 +466,14 @@ def minimize_scalar(
438466
optimizer_kwargs=optimizer_kwargs,
439467
)
440468

441-
return minimize_scalar_op(x, *args)
469+
solution, success = cast(
470+
tuple[TensorVariable, TensorVariable], minimize_scalar_op(x, *args)
471+
)
442472

473+
return solution, success
443474

444-
class MinimizeOp(ScipyWrapperOp):
445-
__props__ = ("method", "jac", "hess", "hessp")
446475

476+
class MinimizeOp(ScipyWrapperOp):
447477
def __init__(
448478
self,
449479
x: Variable,
@@ -487,11 +517,24 @@ def __init__(
487517
self._fn = None
488518
self._fn_wrapped = None
489519

520+
def __str__(self):
521+
str_args = ", ".join(
522+
[
523+
f"{arg}={getattr(self, arg)}"
524+
for arg in ["method", "jac", "hess", "hessp"]
525+
]
526+
)
527+
return f"{self.__class__.__name__}({str_args})"
528+
490529
def perform(self, node, inputs, outputs):
530+
global optimize
531+
if optimize is None:
532+
import scipy.optimize as optimize
533+
491534
f = self.fn_wrapped
492535
x0, *args = inputs
493536

494-
res = scipy_minimize(
537+
res = optimize.minimize(
495538
fun=f.value_and_grad if self.jac else f.value,
496539
jac=self.jac,
497540
x0=x0,
@@ -538,7 +581,7 @@ def minimize(
538581
jac: bool = True,
539582
hess: bool = False,
540583
optimizer_kwargs: dict | None = None,
541-
):
584+
) -> tuple[TensorVariable, TensorVariable]:
542585
"""
543586
Minimize a scalar objective function using scipy.optimize.minimize.
544587
@@ -563,9 +606,13 @@ def minimize(
563606
564607
Returns
565608
-------
566-
TensorVariable
567-
The optimized value of x that minimizes the objective function.
609+
solution: TensorVariable
610+
The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
611+
is False, this will be the final state of the minimization routine, but not necessarily a minimum.
568612
613+
success: TensorVariable
614+
Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
615+
value, based on the requested convergence criteria.
569616
"""
570617
args = _find_optimization_parameters(objective, x)
571618

@@ -579,12 +626,14 @@ def minimize(
579626
optimizer_kwargs=optimizer_kwargs,
580627
)
581628

582-
return minimize_op(x, *args)
629+
solution, success = cast(
630+
tuple[TensorVariable, TensorVariable], minimize_op(x, *args)
631+
)
632+
633+
return solution, success
583634

584635

585636
class RootScalarOp(ScipyScalarWrapperOp):
586-
__props__ = ("method", "jac", "hess")
587-
588637
def __init__(
589638
self,
590639
variables,
@@ -633,14 +682,24 @@ def __init__(
633682
self._fn = None
634683
self._fn_wrapped = None
635684

685+
def __str__(self):
686+
str_args = ", ".join(
687+
[f"{arg}={getattr(self, arg)}" for arg in ["method", "jac", "hess"]]
688+
)
689+
return f"{self.__class__.__name__}({str_args})"
690+
636691
def perform(self, node, inputs, outputs):
692+
global optimize
693+
if optimize is None:
694+
import scipy.optimize as optimize
695+
637696
f = self.fn_wrapped
638697
f.clear_cache()
639698
# f.copy_x = True
640699

641700
variables, *args = inputs
642701

643-
res = scipy_root_scalar(
702+
res = optimize.root_scalar(
644703
f=f.value,
645704
fprime=f.grad if self.jac else None,
646705
fprime2=f.hess if self.hess else None,
@@ -676,19 +735,48 @@ def L_op(self, inputs, outputs, output_grads):
676735

677736
def root_scalar(
678737
equation: TensorVariable,
679-
variables: TensorVariable,
738+
variable: TensorVariable,
680739
method: str = "secant",
681740
jac: bool = False,
682741
hess: bool = False,
683742
optimizer_kwargs: dict | None = None,
684-
):
743+
) -> tuple[TensorVariable, TensorVariable]:
685744
"""
686745
Find roots of a scalar equation using scipy.optimize.root_scalar.
746+
747+
Parameters
748+
----------
749+
equation : TensorVariable
750+
The equation for which to find roots. This should be a PyTensor variable representing a single equation in one
751+
variable. The function will find `variables` such that `equation(variables, *args) = 0`.
752+
variable : TensorVariable
753+
The variable with respect to which the equation is solved. It must be a scalar and an input to the
754+
computational graph of `equation`.
755+
method : str, optional
756+
The root-finding method to use. Default is "secant". See `scipy.optimize.root_scalar` for other options.
757+
jac : bool, optional
758+
Whether to compute and use the first derivative of the equation with respect to `variables`.
759+
Default is False. Some methods require this.
760+
hess : bool, optional
761+
Whether to compute and use the second derivative of the equation with respect to `variables`.
762+
Default is False. Some methods require this.
763+
optimizer_kwargs : dict, optional
764+
Additional keyword arguments to pass to `scipy.optimize.root_scalar`.
765+
766+
Returns
767+
-------
768+
solution: TensorVariable
769+
The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
770+
causes `equation` to evaluate to zero. Otherwise it is the final state returned by the root-finding
771+
routine, but not necessarily a root.
772+
773+
success: TensorVariable
774+
Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
687775
"""
688-
args = _find_optimization_parameters(equation, variables)
776+
args = _find_optimization_parameters(equation, variable)
689777

690778
root_scalar_op = RootScalarOp(
691-
variables,
779+
variable,
692780
*args,
693781
equation=equation,
694782
method=method,
@@ -697,7 +785,11 @@ def root_scalar(
697785
optimizer_kwargs=optimizer_kwargs,
698786
)
699787

700-
return root_scalar_op(variables, *args)
788+
solution, success = cast(
789+
tuple[TensorVariable, TensorVariable], root_scalar_op(variable, *args)
790+
)
791+
792+
return solution, success
701793

702794

703795
class RootOp(ScipyWrapperOp):
@@ -734,6 +826,12 @@ def __init__(
734826
self._fn = None
735827
self._fn_wrapped = None
736828

829+
def __str__(self):
830+
str_args = ", ".join(
831+
[f"{arg}={getattr(self, arg)}" for arg in ["method", "jac"]]
832+
)
833+
return f"{self.__class__.__name__}({str_args})"
834+
737835
def build_fn(self):
738836
outputs = self.inner_outputs
739837
variables, *args = self.inner_inputs
@@ -761,13 +859,17 @@ def build_fn(self):
761859
self._fn_wrapped = LRUCache1(fn)
762860

763861
def perform(self, node, inputs, outputs):
862+
global optimize
863+
if optimize is None:
864+
import scipy.optimize as optimize
865+
764866
f = self.fn_wrapped
765867
f.clear_cache()
766868
f.copy_x = True
767869

768870
variables, *args = inputs
769871

770-
res = scipy_root(
872+
res = optimize.root(
771873
fun=f,
772874
jac=self.jac,
773875
x0=variables,
@@ -815,8 +917,36 @@ def root(
815917
method: str = "hybr",
816918
jac: bool = True,
817919
optimizer_kwargs: dict | None = None,
818-
):
819-
"""Find roots of a system of equations using scipy.optimize.root."""
920+
) -> tuple[TensorVariable, TensorVariable]:
921+
"""
922+
Find roots of a system of equations using scipy.optimize.root.
923+
924+
Parameters
925+
----------
926+
equations : TensorVariable
927+
The system of equations for which to find roots. This should be a PyTensor variable representing a
928+
vector (or scalar) value. The function will find `variables` such that `equations(variables, *args) = 0`.
929+
variables : TensorVariable
930+
The variable(s) with respect to which the system of equations is solved. It must be an input to the
931+
computational graph of `equations` and have the same number of dimensions as `equations`.
932+
method : str, optional
933+
The root-finding method to use. Default is "hybr". See `scipy.optimize.root` for other options.
934+
jac : bool, optional
935+
Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
936+
Default is True. Most methods require this.
937+
optimizer_kwargs : dict, optional
938+
Additional keyword arguments to pass to `scipy.optimize.root`.
939+
940+
Returns
941+
-------
942+
solution: TensorVariable
943+
The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
944+
causes all `equations` to evaluate to zero. Otherwise it is the final state returned by the root-finding
945+
routine, but not necessarily a root.
946+
947+
success: TensorVariable
948+
Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
949+
"""
820950

821951
args = _find_optimization_parameters(equations, variables)
822952

@@ -829,7 +959,11 @@ def root(
829959
optimizer_kwargs=optimizer_kwargs,
830960
)
831961

832-
return root_op(variables, *args)
962+
solution, success = cast(
963+
tuple[TensorVariable, TensorVariable], root_op(variables, *args)
964+
)
965+
966+
return solution, success
833967

834968

835969
__all__ = ["minimize_scalar", "minimize", "root_scalar", "root"]

scripts/generate_gallery.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
"introduction": "Introduction",
5959
"rewrites": "Graph Rewriting",
6060
"scan": "Looping in Pytensor",
61+
"optimize": "Optimization in Pytensor",
6162
}
6263

6364

0 commit comments

Comments
 (0)