Skip to content

Restore pre-4625 way how the size kwarg works #4662

Closed
@michaelosthege

Description

@michaelosthege

Due to a (big) misunderstanding, my #4625 PR broke the way how size parametrization works in v4.

The misunderstanding was specifically about the notion of implied dimensions and what it means for univariate vs. multivariate RVs:

⚠ The pre-#4625 notion was that size is in addition to the support dimensions.

⚠ The post-#4625 notion was that size is in addition to what's implied by parameters other than size/shape/dims.

The difference is subtle and maybe best explained with the following example:

# v4 pre-4625 (45cb4ebf36500e502481bdced6980dd9e630acca)
MvNormal.dist(
    cov=eye(7), mu=ones(7), size=(2,3)    # MvNormal is multivariate with ndim_support=1,
).eval().shape == (2, 3, 7)               # therefore (7,) is implied/required and does not count into `size`.
Normal.dist(
    mu=[1,2,3], size=(2,3)
).eval().shape == (2,3)                   # Normal is univariate, so `mu` does not count as a support dim

# v4 post-4625 (e9f2e9616394275ccf7587a4818fe21251d51328)
MvNormal.dist(
    cov=eye(7), mu=ones(7), size=(2,3)
).eval().shape == (2, 3, 7)
Normal.dist(
    mu=[1,2,3], size=(2,3)
).eval().shape == (2, 3, 3)         # the last dimension of length 3 was implied by mu and does not count into `size`

With the changes from #4625 the outcome from specifying shape=(1,2,3, ...) and size=(1,2,3) is identical.

After some discussion about the advantages/disadvantages of either API flavor, we decided to go back to the pre-4625 flavor where size is essentially shape but without support dimensions.
This is also the way how numpy handles dimensionality of multivariate distributions:

np.random.mvnormal(
    cov=np.eye(7), mean=np.ones(7), size=(2, 3)
).shape == (2, 3, 7)

The flexibility added by #4625, namely the ability to not specify dimensions that are implied from RV support or parameters, will continue to work through the shape with Ellipsis API:

mu = aesara.shared([1, 2, 3])

rv = Normal.dist(
    mu=mu, shape=(7, 5, ...)   # only the additional dimensions are specified explicitly
)
assert rv.eval().shape == (7, 5, 3)

# Now change the parameter-implied dimensions:
mu.set_value([1, 2, 3, 4])
assert rv.eval().shape == (7, 5, 4)

Metadata

Metadata

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions