Skip to content

Remove joint_logprob function from tests.logprob.util submodule #6619

Closed
@ricardoV94

Description

@ricardoV94

Description

This function is not very helpful and only used in tests. We should be able to remove it, and use either logp or factorized_joint_logprob in the tests.

def joint_logprob(*args, sum: bool = True, **kwargs) -> Optional[TensorVariable]:
"""Create a graph representing the joint log-probability/measure of a graph.
This function calls `factorized_joint_logprob` and returns the combined
log-probability factors as a single graph.
Parameters
----------
sum: bool
If ``True`` each factor is collapsed to a scalar via ``sum`` before
being joined with the remaining factors. This may be necessary to
avoid incorrect broadcasting among independent factors.
"""
logprob = factorized_joint_logprob(*args, **kwargs)
if not logprob:
return None
if len(logprob) == 1:
logprob = tuple(logprob.values())[0]
if sum:
return pt.sum(logprob)
return logprob
if sum:
return pt.sum([pt.sum(factor) for factor in logprob.values()])
return pt.add(*logprob.values())

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions