Skip to content

Commit b088b01

Browse files
committed
Added functionality that can use name of variable in eval
1 parent 277559b commit b088b01

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

pytensor/graph/basic.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,39 @@ def get_parents(self):
558558
return [self.owner]
559559
return []
560560

561+
def convert_string_keys_to_pytensor_variables(self, inputs_to_values):
562+
r"""Convert the string keys to corresponding `Variable` with nearest name.
563+
564+
Parameters
565+
----------
566+
inputs_to_values :
567+
A dictionary mapping PyTensor `Variable`\s to values.
568+
569+
Examples
570+
--------
571+
572+
>>> import numpy as np
573+
>>> import pytensor.tensor as at
574+
>>> x = at.dscalar('x')
575+
>>> y = at.dscalar('y')
576+
>>> z = x + y
577+
>>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4)
578+
True
579+
"""
580+
process_input_to_values = {}
581+
for i in inputs_to_values:
582+
if isinstance(i, str):
583+
nodes_with_matching_names = get_var_by_name([self], i)
584+
if len(nodes_with_matching_names) == 0:
585+
raise Exception(f"{i} not found in graph")
586+
else:
587+
process_input_to_values[
588+
nodes_with_matching_names[0]
589+
] = inputs_to_values[i]
590+
else:
591+
process_input_to_values[i] = inputs_to_values[i]
592+
return process_input_to_values
593+
561594
def eval(self, inputs_to_values=None):
562595
r"""Evaluate the `Variable`.
563596
@@ -597,6 +630,10 @@ def eval(self, inputs_to_values=None):
597630
if inputs_to_values is None:
598631
inputs_to_values = {}
599632

633+
inputs_to_values = self.convert_string_keys_to_pytensor_variables(
634+
inputs_to_values
635+
)
636+
600637
if not hasattr(self, "_fn_cache"):
601638
self._fn_cache = dict()
602639

tests/graph/test_basic.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,14 @@ def test_eval(self):
302302
pickle.loads(pickle.dumps(self.w)), "_fn_cache"
303303
), "temporary functions must not be serialized"
304304

305+
def test_eval_with_strings(self):
306+
assert self.w.eval({"x": 1.0, "y": 2.0}) == 6.0
307+
assert self.w.eval({self.z: 3}) == 6.0
308+
assert hasattr(self.w, "_fn_cache"), "variable must have cache after eval"
309+
assert not hasattr(
310+
pickle.loads(pickle.dumps(self.w)), "_fn_cache"
311+
), "temporary functions must not be serialized"
312+
305313

306314
class TestAutoName:
307315
def test_auto_name(self):

0 commit comments

Comments
 (0)