Skip to content

Passing dims=(dim1, dim2) in gp.prior leads to shape inconcistency #6758

Open
@ferrine

Description

@ferrine

User can't safely pass some broadcast dim to create multioutput GP with same parameters

pymc/pymc/gp/gp.py

Lines 152 to 155 in 261862d

if reparameterize:
size = np.shape(X)[0]
v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
f = pm.Deterministic(name, mu + cholesky(cov).dot(v), dims=kwargs.get("dims", None))

Possible implementation

def _build_prior(self, name, X, reparameterize=True, jitter=JITTER_DEFAULT, **kwargs):
    mu = self.mean_func(X)
    cov = stabilize(self.cov_func(X), jitter)
    if reparameterize:
        if "dims" not in kwargs:
            size = kwargs.get("size", np.shape(X)[0])
        else:
            size = None
        v = pm.Normal(name + "_rotated_", mu=0.0, sigma=1.0, size=size, **kwargs)
        f = pm.Deterministic(name, mu + v.dot(cholesky(cov).T), dims=kwargs.get("dims", None))
    else:
        f = pm.MvNormal(name, mu=mu, cov=cov, **kwargs)
    return f

Metadata

Metadata

Assignees

No one assigned

    Labels

    GPGaussian Process

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions