-
Notifications
You must be signed in to change notification settings - Fork 31
Adding CUDA graph capture feature for supervised learning. #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
434417a to
76b058e
Compare
Signed-off-by: Josh Romero <joshr@nvidia.com>
|
/build_and_test |
|
🚀 Build workflow triggered! View run |
|
✅ Build workflow passed! View run |
Signed-off-by: Josh Romero <joshr@nvidia.com>
azrael417
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, thanks a lot. much cleaner but I have a few comments still.
|
|
||
| private: | ||
| // Input signature for validating consistent inputs | ||
| struct InputSignature { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This structure is the same for all the states, wouldn't it be better to move it outside?
| namespace torchfort { | ||
|
|
||
| // Action to take for current iteration | ||
| enum class GraphAction { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this outside the ENABLE_GPU context?
| void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); } | ||
|
|
||
| // Get static loss (valid after CAPTURE or REPLAY) | ||
| const torch::Tensor& get_loss() const { return static_loss_; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add asserts to return error when not fully initialized?
| } | ||
|
|
||
| // Extract loss value | ||
| *loss_val = loss.item<float>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this work? .item copies the loss back to the CPU, and then all reduce needs a tensor, right? We can just clone the loss tensor, call all reduce on it and then extract the scalar with .item?
This PR adds CUDA graphs support for supervised learning problems. The feature is enabled via a new general configuration entry:
enable_cuda_graphs, see updated documentation.Since we are targeting high-performance use cases, this functionality is made to be fairly minimal in terms of features. In particular, we do not maintain internal static entry points to the captured graphs, allow graph recapture for dynamic shapes, etc. Instead, we expect users to provide consistent input data (memory locations, shapes) to be compatible with the CUDA graphs operating model.
Marking this as a draft for now as I still need to implement some tests.