11
11
import pytensor
12
12
from pytensor .compile .ops import ViewOp
13
13
from pytensor .configdefaults import config
14
- from pytensor .graph import utils
14
+ from pytensor .graph import utils , vectorize_graph
15
15
from pytensor .graph .basic import Apply , NominalVariable , Variable
16
16
from pytensor .graph .null_type import NullType , null_type
17
17
from pytensor .graph .op import get_test_values
@@ -703,15 +703,15 @@ def grad(
703
703
grad_dict [var ] = g_var
704
704
705
705
def handle_disconnected (var ):
706
- message = (
707
- "grad method was asked to compute the gradient "
708
- "with respect to a variable that is not part of "
709
- "the computational graph of the cost, or is used "
710
- f"only by a non-differentiable operator: { var } "
711
- )
712
706
if disconnected_inputs == "ignore" :
713
- pass
707
+ return
714
708
elif disconnected_inputs == "warn" :
709
+ message = (
710
+ "grad method was asked to compute the gradient "
711
+ "with respect to a variable that is not part of "
712
+ "the computational graph of the cost, or is used "
713
+ f"only by a non-differentiable operator: { var } "
714
+ )
715
715
warnings .warn (message , stacklevel = 2 )
716
716
elif disconnected_inputs == "raise" :
717
717
message = utils .get_variable_trace_string (var )
@@ -2021,13 +2021,19 @@ def __str__(self):
2021
2021
Exception args: { args_msg } """
2022
2022
2023
2023
2024
- def jacobian (expression , wrt , consider_constant = None , disconnected_inputs = "raise" ):
2024
+ def jacobian (
2025
+ expression ,
2026
+ wrt ,
2027
+ consider_constant = None ,
2028
+ disconnected_inputs = "raise" ,
2029
+ vectorize : bool = False ,
2030
+ ):
2025
2031
"""
2026
2032
Compute the full Jacobian, row by row.
2027
2033
2028
2034
Parameters
2029
2035
----------
2030
- expression : Vector (1-dimensional) : class:`~pytensor.graph.basic.Variable`
2036
+ expression :class:`~pytensor.graph.basic.Variable`
2031
2037
Values that we are differentiating (that we want the Jacobian of)
2032
2038
wrt : :class:`~pytensor.graph.basic.Variable` or list of Variables
2033
2039
Term[s] with respect to which we compute the Jacobian
@@ -2051,62 +2057,73 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
2051
2057
output, then a zero variable is returned. The return value is
2052
2058
of same type as `wrt`: a list/tuple or TensorVariable in all cases.
2053
2059
"""
2060
+ from pytensor .tensor import broadcast_to , eye
2054
2061
2055
2062
if not isinstance (expression , Variable ):
2056
2063
raise TypeError ("jacobian expects a Variable as `expression`" )
2057
2064
2058
- if expression .ndim > 1 :
2059
- raise ValueError (
2060
- "jacobian expects a 1 dimensional variable as `expression`."
2061
- " If not use flatten to make it a vector"
2062
- )
2063
-
2064
2065
using_list = isinstance (wrt , list )
2065
2066
using_tuple = isinstance (wrt , tuple )
2067
+ grad_kwargs = {
2068
+ "consider_constant" : consider_constant ,
2069
+ "disconnected_inputs" : disconnected_inputs ,
2070
+ }
2066
2071
2067
2072
if isinstance (wrt , list | tuple ):
2068
2073
wrt = list (wrt )
2069
2074
else :
2070
2075
wrt = [wrt ]
2071
2076
2072
2077
if all (expression .type .broadcastable ):
2073
- # expression is just a scalar, use grad
2074
- return as_list_or_tuple (
2075
- using_list ,
2076
- using_tuple ,
2077
- grad (
2078
- expression .squeeze (),
2079
- wrt ,
2080
- consider_constant = consider_constant ,
2081
- disconnected_inputs = disconnected_inputs ,
2082
- ),
2078
+ jacobian_matrices = grad (expression .squeeze (), wrt , ** grad_kwargs )
2079
+
2080
+ elif vectorize :
2081
+ expression_flat = expression .ravel ()
2082
+ row_tangent = _float_ones_like (expression_flat ).type ("row_tangent" )
2083
+ jacobian_single_rows = Lop (expression .ravel (), wrt , row_tangent , ** grad_kwargs )
2084
+
2085
+ n_rows = expression_flat .size
2086
+ jacobian_matrices = vectorize_graph (
2087
+ jacobian_single_rows ,
2088
+ replace = {row_tangent : eye (n_rows , dtype = row_tangent .dtype )},
2083
2089
)
2090
+ if disconnected_inputs != "raise" :
2091
+ # If the input is disconnected from the cost, `vectorize_graph` has no effect on the respective jacobian
2092
+ # We have to broadcast the zeros explicitly here
2093
+ for i , (jacobian_single_row , jacobian_matrix ) in enumerate (
2094
+ zip (jacobian_single_rows , jacobian_matrices , strict = True )
2095
+ ):
2096
+ if jacobian_single_row .ndim == jacobian_matrix .ndim :
2097
+ jacobian_matrices [i ] = broadcast_to (
2098
+ jacobian_matrix , shape = (n_rows , * jacobian_matrix .shape )
2099
+ )
2084
2100
2085
- def inner_function (* args ):
2086
- idx = args [0 ]
2087
- expr = args [1 ]
2088
- rvals = []
2089
- for inp in args [2 :]:
2090
- rval = grad (
2091
- expr [idx ],
2092
- inp ,
2093
- consider_constant = consider_constant ,
2094
- disconnected_inputs = disconnected_inputs ,
2101
+ else :
2102
+
2103
+ def inner_function (* args ):
2104
+ idx , expr , * wrt = args
2105
+ return grad (expr [idx ], wrt , ** grad_kwargs )
2106
+
2107
+ jacobian_matrices , updates = pytensor .scan (
2108
+ inner_function ,
2109
+ sequences = pytensor .tensor .arange (expression .size ),
2110
+ non_sequences = [expression .ravel (), * wrt ],
2111
+ return_list = True ,
2112
+ )
2113
+ if updates :
2114
+ raise ValueError (
2115
+ "The scan used to build the jacobian matrices returned a list of updates"
2095
2116
)
2096
- rvals .append (rval )
2097
- return rvals
2098
-
2099
- # Computing the gradients does not affect the random seeds on any random
2100
- # generator used n expression (because during computing gradients we are
2101
- # just backtracking over old values. (rp Jan 2012 - if anyone has a
2102
- # counter example please show me)
2103
- jacobs , updates = pytensor .scan (
2104
- inner_function ,
2105
- sequences = pytensor .tensor .arange (expression .shape [0 ]),
2106
- non_sequences = [expression , * wrt ],
2107
- )
2108
- assert not updates , "Scan has returned a list of updates; this should not happen."
2109
- return as_list_or_tuple (using_list , using_tuple , jacobs )
2117
+
2118
+ if jacobian_matrices [0 ].ndim < (expression .ndim + wrt [0 ].ndim ):
2119
+ # There was some raveling or squeezing done prior to getting the jacobians
2120
+ # Reshape into original shapes
2121
+ jacobian_matrices = [
2122
+ jac_matrix .reshape ((* expression .shape , * w .shape ))
2123
+ for jac_matrix , w in zip (jacobian_matrices , wrt , strict = True )
2124
+ ]
2125
+
2126
+ return as_list_or_tuple (using_list , using_tuple , jacobian_matrices )
2110
2127
2111
2128
2112
2129
def hessian (cost , wrt , consider_constant = None , disconnected_inputs = "raise" ):
0 commit comments