Skip to content

Bump PyMC dependency #183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ dependencies:
- dask
- xhistogram
- pip:
- pymc>=5.2.0 # CI was failing to resolve
- pymc>=5.4.1 # CI was failing to resolve
- blackjax
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ dependencies:
- dask
- xhistogram
- pip:
- pymc>=5.2.0 # CI was failing to resolve
- pymc>=5.4.1 # CI was failing to resolve
54 changes: 22 additions & 32 deletions pymc_experimental/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,44 +395,38 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
marginalized_rv, *dependent_rvs = clone_replace(
inner_rvs = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)
marginalized_rv = inner_rvs[0]

# Obtain the joint_logp graph of the inner RV graph
# Some inputs are not root inputs (such as transformed projections of value variables)
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
inputs = list(inputvars(inputs))
rvs_to_values = {}
dummy_marginalized_value = marginalized_rv.clone()
rvs_to_values[marginalized_rv] = dummy_marginalized_value
rvs_to_values.update(zip(dependent_rvs, values))
logps_dict = factorized_joint_logprob(rv_values=rvs_to_values, **kwargs)
inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs}
logps_dict = factorized_joint_logprob(rv_values=inner_rvs_to_values, **kwargs)

# Reduce logp dimensions corresponding to broadcasted variables
values_axis_bcast = []
for value in values:
vbcast = value.type.broadcastable
mbcast = dummy_marginalized_value.type.broadcastable
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]
for inner_rv, inner_value in inner_rvs_to_values.items():
if inner_rv is marginalized_rv:
continue
vbcast = inner_value.type.broadcastable
mbcast = marginalized_rv.type.broadcastable
mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast
values_axis_bcast.append([i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v])
joint_logp = logps_dict[dummy_marginalized_value]
for value, values_axis_bcast in zip(values, values_axis_bcast):
joint_logp += logps_dict[value].sum(values_axis_bcast, keepdims=True)
values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]
joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True)

# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
# values of the marginalized RV
# OpFromGraph does not accept constant inputs
non_const_values = [
value
for value in rvs_to_values.values()
if not isinstance(value, (Constant, SharedVariable))
]
joint_logp_op = OpFromGraph([*non_const_values, *inputs], [joint_logp], inline=True)
# Some inputs are not root inputs (such as transformed projections of value variables)
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
inputs = list(inputvars(inputs))
joint_logp_op = OpFromGraph(
list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True
)

# Compute the joint_logp for all possible n values of the marginalized RV. We assume
# each original dimension is independent so that it sufficies to evaluate the graph
# each original dimension is independent so that it suffices to evaluate the graph
# n times, once with each possible value of the marginalized RV replicated across
# batched dimensions of the marginalized RV

Expand All @@ -449,18 +443,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
axis2=-1,
)

# OpFromGraph does not accept constant inputs
non_const_values = [
value for value in values if not isinstance(value, (Constant, SharedVariable))
]
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
if len(marginalized_rv_domain) <= 10:
joint_logps = [
joint_logp_op(marginalized_rv_domain_tensor[i], *non_const_values, *inputs)
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
for i in range(len(marginalized_rv_domain))
]
else:
# Make sure this is rewrite is registered
# Make sure this rewrite is registered
from pymc.pytensorf import local_remove_check_parameter

def logp_fn(marginalized_rv_const, *non_sequences):
Expand All @@ -469,7 +459,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
joint_logps, _ = scan_map(
fn=logp_fn,
sequences=marginalized_rv_domain_tensor,
non_sequences=[*non_const_values, *inputs],
non_sequences=[*values, *inputs],
mode=Mode().including("local_remove_check_parameter"),
)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pymc>=5.2.0
pymc>=5.4.1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
Expand Down