@@ -558,6 +558,39 @@ def get_parents(self):
558
558
return [self .owner ]
559
559
return []
560
560
561
+ def convert_string_keys_to_pytensor_variables (self , inputs_to_values ):
562
+ r"""Convert the string keys to corresponding `Variable` with nearest name.
563
+
564
+ Parameters
565
+ ----------
566
+ inputs_to_values :
567
+ A dictionary mapping PyTensor `Variable`\s to values.
568
+
569
+ Examples
570
+ --------
571
+
572
+ >>> import numpy as np
573
+ >>> import pytensor.tensor as at
574
+ >>> x = at.dscalar('x')
575
+ >>> y = at.dscalar('y')
576
+ >>> z = x + y
577
+ >>> np.allclose(z.eval({'x' : 3, 'y' : 1}), 4)
578
+ True
579
+ """
580
+ process_input_to_values = {}
581
+ for i in inputs_to_values :
582
+ if isinstance (i , str ):
583
+ nodes_with_matching_names = get_var_by_name ([self ], i )
584
+ if len (nodes_with_matching_names ) == 0 :
585
+ raise Exception (f"{ i } not found in graph" )
586
+ else :
587
+ process_input_to_values [
588
+ nodes_with_matching_names [0 ]
589
+ ] = inputs_to_values [i ]
590
+ else :
591
+ process_input_to_values [i ] = inputs_to_values [i ]
592
+ return process_input_to_values
593
+
561
594
def eval (self , inputs_to_values = None ):
562
595
r"""Evaluate the `Variable`.
563
596
@@ -597,6 +630,10 @@ def eval(self, inputs_to_values=None):
597
630
if inputs_to_values is None :
598
631
inputs_to_values = {}
599
632
633
+ inputs_to_values = self .convert_string_keys_to_pytensor_variables (
634
+ inputs_to_values
635
+ )
636
+
600
637
if not hasattr (self , "_fn_cache" ):
601
638
self ._fn_cache = dict ()
602
639
0 commit comments