diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index b034fc9c1..a274954c0 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -9,5 +9,5 @@ dependencies: - dask - xhistogram - pip: - - pymc>=5.2.0 # CI was failing to resolve + - pymc>=5.4.1 # CI was failing to resolve - blackjax diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 3e3b61f43..634be930e 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -9,4 +9,4 @@ dependencies: - dask - xhistogram - pip: - - pymc>=5.2.0 # CI was failing to resolve + - pymc>=5.4.1 # CI was failing to resolve diff --git a/pymc_experimental/marginal_model.py b/pymc_experimental/marginal_model.py index ed40325fb..ee74ae00b 100644 --- a/pymc_experimental/marginal_model.py +++ b/pymc_experimental/marginal_model.py @@ -395,44 +395,38 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # Clone the inner RV graph of the Marginalized RV marginalized_rvs_node = op.make_node(*inputs) - marginalized_rv, *dependent_rvs = clone_replace( + inner_rvs = clone_replace( op.inner_outputs, replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, ) + marginalized_rv = inner_rvs[0] # Obtain the joint_logp graph of the inner RV graph - # Some inputs are not root inputs (such as transformed projections of value variables) - # Or cannot be used as inputs to an OpFromGraph (shared variables and constants) - inputs = list(inputvars(inputs)) - rvs_to_values = {} - dummy_marginalized_value = marginalized_rv.clone() - rvs_to_values[marginalized_rv] = dummy_marginalized_value - rvs_to_values.update(zip(dependent_rvs, values)) - logps_dict = factorized_joint_logprob(rv_values=rvs_to_values, **kwargs) + inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs} + logps_dict = factorized_joint_logprob(rv_values=inner_rvs_to_values, **kwargs) # Reduce logp dimensions corresponding to broadcasted variables - values_axis_bcast = [] - for value in values: - vbcast = value.type.broadcastable - mbcast = dummy_marginalized_value.type.broadcastable + joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]] + for inner_rv, inner_value in inner_rvs_to_values.items(): + if inner_rv is marginalized_rv: + continue + vbcast = inner_value.type.broadcastable + mbcast = marginalized_rv.type.broadcastable mbcast = (True,) * (len(vbcast) - len(mbcast)) + mbcast - values_axis_bcast.append([i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v]) - joint_logp = logps_dict[dummy_marginalized_value] - for value, values_axis_bcast in zip(values, values_axis_bcast): - joint_logp += logps_dict[value].sum(values_axis_bcast, keepdims=True) + values_axis_bcast = [i for i, (m, v) in enumerate(zip(mbcast, vbcast)) if m != v] + joint_logp += logps_dict[inner_value].sum(values_axis_bcast, keepdims=True) # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different # values of the marginalized RV - # OpFromGraph does not accept constant inputs - non_const_values = [ - value - for value in rvs_to_values.values() - if not isinstance(value, (Constant, SharedVariable)) - ] - joint_logp_op = OpFromGraph([*non_const_values, *inputs], [joint_logp], inline=True) + # Some inputs are not root inputs (such as transformed projections of value variables) + # Or cannot be used as inputs to an OpFromGraph (shared variables and constants) + inputs = list(inputvars(inputs)) + joint_logp_op = OpFromGraph( + list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True + ) # Compute the joint_logp for all possible n values of the marginalized RV. We assume - # each original dimension is independent so that it sufficies to evaluate the graph + # each original dimension is independent so that it suffices to evaluate the graph # n times, once with each possible value of the marginalized RV replicated across # batched dimensions of the marginalized RV @@ -449,18 +443,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): axis2=-1, ) - # OpFromGraph does not accept constant inputs - non_const_values = [ - value for value in values if not isinstance(value, (Constant, SharedVariable)) - ] # Arbitrary cutoff to switch to Scan implementation to keep graph size under control if len(marginalized_rv_domain) <= 10: joint_logps = [ - joint_logp_op(marginalized_rv_domain_tensor[i], *non_const_values, *inputs) + joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs) for i in range(len(marginalized_rv_domain)) ] else: - # Make sure this is rewrite is registered + # Make sure this rewrite is registered from pymc.pytensorf import local_remove_check_parameter def logp_fn(marginalized_rv_const, *non_sequences): @@ -469,7 +459,7 @@ def logp_fn(marginalized_rv_const, *non_sequences): joint_logps, _ = scan_map( fn=logp_fn, sequences=marginalized_rv_domain_tensor, - non_sequences=[*non_const_values, *inputs], + non_sequences=[*values, *inputs], mode=Mode().including("local_remove_check_parameter"), ) diff --git a/requirements.txt b/requirements.txt index 600e3fa57..d6bcb07e6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -pymc>=5.2.0 +pymc>=5.4.1 diff --git a/setup.py b/setup.py index 6ebbae314..4713ea59c 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: Apache Software License", "Intended Audience :: Science/Research", "Topic :: Scientific/Engineering",