Skip to content

Conversation

@romerojosh
Copy link
Collaborator

This PR adds CUDA graphs support for supervised learning problems. The feature is enabled via a new general configuration entry: enable_cuda_graphs, see updated documentation.

Since we are targeting high-performance use cases, this functionality is made to be fairly minimal in terms of features. In particular, we do not maintain internal static entry points to the captured graphs, allow graph recapture for dynamic shapes, etc. Instead, we expect users to provide consistent input data (memory locations, shapes) to be compatible with the CUDA graphs operating model.

Marking this as a draft for now as I still need to implement some tests.

Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
Signed-off-by: Josh Romero <joshr@nvidia.com>
@romerojosh
Copy link
Collaborator Author

/build_and_test

@github-actions
Copy link

github-actions bot commented Dec 3, 2025

🚀 Build workflow triggered! View run

@github-actions
Copy link

github-actions bot commented Dec 4, 2025

✅ Build workflow passed! View run

Signed-off-by: Josh Romero <joshr@nvidia.com>
Copy link
Collaborator

@azrael417 azrael417 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, thanks a lot. much cleaner but I have a few comments still.


private:
// Input signature for validating consistent inputs
struct InputSignature {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This structure is the same for all the states, wouldn't it be better to move it outside?

namespace torchfort {

// Action to take for current iteration
enum class GraphAction {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this outside the ENABLE_GPU context?

void launch(cudaStream_t stream) { CHECK_CUDA(cudaGraphLaunch(graph_exec_.get(), stream)); }

// Get static loss (valid after CAPTURE or REPLAY)
const torch::Tensor& get_loss() const { return static_loss_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add asserts to return error when not fully initialized?

}

// Extract loss value
*loss_val = loss.item<float>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work? .item copies the loss back to the CPU, and then all reduce needs a tensor, right? We can just clone the loss tensor, call all reduce on it and then extract the scalar with .item?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants