Skip to content

Add MvStudentT and MatrixNormal moment #5173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Nov 21, 2021
Merged

Conversation

patel-zeel
Copy link
Contributor

@patel-zeel patel-zeel commented Nov 11, 2021

Add moments and tests for the below distributions as part of #5078

  • pymc.distributions.multivariate.MvStudentT
  • pymc.distributions.multivariate.MatrixNormal

@codecov
Copy link

codecov bot commented Nov 11, 2021

Codecov Report

Merging #5173 (e23a672) into main (45b3339) will decrease coverage by 0.01%.
The diff coverage is 58.33%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5173      +/-   ##
==========================================
- Coverage   77.98%   77.96%   -0.02%     
==========================================
  Files          88       88              
  Lines       14227    14239      +12     
==========================================
+ Hits        11095    11102       +7     
- Misses       3132     3137       +5     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 71.48% <58.33%> (-0.24%) ⬇️

@ricardoV94 ricardoV94 mentioned this pull request Nov 11, 2021
51 tasks
@patel-zeel patel-zeel changed the title Add MvStudentT moment Add MvStudentT and MatrixNormal moment Nov 12, 2021
@ricardoV94 ricardoV94 requested a review from Sayam753 November 13, 2021 07:17
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found an issue with the MvStudentT moment

(2, np.ones(1), np.eye(1), None, np.ones(1)),
(2, rand1d, np.eye(2), None, rand1d),
(2, rand1d, np.eye(2), 2, np.full((2, 2), rand1d)),
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)),
Copy link
Member

@ricardoV94 ricardoV94 Nov 15, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test combination seems to be failing, because the moment is ignoring broadcasting induced by nu

Suggested change
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)),
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)),
(np.array([3, 4]), np.ones(2), np.eye(2), None, np.ones((2, 2))),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to fix it. I suppose nu and size are not going well together. Is the following test correct? if not why?

(np.array([3, 4]), np.ones(2), np.eye(2), 2, np.ones((2, 2, 2))),

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sayam753 Can you confirm this is the expected behavior with nu broadcasting?

Copy link
Member

@Sayam753 Sayam753 Nov 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am first time seeing nu not being a scaler. And this is not so straight-forward -

In the case -

(np.array([3, 4]), np.ones(2), np.eye(2), 2, np.ones((2, 2, 2)))
size = 2
nu = np.array([3, 4])
mean = np.ones(2)

Due to the dimensions of nu, we form two independent (batched) MvStudentT distribution. Each random sample from this distribution will be of shape (2, 2). First dimension is implied by nu and second by cov.shape[1]
Then size comes into play and we thus have the shape (2, 2, 2).


For the case,

size = 2
nu = np.array([3, 4])
mean = np.ones((2, 2))

The resulting shape would still be (2, 2, 2) not (2, 2, 2, 2). This is because mean.shape[:-1] and nu.shape broadcasts with each other. This case is not handled with the current implementation of get_moment method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some pseudo code to fix the same would be -

event_shape = cov.shape[-1:]
batch_shape = broadcast_shapes(nu.shape, mu.shape[:-1])
output_shape = size + batch_shape + event_shape
at.fill(output_shape, mu)

Copy link
Member

@ricardoV94 ricardoV94 Nov 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should imply batch dimensions and rng_fn method should work out of the box

@Sayam753, does this mean the rng_fn is working out of the box? We can tweak the shape inference if that's not giving the right results. I am not sure, but I think Aesara will not broadcast any parameters before calling rng_fn, broadcasting is only done internally when trying to infer the shape of the RV, but I would need to double check.

If neither the rng_fn nor the Aesara shape inference are working for vector nu, we can also force it to be a scalar in make_node and raise otherwise.

Also we would need to check if the logp/logcdf behave properly with vector nu, so maybe forcing it to be a scalar is the safest course of action for now? We can open an issue to investigate things can/are working properly with vector nu

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the rng_fn is working out of the box?

Yes. As per @patel-zeel comment two messages above, shape from random samples does not match the shape obtained from initial moments.

so maybe forcing it to be a scalar is the safest course of action for now?

Agree. @patel-zeel please revert the nu broadcasting logic. And lets only deal with mu for moments this time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the rng_fn is working out of the box?

Yes. As per @patel-zeel comment two messages above, shape from random samples does not match the shape obtained from initial moments.

Just to be clear, I don't think we should limit the MvStudentT just because the get_moment is not giving the right shape. We could limit it on the grounds that rv.eval().shape disagrees with rv.shape.eval() or that our logp/logcdf do not work properly with vector nu.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape inference does seem to fail:

x = pm.MvStudentT.dist(nu=[3, 4], mu=np.ones(2), cov=np.eye(2), size=2)
tuple(x.shape.eval()) == tuple(x.eval().shape)  # False

Copy link
Contributor Author

@patel-zeel patel-zeel Nov 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestions @ricardoV94 and @Sayam753. I have disabled nu broadcasting for now. I hope I am covering everything else with the current test cases.

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 20, 2021

So we don't want to support vector nu (we should open an issue to make sure we enforce that in MvStudentT dist or RV), but what about a 3D sigma?

Is that allowed by the RV? I think we have tests for this case in the MvNormal moment. No we don't

@patel-zeel
Copy link
Contributor Author

That was an interesting thought @ricardoV94. But I think cov is enforced to be 2d in quaddist_matrix. Thanks for approving the PR :)

@ricardoV94
Copy link
Member

That was an interesting thought @ricardoV94. But I think cov is enforced to be 2d in quaddist_matrix. Thanks for approving the PR :)

Yeah, we might want to revisit that limitation another time. Thanks a lot for your help. I'll leave some time in case @Sayam753 wants to have a look before merging

@patel-zeel
Copy link
Contributor Author

Sure. No problem at all. It was a good learning experience for me.

Copy link
Member

@Sayam753 Sayam753 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

@ricardoV94 ricardoV94 merged commit dc92865 into pymc-devs:main Nov 21, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants