diff --git a/advanced_source/cpp_extension.rst b/advanced_source/cpp_extension.rst index 735ad6d6263..52f541a31bb 100644 --- a/advanced_source/cpp_extension.rst +++ b/advanced_source/cpp_extension.rst @@ -428,7 +428,7 @@ class citizens of PyTorch:: def forward(ctx, input, weights, bias, old_h, old_cell): outputs = lltm.forward(input, weights, bias, old_h, old_cell) new_h, new_cell = outputs[:2] - variables = outputs[1:] + [weights, old_cell] + variables = outputs[1:] + [weights] ctx.save_for_backward(*variables) return new_h, new_cell @@ -437,7 +437,7 @@ class citizens of PyTorch:: def backward(ctx, grad_h, grad_cell): outputs = lltm.backward( grad_h.contiguous(), grad_cell.contiguous(), *ctx.saved_variables) - d_old_h, d_input, d_weights, d_bias, d_old_cell, d_gates = outputs + d_old_h, d_input, d_weights, d_bias, d_old_cell = outputs return d_input, d_weights, d_bias, d_old_h, d_old_cell