Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion benchmarks/linear_programming/cuopt/run_mip.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */
Expand All @@ -12,6 +12,7 @@
#include <cuopt/linear_programming/mip/solver_solution.hpp>
#include <cuopt/linear_programming/optimization_problem.hpp>
#include <cuopt/linear_programming/solve.hpp>
#include <cuopt/utilities/user_interrupt_handler.hpp>
#include <mps_parser/parser.hpp>
#include <utilities/logger.hpp>

Expand Down Expand Up @@ -40,6 +41,14 @@

#include "initial_problem_check.hpp"

class check_termination_callback_t : public cuopt::internals::check_termination_callback_t {
public:
virtual bool check_termination() override
{
return cuopt::user_interrupt_handler_t::instance().termination_requested();
}
};

void merge_result_files(const std::string& out_dir,
const std::string& final_result_file,
int n_gpus,
Expand Down Expand Up @@ -197,6 +206,8 @@ int run_single_file(std::string file_path,
}
}

check_termination_callback_t termination_callback;
settings.set_mip_callback(&termination_callback);
settings.time_limit = time_limit;
settings.heuristics_only = heuristics_only;
settings.num_cpu_threads = num_cpu_threads;
Expand Down
16 changes: 15 additions & 1 deletion cpp/cuopt_cli.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */

#include <cuopt/linear_programming/mip/solver_settings.hpp>
#include <cuopt/linear_programming/optimization_problem.hpp>
#include <cuopt/linear_programming/solve.hpp>
#include <cuopt/utilities/user_interrupt_handler.hpp>
#include <mps_parser/parser.hpp>
#include <utilities/logger.hpp>

Expand All @@ -27,6 +28,14 @@

#include <cuopt/version_config.hpp>

class check_termination_callback_t : public cuopt::internals::check_termination_callback_t {
public:
virtual bool check_termination() override
{
return cuopt::user_interrupt_handler_t::instance().termination_requested();
}
};

static char cuda_module_loading_env[] = "CUDA_MODULE_LOADING=EAGER";

/**
Expand Down Expand Up @@ -92,6 +101,11 @@ int run_single_file(const std::string& file_path,
const raft::handle_t handle_{};
cuopt::linear_programming::solver_settings_t<int, double> settings;

// Ctrl-C handler
check_termination_callback_t termination_callback;
settings.set_mip_callback(&termination_callback);
settings.set_lp_callback(&termination_callback);

try {
for (auto& [key, val] : settings_strings) {
settings.set_parameter_from_string(key, val);
Expand Down
15 changes: 14 additions & 1 deletion cpp/include/cuopt/linear_programming/pdlp/solver_settings.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */
Expand All @@ -9,6 +9,7 @@

#include <cuopt/linear_programming/constants.h>
#include <cuopt/linear_programming/pdlp/pdlp_warm_start_data.hpp>
#include <cuopt/linear_programming/utilities/internals.hpp>
#include <optional>
#include <raft/core/device_span.hpp>
#include <rmm/device_uvector.hpp>
Expand Down Expand Up @@ -170,6 +171,16 @@ class pdlp_solver_settings_t {
bool has_initial_primal_solution() const;
bool has_initial_dual_solution() const;

const std::vector<internals::base_solution_callback_t*> get_lp_callbacks() const
{
return lp_callbacks_;
}

void set_lp_callback(internals::base_solution_callback_t* callback)
{
lp_callbacks_.push_back(callback);
}

struct tolerances_t {
f_t absolute_dual_tolerance = 1.0e-4;
f_t relative_dual_tolerance = 1.0e-4;
Expand Down Expand Up @@ -215,6 +226,8 @@ class pdlp_solver_settings_t {
std::atomic<int>* concurrent_halt{nullptr};
static constexpr f_t minimal_absolute_tolerance = 1.0e-12;

std::vector<internals::base_solution_callback_t*> lp_callbacks_;

private:
/** Initial primal solution */
std::shared_ptr<rmm::device_uvector<f_t>> initial_primal_solution_;
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/cuopt/linear_programming/solver_settings.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */
Expand Down Expand Up @@ -82,6 +82,8 @@ class solver_settings_t {
i_t size,
rmm::cuda_stream_view stream = rmm::cuda_stream_default);
void set_mip_callback(internals::base_solution_callback_t* callback = nullptr);
void set_lp_callback(internals::base_solution_callback_t* callback = nullptr);
const std::vector<internals::base_solution_callback_t*> get_lp_callbacks() const;
Comment on lines +85 to +86
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Add Doxygen documentation for new public API methods.

The new set_lp_callback and get_lp_callbacks methods lack documentation comments. Public APIs in headers under cpp/include/cuopt/ should include Doxygen-style documentation covering:

  • Purpose and usage
  • Parameter descriptions (e.g., callback purpose, nullptr behavior)
  • Return value semantics
  • Thread-safety considerations

As per coding guidelines for public header files.

📋 Suggested documentation template
+  /**
+   * @brief Set LP callback for monitoring solve progress and enabling user termination
+   * @param callback Pointer to callback implementation, or nullptr to clear callbacks
+   */
   void set_lp_callback(internals::base_solution_callback_t* callback = nullptr);
+  /**
+   * @brief Get all registered LP callbacks
+   * @return Vector of callback pointers currently registered
+   */
   const std::vector<internals::base_solution_callback_t*> get_lp_callbacks() const;
🤖 Prompt for AI Agents
In @cpp/include/cuopt/linear_programming/solver_settings.hpp around lines 85-86,
Add Doxygen-style comments above the public methods set_lp_callback and
get_lp_callbacks describing their purpose and usage, documenting the parameter
for set_lp_callback (that it accepts an internals::base_solution_callback_t* and
that nullptr clears/unsets the callback), documenting the return semantics for
get_lp_callbacks (returns a vector of pointers, ownership/copy/const
guarantees), and including thread-safety notes (whether these methods are safe
to call concurrently or require external synchronization); ensure comments
follow the project's Doxygen format, include @param, @return and @thread-safety
or equivalent tags, and place them directly above the declarations of
set_lp_callback and get_lp_callbacks in the header.


const pdlp_warm_start_data_view_t<i_t, f_t>& get_pdlp_warm_start_data_view() const noexcept;
const std::vector<internals::base_solution_callback_t*> get_mip_callbacks() const;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */
Expand Down Expand Up @@ -89,5 +89,18 @@ class default_set_solution_callback_t : public set_solution_callback_t {
PyObject* pyCallbackClass;
};

class default_check_termination_callback_t : public check_termination_callback_t {
public:
bool check_termination() override
{
PyObject* res = PyObject_CallMethod(this->pyCallbackClass, "check_termination", nullptr);
bool should_terminate = PyObject_IsTrue(res);
Py_DECREF(res);
return should_terminate;
}
Comment on lines +94 to +100
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Missing NULL check on Python API return value will cause crash on exception.

PyObject_CallMethod returns NULL when the Python callback raises an exception. Calling PyObject_IsTrue(NULL) is undefined behavior and Py_DECREF(NULL) will segfault.

🔎 Proposed fix
   bool check_termination() override
   {
     PyObject* res = PyObject_CallMethod(this->pyCallbackClass, "check_termination", nullptr);
+    if (res == nullptr) {
+      // Python exception occurred - treat as termination request
+      PyErr_Print();
+      return true;
+    }
     bool should_terminate = PyObject_IsTrue(res);
     Py_DECREF(res);
     return should_terminate;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool check_termination() override
{
PyObject* res = PyObject_CallMethod(this->pyCallbackClass, "check_termination", nullptr);
bool should_terminate = PyObject_IsTrue(res);
Py_DECREF(res);
return should_terminate;
}
bool check_termination() override
{
PyObject* res = PyObject_CallMethod(this->pyCallbackClass, "check_termination", nullptr);
if (res == nullptr) {
// Python exception occurred - treat as termination request
PyErr_Print();
return true;
}
bool should_terminate = PyObject_IsTrue(res);
Py_DECREF(res);
return should_terminate;
}
🤖 Prompt for AI Agents
In @cpp/include/cuopt/linear_programming/utilities/callbacks_implems.hpp around
lines 94-100, The function check_termination calls PyObject_CallMethod on
this->pyCallbackClass and immediately uses PyObject_IsTrue and Py_DECREF without
checking for NULL; modify check_termination to test whether PyObject_CallMethod
returned NULL, handle the Python exception (e.g., call PyErr_Print() or
propagate/translate the error) and return a safe default (false) when NULL is
returned, otherwise proceed to call PyObject_IsTrue on res and Py_DECREF(res);
ensure references to pyCallbackClass, PyObject_CallMethod, PyObject_IsTrue, and
Py_DECREF are preserved and only used after the NULL check.


PyObject* pyCallbackClass;
};

} // namespace internals
} // namespace cuopt
12 changes: 10 additions & 2 deletions cpp/include/cuopt/linear_programming/utilities/internals.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */
Expand All @@ -20,7 +20,7 @@ class Callback {
virtual ~Callback() {}
};

enum class base_solution_callback_type { GET_SOLUTION, SET_SOLUTION };
enum class base_solution_callback_type { GET_SOLUTION, SET_SOLUTION, CHECK_TERMINATION };

class base_solution_callback_t : public Callback {
public:
Expand Down Expand Up @@ -56,6 +56,14 @@ class set_solution_callback_t : public base_solution_callback_t {
}
};

class check_termination_callback_t : public base_solution_callback_t {
public:
virtual bool check_termination() = 0;
base_solution_callback_type get_type() const override
{
return base_solution_callback_type::CHECK_TERMINATION;
}
};
} // namespace internals

namespace linear_programming {
Expand Down
125 changes: 125 additions & 0 deletions cpp/include/cuopt/utilities/user_interrupt_handler.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */

#pragma once

#include <chrono>
#include <csignal>
#include <cstdlib>
#include <deque>
#include <functional>
#include <mutex>
#include <unordered_map>

namespace cuopt {

/**
* @brief Global singleton that handles SIGINT (Ctrl-C) and invokes registered callbacks.
*
* Components that want to respond to user interrupts register a callback via
* register_callback() and unregister via unregister_callback().
*
* Safety feature: If the user presses Ctrl-C 5 times within 5 seconds,
* the process is forcefully terminated.
*/
class user_interrupt_handler_t {
public:
static user_interrupt_handler_t& instance()
{
static user_interrupt_handler_t instance;
return instance;
}

bool termination_requested() const { return terminate_signal_received_.load(); }

/**
* @brief Register a callback to be invoked on SIGINT.
* @param callback Function to call when interrupt is received.
* @return Registration ID for later removal.
*/
size_t register_callback(std::function<void()> callback)
{
std::lock_guard<std::mutex> lock(mutex_);
size_t id = next_id_++;
callbacks_[id] = std::move(callback);
return id;
}

/**
* @brief Unregister a previously registered callback.
* @param id Registration ID returned by register_callback().
*/
void unregister_callback(size_t id)
{
std::lock_guard<std::mutex> lock(mutex_);
callbacks_.erase(id);
}

// Non-copyable, non-movable
user_interrupt_handler_t(const user_interrupt_handler_t&) = delete;
user_interrupt_handler_t& operator=(const user_interrupt_handler_t&) = delete;
user_interrupt_handler_t(user_interrupt_handler_t&&) = delete;
user_interrupt_handler_t& operator=(user_interrupt_handler_t&&) = delete;

private:
static constexpr int force_quit_threshold = 5;
static constexpr int force_quit_window_seconds = 5;

using time_point = std::chrono::steady_clock::time_point;

user_interrupt_handler_t() { previous_handler_ = std::signal(SIGINT, &handle_signal); }

~user_interrupt_handler_t()
{
if (previous_handler_ != SIG_ERR) { std::signal(SIGINT, previous_handler_); }
}

static void handle_signal(int /*sig*/)
{
auto& self = instance();
std::lock_guard<std::mutex> lock(self.mutex_);

self.terminate_signal_received_ = true;

auto now = std::chrono::steady_clock::now();
auto cutoff = now - std::chrono::seconds(force_quit_window_seconds);

// Remove timestamps older than the window
while (!self.interrupt_times_.empty() && self.interrupt_times_.front() < cutoff) {
self.interrupt_times_.pop_front();
}
self.interrupt_times_.push_back(now);

// Force quit if too many interrupts in the window
if (static_cast<int>(self.interrupt_times_.size()) >= force_quit_threshold) {
fprintf(stderr,
"Force quit: %d interrupts in %d seconds.",
force_quit_threshold,
force_quit_window_seconds);
std::_Exit(128 + SIGINT);
}

// Invoke all registered callbacks
for (const auto& [id, callback] : self.callbacks_) {
callback();
}

auto remaining = force_quit_threshold - static_cast<int>(self.interrupt_times_.size());
fprintf(stderr,
"Interrupt received. Stopping solver... (press Ctrl-C %d more time(s) to force quit)",
remaining);
}
Comment on lines +81 to +115
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Calling mutex and invoking callbacks in a signal handler is undefined behavior.

Signal handlers have strict requirements on what functions can be called safely. Per POSIX, only async-signal-safe functions are permitted in signal handlers. std::mutex::lock(), std::mutex::unlock(), and arbitrary user callbacks are not async-signal-safe.

This can cause deadlocks if SIGINT arrives while mutex_ is already held (e.g., during register_callback), or undefined behavior in other scenarios.

Consider one of these safer approaches:

  1. Use std::atomic_flag or sig_atomic_t to set a flag in the handler, then check it from normal code paths
  2. Use a self-pipe trick or signalfd (Linux) to handle signals in a dedicated thread
  3. At minimum, use std::atomic<bool> for terminate_signal_received_ (already done) and check it from solver loops rather than invoking callbacks from the handler
🔎 Suggested safer pattern
  static void handle_signal(int /*sig*/)
  {
    auto& self = instance();
-   std::lock_guard<std::mutex> lock(self.mutex_);
-
-   self.terminate_signal_received_ = true;
-
-   auto now    = std::chrono::steady_clock::now();
-   auto cutoff = now - std::chrono::seconds(force_quit_window_seconds);
-
-   // Remove timestamps older than the window
-   while (!self.interrupt_times_.empty() && self.interrupt_times_.front() < cutoff) {
-     self.interrupt_times_.pop_front();
-   }
-   self.interrupt_times_.push_back(now);
-
-   // Force quit if too many interrupts in the window
-   if (static_cast<int>(self.interrupt_times_.size()) >= force_quit_threshold) {
-     fprintf(stderr,
-             "Force quit: %d interrupts in %d seconds.",
-             force_quit_threshold,
-             force_quit_window_seconds);
-     std::_Exit(128 + SIGINT);
-   }
-
-   // Invoke all registered callbacks
-   for (const auto& [id, callback] : self.callbacks_) {
-     callback();
-   }
-
-   auto remaining = force_quit_threshold - static_cast<int>(self.interrupt_times_.size());
-   fprintf(stderr,
-           "Interrupt received. Stopping solver... (press Ctrl-C %d more time(s) to force quit)",
-           remaining);
+   // Only use async-signal-safe operations here
+   self.terminate_signal_received_.store(true, std::memory_order_relaxed);
+   
+   // Increment interrupt count atomically for force-quit detection
+   int count = self.interrupt_count_.fetch_add(1, std::memory_order_relaxed) + 1;
+   
+   if (count >= force_quit_threshold) {
+     // _Exit is async-signal-safe
+     std::_Exit(128 + SIGINT);
+   }
+   
+   // Use write() instead of fprintf - it's async-signal-safe
+   const char msg[] = "Interrupt received. Stopping solver...\n";
+   write(STDERR_FILENO, msg, sizeof(msg) - 1);
  }
+ 
+ // Add atomic counter for simpler force-quit detection
+ std::atomic<int> interrupt_count_{0};

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In @cpp/include/cuopt/utilities/user_interrupt_handler.hpp around lines 81-115,
The signal handler handle_signal must not take mutex_ or call callbacks_; change
it to only perform async-signal-safe operations: set the existing
std::atomic<bool> terminate_signal_received_ and atomically increment a new
std::atomic<int> interrupt_count_ (or write a single byte to a self-pipe /
pipe_fds to notify the main thread) and return; remove usage of
interrupt_times_, callbacks_, std::lock_guard<std::mutex> and any chrono
operations from handle_signal. Add a new public method like
process_pending_signals() (called from solver/main loop) that acquires mutex_,
updates interrupt_times_, checks force_quit_threshold/window, logs and calls
callbacks_ and calls std::_Exit if needed; or have the main thread read the
self-pipe and perform the same locked processing there. Update any declarations
(add interrupt_count_ or pipe fds and the new processing method) and ensure all
non-async-signal-safe work (locking, timestamping, invoking callbacks_) happens
only in process_pending_signals() on the main thread.


std::mutex mutex_;
std::unordered_map<size_t, std::function<void()>> callbacks_;
std::atomic<bool> terminate_signal_received_{false};
size_t next_id_{0};
std::deque<time_point> interrupt_times_;
void (*previous_handler_)(int) = SIG_ERR;
};

} // namespace cuopt
12 changes: 6 additions & 6 deletions cpp/src/dual_simplex/barrier.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/* clang-format off */
/*
* SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
/* clang-format on */
Expand Down Expand Up @@ -3688,13 +3688,13 @@ lp_status_t barrier_solver_t<i_t, f_t>::solve(f_t start_time,
data.cusparse_y_residual_ = data.cusparse_view_.create_vector(data.d_y_residual_);
data.restrict_u_.resize(num_upper_bounds);

if (toc(start_time) > settings.time_limit) {
if (settings.check_termination(start_time)) {
settings.log.printf("Barrier time limit exceeded\n");
return lp_status_t::TIME_LIMIT;
}

i_t initial_status = initial_point(data);
if (toc(start_time) > settings.time_limit) {
if (settings.check_termination(start_time)) {
settings.log.printf("Barrier time limit exceeded\n");
return lp_status_t::TIME_LIMIT;
}
Expand Down Expand Up @@ -3793,7 +3793,7 @@ lp_status_t barrier_solver_t<i_t, f_t>::solve(f_t start_time,
while (iter < iteration_limit) {
raft::common::nvtx::range fun_scope("Barrier: iteration");

if (toc(start_time) > settings.time_limit) {
if (settings.check_termination(start_time)) {
settings.log.printf("Barrier time limit exceeded\n");
return lp_status_t::TIME_LIMIT;
}
Expand Down Expand Up @@ -3829,7 +3829,7 @@ lp_status_t barrier_solver_t<i_t, f_t>::solve(f_t start_time,
relative_complementarity_residual,
solution);
}
if (toc(start_time) > settings.time_limit) {
if (settings.check_termination(start_time)) {
settings.log.printf("Barrier time limit exceeded\n");
return lp_status_t::TIME_LIMIT;
}
Expand Down Expand Up @@ -3869,7 +3869,7 @@ lp_status_t barrier_solver_t<i_t, f_t>::solve(f_t start_time,
}
data.has_factorization = false;
data.has_solve_info = false;
if (toc(start_time) > settings.time_limit) {
if (settings.check_termination(start_time)) {
settings.log.printf("Barrier time limit exceeded\n");
return lp_status_t::TIME_LIMIT;
}
Expand Down
Loading