Skip to content

Commit feccc41

Browse files
authored
Allow string keys in eval utility (#242)
1 parent 74dca20 commit feccc41

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

pytensor/graph/basic.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,22 @@ def eval(self, inputs_to_values=None):
597597
if inputs_to_values is None:
598598
inputs_to_values = {}
599599

600+
def convert_string_keys_to_variables(input_to_values):
601+
new_input_to_values = {}
602+
for key, value in inputs_to_values.items():
603+
if isinstance(key, str):
604+
matching_vars = get_var_by_name([self], key)
605+
if not matching_vars:
606+
raise Exception(f"{key} not found in graph")
607+
elif len(matching_vars) > 1:
608+
raise Exception(f"Found multiple variables with name {key}")
609+
new_input_to_values[matching_vars[0]] = value
610+
else:
611+
new_input_to_values[key] = value
612+
return new_input_to_values
613+
614+
inputs_to_values = convert_string_keys_to_variables(inputs_to_values)
615+
600616
if not hasattr(self, "_fn_cache"):
601617
self._fn_cache = dict()
602618

tests/graph/test_basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,24 @@ def test_eval(self):
302302
pickle.loads(pickle.dumps(self.w)), "_fn_cache"
303303
), "temporary functions must not be serialized"
304304

305+
def test_eval_with_strings(self):
306+
assert self.w.eval({"x": 1.0, self.y: 2.0}) == 6.0
307+
assert self.w.eval({self.z: 3}) == 6.0
308+
309+
def test_eval_with_strings_multiple_matches(self):
310+
e = scalars("e")
311+
t = e + 1
312+
t.name = "e"
313+
with pytest.raises(Exception, match="Found multiple variables with name e"):
314+
t.eval({"e": 1})
315+
316+
def test_eval_with_strings_no_match(self):
317+
e = scalars("e")
318+
t = e + 1
319+
t.name = "p"
320+
with pytest.raises(Exception, match="o not found in graph"):
321+
t.eval({"o": 1})
322+
305323

306324
class TestAutoName:
307325
def test_auto_name(self):

0 commit comments

Comments
 (0)