Skip to content

Add support for automatic naming of random variables #6364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Callable, Optional, Sequence, Tuple, Union

import numpy as np
import opcode

from aesara import tensor as at
from aesara.compile.builders import OpFromGraph
Expand Down Expand Up @@ -164,6 +165,45 @@ def fn(*args, **kwargs):
return fn


# Helper function from pyprob
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

URL to original commit on pyprob: pyprob/pyprob@6aac747

def _extract_target_of_assignment(depth):
frame = sys._getframe(depth)
code = frame.f_code
next_instruction = code.co_code[frame.f_lasti + 2]
instruction_arg = code.co_code[frame.f_lasti + 3]
instruction_name = opcode.opname[next_instruction]
if instruction_name == "STORE_FAST":
return code.co_varnames[instruction_arg]
elif instruction_name in ["STORE_NAME", "STORE_GLOBAL"]:
return code.co_names[instruction_arg]
elif (
instruction_name in ["LOAD_FAST", "LOAD_NAME", "LOAD_GLOBAL"]
and opcode.opname[code.co_code[frame.f_lasti + 4]] in ["LOAD_CONST", "LOAD_FAST"]
and opcode.opname[code.co_code[frame.f_lasti + 6]] == "STORE_SUBSCR"
):
if instruction_name == "LOAD_FAST":
base_name = code.co_varnames[instruction_arg]
else:
base_name = code.co_names[instruction_arg]

second_instruction = opcode.opname[code.co_code[frame.f_lasti + 4]]
second_arg = code.co_code[frame.f_lasti + 5]
if second_instruction == "LOAD_CONST":
value = code.co_consts[second_arg]
elif second_instruction == "LOAD_FAST":
var_name = code.co_varnames[second_arg]
value = frame.f_locals[var_name]
else:
value = None
if value is not None:
index_name = repr(value)
return base_name + "[" + index_name + "]"
else:
return None
else:
return None
Comment on lines +168 to +204
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Helper function from pyprob
def _extract_target_of_assignment(depth):
frame = sys._getframe(depth)
code = frame.f_code
next_instruction = code.co_code[frame.f_lasti + 2]
instruction_arg = code.co_code[frame.f_lasti + 3]
instruction_name = opcode.opname[next_instruction]
if instruction_name == "STORE_FAST":
return code.co_varnames[instruction_arg]
elif instruction_name in ["STORE_NAME", "STORE_GLOBAL"]:
return code.co_names[instruction_arg]
elif (
instruction_name in ["LOAD_FAST", "LOAD_NAME", "LOAD_GLOBAL"]
and opcode.opname[code.co_code[frame.f_lasti + 4]] in ["LOAD_CONST", "LOAD_FAST"]
and opcode.opname[code.co_code[frame.f_lasti + 6]] == "STORE_SUBSCR"
):
if instruction_name == "LOAD_FAST":
base_name = code.co_varnames[instruction_arg]
else:
base_name = code.co_names[instruction_arg]
second_instruction = opcode.opname[code.co_code[frame.f_lasti + 4]]
second_arg = code.co_code[frame.f_lasti + 5]
if second_instruction == "LOAD_CONST":
value = code.co_consts[second_arg]
elif second_instruction == "LOAD_FAST":
var_name = code.co_varnames[second_arg]
value = frame.f_locals[var_name]
else:
value = None
if value is not None:
index_name = repr(value)
return base_name + "[" + index_name + "]"
else:
return None
else:
return None
def _extract_target_of_assignment(depth) -> str:
"""Helper function to infer RV names from outer code.
Adapted from pyprob.
"""
try:
frame = sys._getframe(depth)
code = frame.f_code
next_instruction = code.co_code[frame.f_lasti + 2]
instruction_arg = code.co_code[frame.f_lasti + 3]
instruction_name = opcode.opname[next_instruction]
if instruction_name == "STORE_FAST":
return code.co_varnames[instruction_arg]
elif instruction_name in ["STORE_NAME", "STORE_GLOBAL"]:
return code.co_names[instruction_arg]
elif (
instruction_name in ["LOAD_FAST", "LOAD_NAME", "LOAD_GLOBAL"]
and opcode.opname[code.co_code[frame.f_lasti + 4]] in ["LOAD_CONST", "LOAD_FAST"]
and opcode.opname[code.co_code[frame.f_lasti + 6]] == "STORE_SUBSCR"
):
if instruction_name == "LOAD_FAST":
base_name = code.co_varnames[instruction_arg]
else:
base_name = code.co_names[instruction_arg]
second_instruction = opcode.opname[code.co_code[frame.f_lasti + 4]]
second_arg = code.co_code[frame.f_lasti + 5]
if second_instruction == "LOAD_CONST":
value = code.co_consts[second_arg]
elif second_instruction == "LOAD_FAST":
var_name = code.co_varnames[second_arg]
value = frame.f_locals[var_name]
else:
value = None
if value is not None:
index_name = repr(value)
return base_name + "[" + index_name + "]"
else:
raise Exception()
else:
raise Exception()
except Exception as ex:
raise TypeError(
"Name could not be inferred for variable from surrounding "
"context. Pass a name explicitly as the first argument to "
"the Distribution."
) from ex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks that clarifies things!



class SymbolicRandomVariable(OpFromGraph):
"""Symbolic Random Variable

Expand Down Expand Up @@ -216,7 +256,6 @@ class Distribution(metaclass=DistributionMeta):

def __new__(
cls,
name: str,
*args,
rng=None,
dims: Optional[Dims] = None,
Expand All @@ -234,8 +273,6 @@ def __new__(
----------
cls : type
A PyMC distribution.
name : str
Name for the new model variable.
rng : optional
Random number generator to use with the RandomVariable.
dims : tuple, optional
Expand Down Expand Up @@ -277,6 +314,23 @@ def __new__(
"for a standalone distribution."
)

if "name" in kwargs:
name = kwargs.pop("name")
elif len(args) > 0 and isinstance(args[0], string_types):
name = args[0]
args = args[1:]
else:
name = _extract_target_of_assignment(2)
Copy link
Member

@michaelosthege michaelosthege Dec 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For situations where automated name determination fails, this part should raise an informative error.
For example, one might simply forget to pass the name:

with pmodel:
    pm.Normal("Y", pm.Uniform(), observed=2)

There are exceptions a few lines downstream, but here we should do a try/except

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or even better: make _extract_target_of_assignment raise the informative error and unit test that directly.

Copy link
Contributor Author

@zaxtax zaxtax Dec 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. What do you mean by an informative error message? I'm not sure what more can I say beyond that I failed to infer the name, and in the _extract_target_of_assignment that no assignment instruction was found at that frame depth.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe something about "inferring the name from the outer code"

if you have a handle on the line of code print that

definitely instructions what do do about it: please pass a name as the first argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed. Does that look ok?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the message looks good, but I'd recommend to raise it in the extraction function instead of returning None. It gives a cleaner signature (str or raise) and separation of responsibility..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error message changes. Are we ok with an Assignment not found error that is raised, caught, and re-raised as a Name needs to be provided error message?

if name is None:
raise TypeError(
"Name could not be inferred for variable from surrounding "
"context. Pass a name explicitly as the first argument to "
"the Distribution."
)

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

Comment on lines +317 to +333
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "name" in kwargs:
name = kwargs.pop("name")
elif len(args) > 0 and isinstance(args[0], string_types):
name = args[0]
args = args[1:]
else:
name = _extract_target_of_assignment(2)
if name is None:
raise TypeError(
"Name could not be inferred for variable from surrounding "
"context. Pass a name explicitly as the first argument to "
"the Distribution."
)
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")
name = kwargs.pop("name", None)
if name is None:
if args and isinstance(args[0], string_types):
# Name was provided as the first argument
name = args[0]
args = args[1:]
else:
# Try to infer name from the outer context.
# This may raise an error, but then there's nothing else we can do.
name = _extract_target_of_assignment(2)
if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

if "testval" in kwargs:
initval = kwargs.pop("testval")
warnings.warn(
Expand All @@ -285,9 +339,6 @@ def __new__(
stacklevel=2,
)

if not isinstance(name, string_types):
raise TypeError(f"Name needs to be a string but got: {name}")

dims = convert_dims(dims)
if observed is not None:
observed = convert_observed_data(observed)
Expand Down
27 changes: 27 additions & 0 deletions pymc/tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,30 @@ def test_tag_future_warning_dist():
with pytest.warns(FutureWarning, match="Use model.rvs_to_values"):
value_var = new_x.tag.value_var
assert value_var == "1"


def test_autonaming():
"""Test that random variable ends up with same name as assignment variable"""
d = {}
with pm.Model() as m:
x = pm.Normal(0.0, 1.0)
d[2] = pm.Normal(0.0, 1.0)

assert x.name == "x"
assert m["x"] == x
assert d[2].name == "d[2]"
assert m["d[2]"] == d[2]


def test_autonaming_noname():
"""Test that autonaming fails if no assignment can be found"""
with pytest.raises(
TypeError,
match=(
"Name could not be inferred for variable from surrounding "
"context. Pass a name explicitly as the first argument to "
"the Distribution."
),
):
with pm.Model():
pm.Normal()