4
4
import warnings
5
5
from collections .abc import Callable , Mapping , MutableSequence , Sequence
6
6
from functools import partial , reduce
7
- from typing import TYPE_CHECKING , Literal , TypeVar , Union
7
+ from typing import TYPE_CHECKING , Literal , TypeVar , Union , overload
8
8
9
9
import numpy as np
10
10
@@ -414,6 +414,32 @@ def Lop(
414
414
return as_list_or_tuple (using_list , using_tuple , ret )
415
415
416
416
417
+ @overload
418
+ def grad (
419
+ cost : Variable | None ,
420
+ wrt : Variable | Sequence [Variable ],
421
+ consider_constant : Sequence [Variable ] | None = ...,
422
+ disconnected_inputs : Literal ["ignore" , "warn" , "raise" ] = ...,
423
+ add_names : bool = ...,
424
+ known_grads : Mapping [Variable , Variable ] | None = ...,
425
+ return_disconnected : Literal ["zero" , "disconnected" ] = ...,
426
+ null_gradients : Literal ["raise" , "return" ] = ...,
427
+ ) -> Variable | None | Sequence [Variable ]: ...
428
+
429
+
430
+ @overload
431
+ def grad (
432
+ cost : Variable | None ,
433
+ wrt : Variable | Sequence [Variable ],
434
+ consider_constant : Sequence [Variable ] | None = ...,
435
+ disconnected_inputs : Literal ["ignore" , "warn" , "raise" ] = ...,
436
+ add_names : bool = ...,
437
+ known_grads : Mapping [Variable , Variable ] | None = ...,
438
+ return_disconnected : Literal ["none" ] = ...,
439
+ null_gradients : Literal ["raise" , "return" ] = ...,
440
+ ) -> Variable | None | Sequence [Variable | None ]: ...
441
+
442
+
417
443
def grad (
418
444
cost : Variable | None ,
419
445
wrt : Variable | Sequence [Variable ],
@@ -423,7 +449,7 @@ def grad(
423
449
known_grads : Mapping [Variable , Variable ] | None = None ,
424
450
return_disconnected : Literal ["none" , "zero" , "disconnected" ] = "zero" ,
425
451
null_gradients : Literal ["raise" , "return" ] = "raise" ,
426
- ) -> Variable | None | Sequence [Variable | None ]:
452
+ ) -> Variable | None | Sequence [Variable | None ] | Sequence [ Variable ] :
427
453
"""
428
454
Return symbolic gradients of one cost with respect to one or more variables.
429
455
0 commit comments