-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add MvStudentT and MatrixNormal moment #5173
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5173 +/- ##
==========================================
- Coverage 77.98% 77.96% -0.02%
==========================================
Files 88 88
Lines 14227 14239 +12
==========================================
+ Hits 11095 11102 +7
- Misses 3132 3137 +5
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Found an issue with the MvStudentT moment
(2, np.ones(1), np.eye(1), None, np.ones(1)), | ||
(2, rand1d, np.eye(2), None, rand1d), | ||
(2, rand1d, np.eye(2), 2, np.full((2, 2), rand1d)), | ||
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test combination seems to be failing, because the moment is ignoring broadcasting induced by nu
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)), | |
(2, rand2d, np.eye(3), 2, np.full((2, 2, 3), rand2d)), | |
(np.array([3, 4]), np.ones(2), np.eye(2), None, np.ones((2, 2))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried to fix it. I suppose nu
and size
are not going well together. Is the following test correct? if not why?
(np.array([3, 4]), np.ones(2), np.eye(2), 2, np.ones((2, 2, 2))),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Sayam753 Can you confirm this is the expected behavior with nu broadcasting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am first time seeing nu
not being a scaler. And this is not so straight-forward -
In the case -
(np.array([3, 4]), np.ones(2), np.eye(2), 2, np.ones((2, 2, 2)))
size = 2
nu = np.array([3, 4])
mean = np.ones(2)
Due to the dimensions of nu
, we form two independent (batched) MvStudentT distribution. Each random sample from this distribution will be of shape (2, 2). First dimension is implied by nu
and second by cov.shape[1]
Then size
comes into play and we thus have the shape (2, 2, 2).
For the case,
size = 2
nu = np.array([3, 4])
mean = np.ones((2, 2))
The resulting shape would still be (2, 2, 2) not (2, 2, 2, 2). This is because mean.shape[:-1]
and nu.shape
broadcasts with each other. This case is not handled with the current implementation of get_moment
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some pseudo code to fix the same would be -
event_shape = cov.shape[-1:]
batch_shape = broadcast_shapes(nu.shape, mu.shape[:-1])
output_shape = size + batch_shape + event_shape
at.fill(output_shape, mu)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should imply batch dimensions and rng_fn method should work out of the box
@Sayam753, does this mean the rng_fn
is working out of the box? We can tweak the shape inference if that's not giving the right results. I am not sure, but I think Aesara will not broadcast any parameters before calling rng_fn
, broadcasting is only done internally when trying to infer the shape of the RV, but I would need to double check.
If neither the rng_fn
nor the Aesara shape inference are working for vector nu
, we can also force it to be a scalar in make_node
and raise otherwise.
Also we would need to check if the logp
/logcdf
behave properly with vector nu
, so maybe forcing it to be a scalar is the safest course of action for now? We can open an issue to investigate things can/are working properly with vector nu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the rng_fn is working out of the box?
Yes. As per @patel-zeel comment two messages above, shape from random samples does not match the shape obtained from initial moments.
so maybe forcing it to be a scalar is the safest course of action for now?
Agree. @patel-zeel please revert the nu
broadcasting logic. And lets only deal with mu
for moments this time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the rng_fn is working out of the box?
Yes. As per @patel-zeel comment two messages above, shape from random samples does not match the shape obtained from initial moments.
Just to be clear, I don't think we should limit the MvStudentT just because the get_moment
is not giving the right shape. We could limit it on the grounds that rv.eval().shape
disagrees with rv.shape.eval()
or that our logp
/logcdf
do not work properly with vector nu
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shape inference does seem to fail:
x = pm.MvStudentT.dist(nu=[3, 4], mu=np.ones(2), cov=np.eye(2), size=2)
tuple(x.shape.eval()) == tuple(x.eval().shape) # False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestions @ricardoV94 and @Sayam753. I have disabled nu
broadcasting for now. I hope I am covering everything else with the current test cases.
So we don't want to support vector Is that allowed by the RV? |
That was an interesting thought @ricardoV94. But I think |
Yeah, we might want to revisit that limitation another time. Thanks a lot for your help. I'll leave some time in case @Sayam753 wants to have a look before merging |
Sure. No problem at all. It was a good learning experience for me. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀
Add moments and tests for the below distributions as part of #5078
pymc.distributions.multivariate.MvStudentT
pymc.distributions.multivariate.MatrixNormal