Skip to content

Commit b7f745e

Browse files
committed
Making convert_string_keys_to_variables internal to eval
1 parent 8976c94 commit b7f745e

File tree

2 files changed

+30
-48
lines changed

2 files changed

+30
-48
lines changed

pytensor/graph/basic.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -558,46 +558,6 @@ def get_parents(self):
558558
return [self.owner]
559559
return []
560560

561-
def convert_string_keys_to_pytensor_variables(self, inputs_to_values):
562-
r"""Convert the string keys to corresponding `Variable` with nearest name.
563-
564-
Parameters
565-
----------
566-
inputs_to_values :
567-
A dictionary mapping PyTensor `Variable`\s to values.
568-
569-
Examples
570-
--------
571-
572-
>>> import numpy as np
573-
>>> import pytensor.tensor as at
574-
>>> x = at.dscalar('x')
575-
>>> y = at.dscalar('y')
576-
>>> z = x + y
577-
>>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4)
578-
True
579-
"""
580-
process_input_to_values = {}
581-
for i in inputs_to_values:
582-
if isinstance(i, str):
583-
nodes_with_matching_names = get_var_by_name([self], i)
584-
length_of_nodes_with_matching_names = len(nodes_with_matching_names)
585-
if length_of_nodes_with_matching_names == 0:
586-
raise Exception(f"{i} not found in graph")
587-
else:
588-
if length_of_nodes_with_matching_names > 1:
589-
warnings.warn(
590-
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i} taking the first declared named variable for computation"
591-
)
592-
process_input_to_values[
593-
nodes_with_matching_names[
594-
length_of_nodes_with_matching_names - 1
595-
]
596-
] = inputs_to_values[i]
597-
else:
598-
process_input_to_values[i] = inputs_to_values[i]
599-
return process_input_to_values
600-
601561
def eval(self, inputs_to_values=None):
602562
r"""Evaluate the `Variable`.
603563
@@ -637,9 +597,29 @@ def eval(self, inputs_to_values=None):
637597
if inputs_to_values is None:
638598
inputs_to_values = {}
639599

640-
inputs_to_values = self.convert_string_keys_to_pytensor_variables(
641-
inputs_to_values
642-
)
600+
def convert_string_keys_to_variables():
601+
process_input_to_values = {}
602+
for i in inputs_to_values:
603+
if isinstance(i, str):
604+
nodes_with_matching_names = get_var_by_name([self], i)
605+
length_of_nodes_with_matching_names = len(nodes_with_matching_names)
606+
if length_of_nodes_with_matching_names == 0:
607+
raise Exception(f"{i} not found in graph")
608+
else:
609+
if length_of_nodes_with_matching_names > 1:
610+
raise Exception(
611+
f"Found {length_of_nodes_with_matching_names} pytensor variables with name {i}"
612+
)
613+
process_input_to_values[
614+
nodes_with_matching_names[
615+
length_of_nodes_with_matching_names - 1
616+
]
617+
] = inputs_to_values[i]
618+
else:
619+
process_input_to_values[i] = inputs_to_values[i]
620+
return process_input_to_values
621+
622+
inputs_to_values = convert_string_keys_to_variables()
643623

644624
if not hasattr(self, "_fn_cache"):
645625
self._fn_cache = dict()

tests/graph/test_basic.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,9 @@ def test_outputs_clients(self):
290290

291291
class TestEval:
292292
def setup_method(self):
293-
self.x, self.y, self.e = scalars("x", "y", "e")
293+
self.x, self.y = scalars("x", "y")
294294
self.z = self.x + self.y
295295
self.w = 2 * self.z
296-
self.t = self.e + 1
297-
self.t.name = "e"
298296

299297
def test_eval(self):
300298
assert self.w.eval({self.x: 1.0, self.y: 2.0}) == 6.0
@@ -308,8 +306,12 @@ def test_eval_with_strings(self):
308306
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0
309307
assert self.w.eval({self.z: 3}) == 6.0
310308

311-
def test_eval_with_strings_with_mulitple_same_name(self):
312-
assert self.t.eval({"e": 1.0}) == 2.0
309+
def test_eval_errors_having_mulitple_variables_same_name(self):
310+
e = scalars("e")
311+
t = e + 1
312+
t.name = "e"
313+
with pytest.raises(Exception, match="Found 2 pytensor variables with name e"):
314+
t.eval({"e": 1})
313315

314316

315317
class TestAutoName:

0 commit comments

Comments
 (0)