Skip to content

Commit 962772f

Browse files
committed
mypy: fix graph.py
1 parent f8aa9bd commit 962772f

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

pytensor/gradient.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import warnings
55
from collections.abc import Callable, Mapping, MutableSequence, Sequence
66
from functools import partial, reduce
7-
from typing import TYPE_CHECKING, Literal, TypeVar, Union
7+
from typing import TYPE_CHECKING, Literal, TypeVar, Union, overload
88

99
import numpy as np
1010

@@ -414,6 +414,32 @@ def Lop(
414414
return as_list_or_tuple(using_list, using_tuple, ret)
415415

416416

417+
@overload
418+
def grad(
419+
cost: Variable | None,
420+
wrt: Variable | Sequence[Variable],
421+
consider_constant: Sequence[Variable] | None = ...,
422+
disconnected_inputs: Literal["ignore", "warn", "raise"] = ...,
423+
add_names: bool = ...,
424+
known_grads: Mapping[Variable, Variable] | None = ...,
425+
return_disconnected: Literal["zero", "disconnected"] = ...,
426+
null_gradients: Literal["raise", "return"] = ...,
427+
) -> Variable | None | Sequence[Variable]: ...
428+
429+
430+
@overload
431+
def grad(
432+
cost: Variable | None,
433+
wrt: Variable | Sequence[Variable],
434+
consider_constant: Sequence[Variable] | None = ...,
435+
disconnected_inputs: Literal["ignore", "warn", "raise"] = ...,
436+
add_names: bool = ...,
437+
known_grads: Mapping[Variable, Variable] | None = ...,
438+
return_disconnected: Literal["none"] = ...,
439+
null_gradients: Literal["raise", "return"] = ...,
440+
) -> Variable | None | Sequence[Variable | None]: ...
441+
442+
417443
def grad(
418444
cost: Variable | None,
419445
wrt: Variable | Sequence[Variable],
@@ -423,7 +449,7 @@ def grad(
423449
known_grads: Mapping[Variable, Variable] | None = None,
424450
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
425451
null_gradients: Literal["raise", "return"] = "raise",
426-
) -> Variable | None | Sequence[Variable | None]:
452+
) -> Variable | None | Sequence[Variable | None] | Sequence[Variable]:
427453
"""
428454
Return symbolic gradients of one cost with respect to one or more variables.
429455

pytensor/graph/basic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,8 +1313,9 @@ def clone_get_equiv(
13131313
outputs: Reversible[Variable],
13141314
copy_inputs: bool = True,
13151315
copy_orphans: bool = True,
1316-
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
1317-
| None = None,
1316+
memo: (
1317+
dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] | None
1318+
) = None,
13181319
clone_inner_graphs: bool = False,
13191320
**kwargs,
13201321
) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]:

0 commit comments

Comments
 (0)