Skip to content

Commit 38ef6cc

Browse files
committed
Add support for automatic naming of random variables
1 parent 5688555 commit 38ef6cc

File tree

2 files changed

+73
-6
lines changed

2 files changed

+73
-6
lines changed

pymc/distributions/distribution.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Callable, Optional, Sequence, Tuple, Union
2323

2424
import numpy as np
25+
import opcode
2526

2627
from aesara import tensor as at
2728
from aesara.compile.builders import OpFromGraph
@@ -164,6 +165,45 @@ def fn(*args, **kwargs):
164165
return fn
165166

166167

168+
# Helper function from pyprob
169+
def _extract_target_of_assignment(depth):
170+
frame = sys._getframe(depth)
171+
code = frame.f_code
172+
next_instruction = code.co_code[frame.f_lasti + 2]
173+
instruction_arg = code.co_code[frame.f_lasti + 3]
174+
instruction_name = opcode.opname[next_instruction]
175+
if instruction_name == "STORE_FAST":
176+
return code.co_varnames[instruction_arg]
177+
elif instruction_name in ["STORE_NAME", "STORE_GLOBAL"]:
178+
return code.co_names[instruction_arg]
179+
elif (
180+
instruction_name in ["LOAD_FAST", "LOAD_NAME", "LOAD_GLOBAL"]
181+
and opcode.opname[code.co_code[frame.f_lasti + 4]] in ["LOAD_CONST", "LOAD_FAST"]
182+
and opcode.opname[code.co_code[frame.f_lasti + 6]] == "STORE_SUBSCR"
183+
):
184+
if instruction_name == "LOAD_FAST":
185+
base_name = code.co_varnames[instruction_arg]
186+
else:
187+
base_name = code.co_names[instruction_arg]
188+
189+
second_instruction = opcode.opname[code.co_code[frame.f_lasti + 4]]
190+
second_arg = code.co_code[frame.f_lasti + 5]
191+
if second_instruction == "LOAD_CONST":
192+
value = code.co_consts[second_arg]
193+
elif second_instruction == "LOAD_FAST":
194+
var_name = code.co_varnames[second_arg]
195+
value = frame.f_locals[var_name]
196+
else:
197+
value = None
198+
if value is not None:
199+
index_name = repr(value)
200+
return base_name + "[" + index_name + "]"
201+
else:
202+
return None
203+
else:
204+
return None
205+
206+
167207
class SymbolicRandomVariable(OpFromGraph):
168208
"""Symbolic Random Variable
169209
@@ -216,7 +256,6 @@ class Distribution(metaclass=DistributionMeta):
216256

217257
def __new__(
218258
cls,
219-
name: str,
220259
*args,
221260
rng=None,
222261
dims: Optional[Dims] = None,
@@ -234,8 +273,6 @@ def __new__(
234273
----------
235274
cls : type
236275
A PyMC distribution.
237-
name : str
238-
Name for the new model variable.
239276
rng : optional
240277
Random number generator to use with the RandomVariable.
241278
dims : tuple, optional
@@ -277,6 +314,19 @@ def __new__(
277314
"for a standalone distribution."
278315
)
279316

317+
if "name" in kwargs:
318+
name = kwargs.pop("name")
319+
elif len(args) > 0 and isinstance(args[0], string_types):
320+
name = args[0]
321+
args = args[1:]
322+
else:
323+
name = _extract_target_of_assignment(2)
324+
if name is None:
325+
raise TypeError("Name could not be inferred for variable")
326+
327+
if not isinstance(name, string_types):
328+
raise TypeError(f"Name needs to be a string but got: {name}")
329+
280330
if "testval" in kwargs:
281331
initval = kwargs.pop("testval")
282332
warnings.warn(
@@ -285,9 +335,6 @@ def __new__(
285335
stacklevel=2,
286336
)
287337

288-
if not isinstance(name, string_types):
289-
raise TypeError(f"Name needs to be a string but got: {name}")
290-
291338
dims = convert_dims(dims)
292339
if observed is not None:
293340
observed = convert_observed_data(observed)

pymc/tests/distributions/test_distribution.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,3 +416,23 @@ def test_tag_future_warning_dist():
416416
with pytest.warns(FutureWarning, match="Use model.rvs_to_values"):
417417
value_var = new_x.tag.value_var
418418
assert value_var == "1"
419+
420+
421+
def test_autonaming():
422+
"""Test that random variable ends up with same name as assignment variable"""
423+
d = {}
424+
with pm.Model() as m:
425+
x = pm.Normal(0.0, 1.0)
426+
d[2] = pm.Normal(0.0, 1.0)
427+
428+
assert x.name == "x"
429+
assert m["x"] == x
430+
assert d[2].name == "d[2]"
431+
assert m["d[2]"] == d[2]
432+
433+
434+
def test_autonaming_noname():
435+
"""Test that autonaming fails if no assignment can be found"""
436+
with pytest.raises(TypeError, match="Name could not be inferred for variable"):
437+
with pm.Model():
438+
pm.Normal()

0 commit comments

Comments
 (0)