Skip to content

Commit 8976c94

Browse files
committed
Updated function with warning when multiple variables with same name are defined
1 parent b088b01 commit 8976c94

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

pytensor/graph/basic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,11 +581,18 @@ def convert_string_keys_to_pytensor_variables(self, inputs_to_values):
581581
for i in inputs_to_values:
582582
if isinstance(i, str):
583583
nodes_with_matching_names = get_var_by_name([self], i)
584-
if len(nodes_with_matching_names) == 0:
584+
length_of_nodes_with_matching_names = len(nodes_with_matching_names)
585+
if length_of_nodes_with_matching_names == 0:
585586
raise Exception(f"{i} not found in graph")
586587
else:
588+
if length_of_nodes_with_matching_names > 1:
589+
warnings.warn(
590+
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i} taking the first declared named variable for computation"
591+
)
587592
process_input_to_values[
588-
nodes_with_matching_names[0]
593+
nodes_with_matching_names[
594+
length_of_nodes_with_matching_names - 1
595+
]
589596
] = inputs_to_values[i]
590597
else:
591598
process_input_to_values[i] = inputs_to_values[i]

tests/graph/test_basic.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,11 @@ def test_outputs_clients(self):
290290

291291
class TestEval:
292292
def setup_method(self):
293-
self.x, self.y = scalars("x", "y")
293+
self.x, self.y, self.e = scalars("x", "y", "e")
294294
self.z = self.x + self.y
295295
self.w = 2 * self.z
296+
self.t = self.e + 1
297+
self.t.name = "e"
296298

297299
def test_eval(self):
298300
assert self.w.eval({self.x: 1.0, self.y: 2.0}) == 6.0
@@ -303,12 +305,11 @@ def test_eval(self):
303305
), "temporary functions must not be serialized"
304306

305307
def test_eval_with_strings(self):
306-
assert self.w.eval({"x": 1.0, "y": 2.0}) == 6.0
308+
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0
307309
assert self.w.eval({self.z: 3}) == 6.0
308-
assert hasattr(self.w, "_fn_cache"), "variable must have cache after eval"
309-
assert not hasattr(
310-
pickle.loads(pickle.dumps(self.w)), "_fn_cache"
311-
), "temporary functions must not be serialized"
310+
311+
def test_eval_with_strings_with_mulitple_same_name(self):
312+
assert self.t.eval({"e": 1.0}) == 2.0
312313

313314

314315
class TestAutoName:

0 commit comments

Comments
 (0)