-
Notifications
You must be signed in to change notification settings - Fork 69
Graph kernel consistency with other kernels #560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph kernel consistency with other kernels #560
Conversation
… kernel to allow SVGP shapes and dtypes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one suggested comment. A sign of our poor testing perhaps, but can you also just ensure that the graph kernel notebook runs identically to the current prod version?
In your revision, you may also bump the patch version in init.py.
gpjax/kernels/non_euclidean/utils.py
Outdated
| return params[tuple_indices] | ||
|
|
||
|
|
||
| def calculate_S(kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this function to something more informative and add types please?
gpjax/kernels/non_euclidean/utils.py
Outdated
| return params[tuple_indices] | ||
|
|
||
|
|
||
| def calculate_S(kernel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def calculate_S(kernel): | |
| def calculate_heat_semigroup(kernel: GraphKernel) -> Float[Array, "N N"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried out having a GraphKernel type but that leads to circular import errors. A way to seperate out the calculate_heat_semigroup function in utils file would be to have a base Graph class and the GraphKernel inherit the base class. Then calculate_heat_semigroup could just check for the base type. I am not sure if that is necessary at this point or if many graph kernels actually use eigenvalues. Thus I put the calculate_heat_semigroup function in the graph file itself for now.
I checked the notebook as well and it seems all is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To bypass a circular import only for the needs of types, you can do
from __future__ import annotations
from import beartype.typing as tp
if tp.TYPE_CHECKING:
from gpjax.kernels.non_euclidean.graph import GraphKernel
Overall this looks good. Can you please bump the version in gpjax/__init__.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're a genius! Worked like a charm. Function is back in utils and bumped init
Checklist
uv run poe formatbefore committing.Description
This is in reference to the issue #558 . I have created another helper function to calculate S and moved it out to utils instead. I have kept the eigen computation bits so that the tests don't fail.
Not sure if this is the preferred way. I did think of moving the S calculations to the
_call_function itself but it maybe required for other kernels in the future?Issue Number: #558