-
Notifications
You must be signed in to change notification settings - Fork 360
Isolate dtypes #1009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Isolate dtypes #1009
Conversation
| layout=cueq_config.layout, | ||
| shared_weights=shared_weights, | ||
| internal_weights=internal_weights, | ||
| dtype=torch.get_default_dtype(), | ||
| math_dtype=torch.get_default_dtype(), | ||
| ) | ||
| if ( | ||
| OEQ_AVAILABLE |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pretty sure the dtype is necessary here, as the kernels might be different depending on dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PASSIONLab/OpenEquivariance#118 are implementing necessary features now hence why I left it for oeq, I did just assume that it would already be in cueq so removed, will add a test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does turn out to be problematic which is maybe poor design from cuequivariance (especially given PASSIONLab/OpenEquivariance#118 was responsive and implemented a working .to method that does update the kernels according to maintainers)
|
This is blocked by getting a solution to NVIDIA/cuEquivariance#124 and then updating the values in various tests that assert against seeded answers. |
|
|
f6fd344 to
6685858
Compare
fix cueq and add test for Dielectric MACE
minor fixes
8dee8d3 to
c934aff
Compare
This is a fairly hefty set of changes that remove the use of global dtype setting from MACE codebase. This is done by adding custom
tomethods to change the dtypes of buffers.dtypeis also made into a kwarg for all dataset objects.All tests pass apart from those that have expected values that I believe are seed dependent and something must have been affected here that changes those values.
A bunch of minor code consistency changes are also included in this PR. Notable things are using
tmp_pathfixtures rather than manually creating and cleaning tmp folders, using the cli commands for the cli tests, adding some directories to the tests to make it more clear which tests are cli tests and which are core, adding concurrency to tests to reduce ci minute waste on PRs pushed in quick succession, using consistent docstring styles, usingnpt.assert_allcloseconsistentlyThe one major remaining source of confusion is in the
test_calculators.pywhere using float32 causes everything to return very different/wrong values (perhaps this failure is also related to what we see in TorchSim/torch-sim#93?).Related: #877, #328