-
Notifications
You must be signed in to change notification settings - Fork 360
Combining torch.compile with cueq #1175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@ilyes319 , still more work to do for compilation but I wanted to check that the choices for the methods here look ok to you. |
|
mmm good question @mariogeiger, is this the intended settings? |
mariogeiger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did I answer the questions @ilyes319 ?
| shared_weights=shared_weights, | ||
| internal_weights=internal_weights, | ||
| use_fallback=True, | ||
| method="naive", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good.
Just a question, is this for the skipTP? is one of the input always a one-hot vector? If so, why not indexing the weights instead of contract with a one-hot?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah indeed, does cueq provide the option to pass the weight in the forward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ilyes319, since 0.6.0 we support indexing in the linear: you just need to provide the number of weight_classes at init and pass the weight_indices in the fwd.
If your indices are sorted you can also use the experimental "indexed_linear" method, but even in the "naive" case this would first index into the weights, then compute a linear for each element, which should still be faster.
|
@mariogeiger Yes thank you!!! |
|
The latest commit on this PR works in my limited testing but really just changes the scope of the workaround of using symbolic tracing. @mariogeiger would be good to have your take on this - the cueq torch ops perform a number of runtime checks on input tensors. This is obviously a great thing for developer UX but breaks the symbolic tracing that is currently used in MACE. This is a broader design choice since having runtime errors that introduce side-effects that are incompatible with the static control flow requirement for symbolic tracing. A couple of choices are possible here:
Keen to hear thought or suggestions! |
|
Since you are using the argument "method=..." you should maybe update the version in your setup.cfg -> |
|
Hi, we were discussing about this: in the past we had some of these checks with a wrapper like Would something like this, with the appropriate mode, work to disable the checks for compiled models? Also, can you tell me exactly what kind of compilation command you're trying to enable here? Because we're testing some like |
At the moment mace is first passed through |
|
Hey @hatemhelal, curious what the status of the PR? should it be merged, is there anything else to work on? I can think about the following point to use that most effectively in MD:
|
|
@ilyes319 think this should be ok to merge, the follow on discussion is a tangential direction for making this more robust. Also recall that @ThomasWarford was looking to add benchmarks for this in #1184. Agree with all the follow on points you suggest, just not sure when I might get to do them so can't commit at the moment! |
|
ok I merged it for now, @ThomasWarford tell me if you are up to try implementing these. |
|
I tried implementing it today and I get errors running The first was fixed by Here's the first other error: I believe #1184 should work if this is resolved, but I do still get another error with keys (which I posted there), rather than this error. This is python version 3.12, CUDA 12.6, PyTorch 2.9.0+cu126 |
This PR provides a simple fix to enable the use of
torch.compilein conjunction withcueq.The main change is to provide a tighter scope for the symbolic tracing that was previously applied to the
NonLinearReadoutBlockand instead just simplify the activation constructor.Also updated tests to parameterise over enabling cueq