4
4
from typing import cast
5
5
6
6
import numpy as np
7
- from scipy .optimize import minimize as scipy_minimize
8
- from scipy .optimize import minimize_scalar as scipy_minimize_scalar
9
- from scipy .optimize import root as scipy_root
10
- from scipy .optimize import root_scalar as scipy_root_scalar
11
7
12
8
import pytensor .scalar as ps
13
- from pytensor import Variable , function , graph_replace
9
+ from pytensor . compile . function import function
14
10
from pytensor .gradient import grad , hessian , jacobian
15
11
from pytensor .graph import Apply , Constant , FunctionGraph
16
12
from pytensor .graph .basic import ancestors , truncated_graph_inputs
17
13
from pytensor .graph .op import ComputeMapType , HasInnerGraph , Op , StorageMapType
14
+ from pytensor .graph .replace import graph_replace
18
15
from pytensor .tensor .basic import (
19
16
atleast_2d ,
20
17
concatenate ,
24
21
)
25
22
from pytensor .tensor .math import dot
26
23
from pytensor .tensor .slinalg import solve
27
- from pytensor .tensor .variable import TensorVariable
24
+ from pytensor .tensor .variable import TensorVariable , Variable
25
+
26
+
27
+ # scipy.optimize can be slow to import, and will not be used by most users
28
+ # We import scipy.optimize lazily inside optimization perform methods to avoid this.
29
+ optimize = None
28
30
29
31
30
32
_log = logging .getLogger (__name__ )
@@ -352,8 +354,6 @@ def implict_optimization_grads(
352
354
353
355
354
356
class MinimizeScalarOp (ScipyScalarWrapperOp ):
355
- __props__ = ("method" ,)
356
-
357
357
def __init__ (
358
358
self ,
359
359
x : Variable ,
@@ -377,15 +377,22 @@ def __init__(
377
377
self ._fn = None
378
378
self ._fn_wrapped = None
379
379
380
+ def __str__ (self ):
381
+ return f"{ self .__class__ .__name__ } (method={ self .method } )"
382
+
380
383
def perform (self , node , inputs , outputs ):
384
+ global optimize
385
+ if optimize is None :
386
+ import scipy .optimize as optimize
387
+
381
388
f = self .fn_wrapped
382
389
f .clear_cache ()
383
390
384
391
# minimize_scalar doesn't take x0 as an argument. The Op still needs this input (to symbolically determine
385
392
# the args of the objective function), but it is not used in the optimization.
386
393
x0 , * args = inputs
387
394
388
- res = scipy_minimize_scalar (
395
+ res = optimize . minimize_scalar (
389
396
fun = f .value ,
390
397
args = tuple (args ),
391
398
method = self .method ,
@@ -426,6 +433,27 @@ def minimize_scalar(
426
433
):
427
434
"""
428
435
Minimize a scalar objective function using scipy.optimize.minimize_scalar.
436
+
437
+ Parameters
438
+ ----------
439
+ objective : TensorVariable
440
+ The objective function to minimize. This should be a PyTensor variable representing a scalar value.
441
+ x : TensorVariable
442
+ The variable with respect to which the objective function is minimized. It must be a scalar and an
443
+ input to the computational graph of `objective`.
444
+ method : str, optional
445
+ The optimization method to use. Default is "brent". See `scipy.optimize.minimize_scalar` for other options.
446
+ optimizer_kwargs : dict, optional
447
+ Additional keyword arguments to pass to `scipy.optimize.minimize_scalar`.
448
+
449
+ Returns
450
+ -------
451
+ solution: TensorVariable
452
+ Value of `x` that minimizes `objective(x, *args)`. If the success flag is False, this will be the
453
+ final state returned by the minimization routine, not necessarily a minimum.
454
+ success : TensorVariable
455
+ Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
456
+ value, based on the requested convergence criteria.
429
457
"""
430
458
431
459
args = _find_optimization_parameters (objective , x )
@@ -438,12 +466,14 @@ def minimize_scalar(
438
466
optimizer_kwargs = optimizer_kwargs ,
439
467
)
440
468
441
- return minimize_scalar_op (x , * args )
469
+ solution , success = cast (
470
+ tuple [TensorVariable , TensorVariable ], minimize_scalar_op (x , * args )
471
+ )
442
472
473
+ return solution , success
443
474
444
- class MinimizeOp (ScipyWrapperOp ):
445
- __props__ = ("method" , "jac" , "hess" , "hessp" )
446
475
476
+ class MinimizeOp (ScipyWrapperOp ):
447
477
def __init__ (
448
478
self ,
449
479
x : Variable ,
@@ -487,11 +517,24 @@ def __init__(
487
517
self ._fn = None
488
518
self ._fn_wrapped = None
489
519
520
+ def __str__ (self ):
521
+ str_args = ", " .join (
522
+ [
523
+ f"{ arg } ={ getattr (self , arg )} "
524
+ for arg in ["method" , "jac" , "hess" , "hessp" ]
525
+ ]
526
+ )
527
+ return f"{ self .__class__ .__name__ } ({ str_args } )"
528
+
490
529
def perform (self , node , inputs , outputs ):
530
+ global optimize
531
+ if optimize is None :
532
+ import scipy .optimize as optimize
533
+
491
534
f = self .fn_wrapped
492
535
x0 , * args = inputs
493
536
494
- res = scipy_minimize (
537
+ res = optimize . minimize (
495
538
fun = f .value_and_grad if self .jac else f .value ,
496
539
jac = self .jac ,
497
540
x0 = x0 ,
@@ -538,7 +581,7 @@ def minimize(
538
581
jac : bool = True ,
539
582
hess : bool = False ,
540
583
optimizer_kwargs : dict | None = None ,
541
- ):
584
+ ) -> tuple [ TensorVariable , TensorVariable ] :
542
585
"""
543
586
Minimize a scalar objective function using scipy.optimize.minimize.
544
587
@@ -563,9 +606,13 @@ def minimize(
563
606
564
607
Returns
565
608
-------
566
- TensorVariable
567
- The optimized value of x that minimizes the objective function.
609
+ solution: TensorVariable
610
+ The optimized value of the vector of inputs `x` that minimizes `objective(x, *args)`. If the success flag
611
+ is False, this will be the final state of the minimization routine, but not necessarily a minimum.
568
612
613
+ success: TensorVariable
614
+ Symbolic boolean flag indicating whether the minimization routine reported convergence to a minimum
615
+ value, based on the requested convergence criteria.
569
616
"""
570
617
args = _find_optimization_parameters (objective , x )
571
618
@@ -579,12 +626,14 @@ def minimize(
579
626
optimizer_kwargs = optimizer_kwargs ,
580
627
)
581
628
582
- return minimize_op (x , * args )
629
+ solution , success = cast (
630
+ tuple [TensorVariable , TensorVariable ], minimize_op (x , * args )
631
+ )
632
+
633
+ return solution , success
583
634
584
635
585
636
class RootScalarOp (ScipyScalarWrapperOp ):
586
- __props__ = ("method" , "jac" , "hess" )
587
-
588
637
def __init__ (
589
638
self ,
590
639
variables ,
@@ -633,14 +682,24 @@ def __init__(
633
682
self ._fn = None
634
683
self ._fn_wrapped = None
635
684
685
+ def __str__ (self ):
686
+ str_args = ", " .join (
687
+ [f"{ arg } ={ getattr (self , arg )} " for arg in ["method" , "jac" , "hess" ]]
688
+ )
689
+ return f"{ self .__class__ .__name__ } ({ str_args } )"
690
+
636
691
def perform (self , node , inputs , outputs ):
692
+ global optimize
693
+ if optimize is None :
694
+ import scipy .optimize as optimize
695
+
637
696
f = self .fn_wrapped
638
697
f .clear_cache ()
639
698
# f.copy_x = True
640
699
641
700
variables , * args = inputs
642
701
643
- res = scipy_root_scalar (
702
+ res = optimize . root_scalar (
644
703
f = f .value ,
645
704
fprime = f .grad if self .jac else None ,
646
705
fprime2 = f .hess if self .hess else None ,
@@ -676,19 +735,48 @@ def L_op(self, inputs, outputs, output_grads):
676
735
677
736
def root_scalar (
678
737
equation : TensorVariable ,
679
- variables : TensorVariable ,
738
+ variable : TensorVariable ,
680
739
method : str = "secant" ,
681
740
jac : bool = False ,
682
741
hess : bool = False ,
683
742
optimizer_kwargs : dict | None = None ,
684
- ):
743
+ ) -> tuple [ TensorVariable , TensorVariable ] :
685
744
"""
686
745
Find roots of a scalar equation using scipy.optimize.root_scalar.
746
+
747
+ Parameters
748
+ ----------
749
+ equation : TensorVariable
750
+ The equation for which to find roots. This should be a PyTensor variable representing a single equation in one
751
+ variable. The function will find `variables` such that `equation(variables, *args) = 0`.
752
+ variable : TensorVariable
753
+ The variable with respect to which the equation is solved. It must be a scalar and an input to the
754
+ computational graph of `equation`.
755
+ method : str, optional
756
+ The root-finding method to use. Default is "secant". See `scipy.optimize.root_scalar` for other options.
757
+ jac : bool, optional
758
+ Whether to compute and use the first derivative of the equation with respect to `variables`.
759
+ Default is False. Some methods require this.
760
+ hess : bool, optional
761
+ Whether to compute and use the second derivative of the equation with respect to `variables`.
762
+ Default is False. Some methods require this.
763
+ optimizer_kwargs : dict, optional
764
+ Additional keyword arguments to pass to `scipy.optimize.root_scalar`.
765
+
766
+ Returns
767
+ -------
768
+ solution: TensorVariable
769
+ The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
770
+ causes `equation` to evaluate to zero. Otherwise it is the final state returned by the root-finding
771
+ routine, but not necessarily a root.
772
+
773
+ success: TensorVariable
774
+ Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
687
775
"""
688
- args = _find_optimization_parameters (equation , variables )
776
+ args = _find_optimization_parameters (equation , variable )
689
777
690
778
root_scalar_op = RootScalarOp (
691
- variables ,
779
+ variable ,
692
780
* args ,
693
781
equation = equation ,
694
782
method = method ,
@@ -697,7 +785,11 @@ def root_scalar(
697
785
optimizer_kwargs = optimizer_kwargs ,
698
786
)
699
787
700
- return root_scalar_op (variables , * args )
788
+ solution , success = cast (
789
+ tuple [TensorVariable , TensorVariable ], root_scalar_op (variable , * args )
790
+ )
791
+
792
+ return solution , success
701
793
702
794
703
795
class RootOp (ScipyWrapperOp ):
@@ -734,6 +826,12 @@ def __init__(
734
826
self ._fn = None
735
827
self ._fn_wrapped = None
736
828
829
+ def __str__ (self ):
830
+ str_args = ", " .join (
831
+ [f"{ arg } ={ getattr (self , arg )} " for arg in ["method" , "jac" ]]
832
+ )
833
+ return f"{ self .__class__ .__name__ } ({ str_args } )"
834
+
737
835
def build_fn (self ):
738
836
outputs = self .inner_outputs
739
837
variables , * args = self .inner_inputs
@@ -761,13 +859,17 @@ def build_fn(self):
761
859
self ._fn_wrapped = LRUCache1 (fn )
762
860
763
861
def perform (self , node , inputs , outputs ):
862
+ global optimize
863
+ if optimize is None :
864
+ import scipy .optimize as optimize
865
+
764
866
f = self .fn_wrapped
765
867
f .clear_cache ()
766
868
f .copy_x = True
767
869
768
870
variables , * args = inputs
769
871
770
- res = scipy_root (
872
+ res = optimize . root (
771
873
fun = f ,
772
874
jac = self .jac ,
773
875
x0 = variables ,
@@ -815,8 +917,36 @@ def root(
815
917
method : str = "hybr" ,
816
918
jac : bool = True ,
817
919
optimizer_kwargs : dict | None = None ,
818
- ):
819
- """Find roots of a system of equations using scipy.optimize.root."""
920
+ ) -> tuple [TensorVariable , TensorVariable ]:
921
+ """
922
+ Find roots of a system of equations using scipy.optimize.root.
923
+
924
+ Parameters
925
+ ----------
926
+ equations : TensorVariable
927
+ The system of equations for which to find roots. This should be a PyTensor variable representing a
928
+ vector (or scalar) value. The function will find `variables` such that `equations(variables, *args) = 0`.
929
+ variables : TensorVariable
930
+ The variable(s) with respect to which the system of equations is solved. It must be an input to the
931
+ computational graph of `equations` and have the same number of dimensions as `equations`.
932
+ method : str, optional
933
+ The root-finding method to use. Default is "hybr". See `scipy.optimize.root` for other options.
934
+ jac : bool, optional
935
+ Whether to compute and use the Jacobian of the `equations` with respect to `variables`.
936
+ Default is True. Most methods require this.
937
+ optimizer_kwargs : dict, optional
938
+ Additional keyword arguments to pass to `scipy.optimize.root`.
939
+
940
+ Returns
941
+ -------
942
+ solution: TensorVariable
943
+ The final state of the root-finding routine. When `success` is True, this is the value of `variables` that
944
+ causes all `equations` to evaluate to zero. Otherwise it is the final state returned by the root-finding
945
+ routine, but not necessarily a root.
946
+
947
+ success: TensorVariable
948
+ Boolean indicating whether the root-finding was successful. If True, the solution is a root of the equation
949
+ """
820
950
821
951
args = _find_optimization_parameters (equations , variables )
822
952
@@ -829,7 +959,11 @@ def root(
829
959
optimizer_kwargs = optimizer_kwargs ,
830
960
)
831
961
832
- return root_op (variables , * args )
962
+ solution , success = cast (
963
+ tuple [TensorVariable , TensorVariable ], root_op (variables , * args )
964
+ )
965
+
966
+ return solution , success
833
967
834
968
835
969
__all__ = ["minimize_scalar" , "minimize" , "root_scalar" , "root" ]
0 commit comments