Description
Due to a (big) misunderstanding, my #4625 PR broke the way how size
parametrization works in v4
.
The misunderstanding was specifically about the notion of implied dimensions and what it means for univariate vs. multivariate RVs:
⚠ The pre-#4625 notion was that size
is in addition to the support dimensions.
⚠ The post-#4625 notion was that size
is in addition to what's implied by parameters other than size
/shape
/dims
.
The difference is subtle and maybe best explained with the following example:
# v4 pre-4625 (45cb4ebf36500e502481bdced6980dd9e630acca)
MvNormal.dist(
cov=eye(7), mu=ones(7), size=(2,3) # MvNormal is multivariate with ndim_support=1,
).eval().shape == (2, 3, 7) # therefore (7,) is implied/required and does not count into `size`.
Normal.dist(
mu=[1,2,3], size=(2,3)
).eval().shape == (2,3) # Normal is univariate, so `mu` does not count as a support dim
# v4 post-4625 (e9f2e9616394275ccf7587a4818fe21251d51328)
MvNormal.dist(
cov=eye(7), mu=ones(7), size=(2,3)
).eval().shape == (2, 3, 7)
Normal.dist(
mu=[1,2,3], size=(2,3)
).eval().shape == (2, 3, 3) # the last dimension of length 3 was implied by mu and does not count into `size`
With the changes from #4625 the outcome from specifying shape=(1,2,3, ...)
and size=(1,2,3)
is identical.
After some discussion about the advantages/disadvantages of either API flavor, we decided to go back to the pre-4625 flavor where size
is essentially shape
but without support dimensions.
This is also the way how numpy handles dimensionality of multivariate distributions:
np.random.mvnormal(
cov=np.eye(7), mean=np.ones(7), size=(2, 3)
).shape == (2, 3, 7)
The flexibility added by #4625, namely the ability to not specify dimensions that are implied from RV support or parameters, will continue to work through the shape
with Ellipsis API:
mu = aesara.shared([1, 2, 3])
rv = Normal.dist(
mu=mu, shape=(7, 5, ...) # only the additional dimensions are specified explicitly
)
assert rv.eval().shape == (7, 5, 3)
# Now change the parameter-implied dimensions:
mu.set_value([1, 2, 3, 4])
assert rv.eval().shape == (7, 5, 4)