Skip to content

Commit 086323f

Browse files
ArmavicaricardoV94
authored andcommitted
Add some types to printing.py
1 parent 1a92165 commit 086323f

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

pytensor/printing.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,18 +1200,18 @@ def __call__(self, *args):
12001200

12011201
def pydotprint(
12021202
fct,
1203-
outfile=None,
1204-
compact=True,
1205-
format="png",
1206-
with_ids=False,
1207-
high_contrast=True,
1203+
outfile: str | None = None,
1204+
compact: bool = True,
1205+
format: str = "png",
1206+
with_ids: bool = False,
1207+
high_contrast: bool = True,
12081208
cond_highlight=None,
1209-
colorCodes=None,
1210-
max_label_size=70,
1211-
scan_graphs=False,
1212-
var_with_name_simple=False,
1213-
print_output_file=True,
1214-
return_image=False,
1209+
colorCodes: dict | None = None,
1210+
max_label_size: int = 70,
1211+
scan_graphs: bool = False,
1212+
var_with_name_simple: bool = False,
1213+
print_output_file: bool = True,
1214+
return_image: bool = False,
12151215
):
12161216
"""Print to a file the graph of a compiled pytensor function's ops. Supports
12171217
all pydot output formats, including png and svg.
@@ -1676,7 +1676,9 @@ def get_tag(self):
16761676
return rval
16771677

16781678

1679-
def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None):
1679+
def min_informative_str(
1680+
obj, indent_level: int = 0, _prev_obs: dict | None = None, _tag_generator=None
1681+
) -> str:
16801682
"""
16811683
Returns a string specifying to the user what obj is
16821684
The string will print out as much of the graph as is needed
@@ -1776,7 +1778,7 @@ def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None
17761778
return rval
17771779

17781780

1779-
def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
1781+
def var_descriptor(obj, _prev_obs: dict | None = None, _tag_generator=None) -> str:
17801782
"""
17811783
Returns a string, with no endlines, fully specifying
17821784
how a variable is computed. Does not include any memory
@@ -1832,7 +1834,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
18321834
return rval
18331835

18341836

1835-
def position_independent_str(obj):
1837+
def position_independent_str(obj) -> str:
18361838
if isinstance(obj, Variable):
18371839
rval = "pytensor_var"
18381840
rval += "{type=" + str(obj.type) + "}"
@@ -1842,7 +1844,7 @@ def position_independent_str(obj):
18421844
return rval
18431845

18441846

1845-
def hex_digest(x):
1847+
def hex_digest(x: np.ndarray) -> str:
18461848
"""
18471849
Returns a short, mostly hexadecimal hash of a numpy ndarray
18481850
"""
@@ -1852,8 +1854,8 @@ def hex_digest(x):
18521854
# because the buffer interface only exposes the raw data, not
18531855
# any info about the semantics of how that data should be arranged
18541856
# into a tensor
1855-
rval = rval + "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
1856-
rval = rval + "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
1857+
rval += "|strides=[" + ",".join(str(stride) for stride in x.strides) + "]"
1858+
rval += "|shape=[" + ",".join(str(s) for s in x.shape) + "]"
18571859
return rval
18581860

18591861

0 commit comments

Comments
 (0)