Skip to content

Commit b7eb139

Browse files
committed
Cleanup Function.__call__
1 parent 5c75d1a commit b7eb139

File tree

1 file changed

+57
-64
lines changed

1 file changed

+57
-64
lines changed

pytensor/compile/function/types.py

Lines changed: 57 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ class Function:
326326
def __init__(
327327
self,
328328
vm: "VM",
329-
input_storage,
330-
output_storage,
329+
input_storage: list[Container],
330+
output_storage: list[Container],
331331
indices,
332332
outputs,
333333
defaults,
@@ -388,6 +388,8 @@ def __init__(
388388
self.nodes_with_inner_function = []
389389
self.output_keys = output_keys
390390

391+
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
392+
391393
# See if we have any mutable / borrow inputs
392394
# TODO: this only need to be set if there is more than one input
393395
self._check_for_aliased_inputs = False
@@ -408,11 +410,6 @@ def __init__(
408410
finder = {}
409411
inv_finder = {}
410412

411-
def distribute(indices, cs, value):
412-
input.distribute(value, indices, cs)
413-
for c in cs:
414-
c.provided += 1
415-
416413
# Store the list of names of named inputs.
417414
named_inputs = []
418415
# Count the number of un-named inputs.
@@ -777,6 +774,13 @@ def checkSV(sv_ori, sv_rpl):
777774
f_cpy.maker.fgraph.name = name
778775
return f_cpy
779776

777+
def _restore_defaults(self):
778+
for i, (required, refeed, value) in enumerate(self.defaults):
779+
if refeed:
780+
if isinstance(value, Container):
781+
value = value.storage[0]
782+
self[i] = value
783+
780784
def __call__(self, *args, **kwargs):
781785
"""
782786
Evaluates value of a function on given arguments.
@@ -814,43 +818,42 @@ def restore_defaults():
814818
self[i] = value
815819

816820
profile = self.profile
817-
t0 = time.perf_counter()
821+
if profile:
822+
t0 = time.perf_counter()
818823

819824
output_subset = kwargs.pop("output_subset", None)
820825
if output_subset is not None and self.output_keys is not None:
821826
output_subset = [self.output_keys.index(key) for key in output_subset]
822827

823828
# Reinitialize each container's 'provided' counter
824829
if self.trust_input:
825-
i = 0
826-
for arg in args:
827-
s = self.input_storage[i]
828-
s.storage[0] = arg
829-
i += 1
830+
for arg_container, arg in zip(self.input_storage, args, strict=False):
831+
arg_container.storage[0] = arg
830832
else:
831-
for c in self.input_storage:
832-
c.provided = 0
833+
for arg_container in self.input_storage:
834+
arg_container.provided = 0
833835

834836
if len(args) + len(kwargs) > len(self.input_storage):
835837
raise TypeError("Too many parameter passed to pytensor function")
836838

837839
# Set positional arguments
838-
i = 0
839-
for arg in args:
840+
for arg_container, arg in zip(self.input_storage, args, strict=False):
840841
# TODO: provide a option for skipping the filter if we really
841842
# want speed.
842-
s = self.input_storage[i]
843843
# see this emails for a discuation about None as input
844844
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
845845
if arg is None:
846-
s.storage[0] = arg
846+
arg_container.storage[0] = arg
847847
else:
848848
try:
849-
s.storage[0] = s.type.filter(
850-
arg, strict=s.strict, allow_downcast=s.allow_downcast
849+
arg_container.storage[0] = arg_container.type.filter(
850+
arg,
851+
strict=arg_container.strict,
852+
allow_downcast=arg_container.allow_downcast,
851853
)
852854

853855
except Exception as e:
856+
i = self.input_storage.index(arg_container)
854857
function_name = "pytensor function"
855858
argument_name = "argument"
856859
if self.name:
@@ -875,27 +878,23 @@ def restore_defaults():
875878
+ function_name
876879
+ f" at index {int(i)} (0-based). {where}"
877880
) + e.args
878-
restore_defaults()
881+
self._restore_defaults()
879882
raise
880-
s.provided += 1
881-
i += 1
883+
arg_container.provided += 1
882884

883885
# Set keyword arguments
884886
if kwargs: # for speed, skip the items for empty kwargs
885887
for k, arg in kwargs.items():
886888
self[k] = arg
887889

888-
if (
889-
not self.trust_input
890-
and
891-
# The getattr is only needed for old pickle
892-
getattr(self, "_check_for_aliased_inputs", True)
893-
):
890+
if not self.trust_input and self._check_for_aliased_inputs:
894891
# Collect aliased inputs among the storage space
895892
args_share_memory = []
896-
for i in range(len(self.input_storage)):
897-
i_var = self.maker.inputs[i].variable
898-
i_val = self.input_storage[i].storage[0]
893+
for i, (inp, arg_storage) in enumerate(
894+
zip(self.maker.inputs, self.input_storage)
895+
):
896+
i_var = inp.variable
897+
i_val = arg_storage.storage[0]
899898
if hasattr(i_var.type, "may_share_memory"):
900899
is_aliased = False
901900
for j in range(len(args_share_memory)):
@@ -932,36 +931,36 @@ def restore_defaults():
932931
self.input_storage[j].storage[0]
933932
)
934933

935-
# Check if inputs are missing, or if inputs were set more than once, or
936-
# if we tried to provide inputs that are supposed to be implicit.
937-
if not self.trust_input:
938-
for c in self.input_storage:
939-
if c.required and not c.provided:
940-
restore_defaults()
934+
# Check if inputs are missing, or if inputs were set more than once, or
935+
# if we tried to provide inputs that are supposed to be implicit.
936+
for arg_container in self.input_storage:
937+
if arg_container.required and not arg_container.provided:
938+
self._restore_defaults()
941939
raise TypeError(
942-
f"Missing required input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
940+
f"Missing required input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
943941
)
944-
if c.provided > 1:
945-
restore_defaults()
942+
if arg_container.provided > 1:
943+
self._restore_defaults()
946944
raise TypeError(
947-
f"Multiple values for input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
945+
f"Multiple values for input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
948946
)
949-
if c.implicit and c.provided > 0:
950-
restore_defaults()
947+
if arg_container.implicit and arg_container.provided > 0:
948+
self._restore_defaults()
951949
raise TypeError(
952-
f"Tried to provide value for implicit input: {getattr(self.inv_finder[c], 'variable', self.inv_finder[c])}"
950+
f"Tried to provide value for implicit input: {getattr(self.inv_finder[arg_container], 'variable', self.inv_finder[arg_container])}"
953951
)
954952

955953
# Do the actual work
956-
t0_fn = time.perf_counter()
954+
if profile:
955+
t0_fn = time.perf_counter()
957956
try:
958957
outputs = (
959958
self.vm()
960959
if output_subset is None
961960
else self.vm(output_subset=output_subset)
962961
)
963962
except Exception:
964-
restore_defaults()
963+
self._restore_defaults()
965964
if hasattr(self.vm, "position_of_error"):
966965
# this is a new vm-provided function or c linker
967966
# they need this because the exception manipulation
@@ -979,9 +978,9 @@ def restore_defaults():
979978
# old-style linkers raise their own exceptions
980979
raise
981980

982-
dt_fn = time.perf_counter() - t0_fn
983-
self.maker.mode.fn_time += dt_fn
984981
if profile:
982+
dt_fn = time.perf_counter() - t0_fn
983+
self.maker.mode.fn_time += dt_fn
985984
profile.vm_call_time += dt_fn
986985

987986
# Retrieve the values that were computed
@@ -991,14 +990,13 @@ def restore_defaults():
991990

992991
# Remove internal references to required inputs.
993992
# These cannot be re-used anyway.
994-
for c in self.input_storage:
995-
if c.required:
996-
c.storage[0] = None
993+
for arg_container in self.input_storage:
994+
if arg_container.required:
995+
arg_container.storage[0] = None
997996

998997
# if we are allowing garbage collection, remove the
999998
# output reference from the internal storage cells
1000999
if getattr(self.vm, "allow_gc", False):
1001-
assert len(self.output_storage) == len(self.maker.fgraph.outputs)
10021000
for o_container, o_variable in zip(
10031001
self.output_storage, self.maker.fgraph.outputs
10041002
):
@@ -1020,17 +1018,12 @@ def restore_defaults():
10201018
outputs = outputs[: self.n_returned_outputs]
10211019

10221020
# Put default values back in the storage
1023-
restore_defaults()
1024-
#
1025-
# NOTE: This logic needs to be replicated in
1026-
# scan.
1027-
# grep for 'PROFILE_CODE'
1028-
#
1029-
1030-
dt_call = time.perf_counter() - t0
1031-
pytensor.compile.profiling.total_fct_exec_time += dt_call
1032-
self.maker.mode.call_time += dt_call
1021+
self._restore_defaults()
1022+
10331023
if profile:
1024+
dt_call = time.perf_counter() - t0
1025+
pytensor.compile.profiling.total_fct_exec_time += dt_call
1026+
self.maker.mode.call_time += dt_call
10341027
profile.fct_callcount += 1
10351028
profile.fct_call_time += dt_call
10361029
if hasattr(self.vm, "update_profile"):

0 commit comments

Comments
 (0)