-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow Stationary covariance functions to use user defined distance functions #6965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
… pytensor distance functions
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6965 +/- ##
==========================================
- Coverage 92.12% 87.78% -4.35%
==========================================
Files 100 100
Lines 16859 16892 +33
==========================================
- Hits 15531 14828 -703
- Misses 1328 2064 +736
|
cov2 = pm.gp.cov.Matern32(1, ls=1, square_dist=square_dist) | ||
K1 = cov1(X).eval() | ||
K2 = cov2(X, X).eval() | ||
npt.assert_allclose(K1, K2, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe also test you can get something different from the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @bwengals. I'm not sure if the PR, as it stands, actually customizes the square_dist
callable? I might have missed something, but I didn't see where it got attached to the object.
There is one theoretical thing that we need to be aware of and we could add into the docstring of square_dist
: not all distance metrics yield valid covariance kernels. There's a short summary of this here in section 3.1.3. I fell into this trap once when I tried to implement a covariance kernel using the Haversine distance on the surface of a sphere, until I later found out that was wrong.
@@ -518,6 +521,7 @@ def __init__( | |||
ls=None, | |||
ls_inv=None, | |||
active_dims: Optional[IntSequence] = None, | |||
square_dist: Optional[Callable] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't see where the square_dist
argument is being used later. Do you want to attach it to the covariance instance at the end of __init__
? As it stands now, it looks like this isn't doing anything. Furthermore, even if you attach it to the object, would that be called instead of the __class__
square_dist
accessor?
sqd = -2.0 * pt.dot(X, pt.transpose(Xs)) + ( | ||
pt.reshape(X2, (-1, 1)) + pt.reshape(Xs2, (1, -1)) | ||
) | ||
Xs = pt.mul(Xs, 1.0 / ls) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you take out the is None
case? Can't you just do Xs = X if Xs is None else Xs
and have a single branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it down to line 564 (and your one-liner is nicer). I moved it because I thought it'd be easier for someone making a different distance function if the inputs didn't have defaults.
cov2 = pm.gp.cov.Matern32(1, ls=1, square_dist=square_dist) | ||
K1 = cov1(X).eval() | ||
K2 = cov2(X, X).eval() | ||
npt.assert_allclose(K1, K2, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with this
No.. Sorry for the sloppiness on my end! The test you and @ricardoV94 mentioned would have caught this issue. |
This PR adds an optional argument to any covariance function that inherits from
Stationary
calledsquare_dist
. New feature and existing model code shouldn't be affected.Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance
📚 Documentation preview 📚: https://pymc--6965.org.readthedocs.build/en/6965/