Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1327,10 +1327,44 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};

std::unique_ptr<ggml_cuda_graph> cuda_graph;

int curr_stream_no = 0;

#ifdef USE_CUDA_GRAPH
// Map from first_node_ptr to cuda_graph - allows multiple graphs per context
// when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;

ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
auto it = cuda_graphs.find(first_node_ptr);
if (it == cuda_graphs.end()) {
cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
return cuda_graphs[first_node_ptr].get();
}
return it->second.get();
}

// Check if any CUDA graph is enabled for this context (used by kernels that need to know
// if graphs are in use without having access to the specific graph key)
bool any_cuda_graph_enabled() const {
for (const auto & [key, graph] : cuda_graphs) {
if (graph && graph->is_enabled()) {
return true;
}
}
return false;
}

// Check if any CUDA graph has an instance for this context
bool any_cuda_graph_has_instance() const {
for (const auto & [key, graph] : cuda_graphs) {
if (graph && graph->instance != nullptr) {
return true;
}
}
return false;
}
#endif // USE_CUDA_GRAPH

explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
Expand Down
95 changes: 57 additions & 38 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2969,56 +2969,64 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
return true;
}

static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
return cgraph->nodes[0];
}

static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {

bool res = false;

if (cuda_ctx->cuda_graph->instance == nullptr) {
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);

if (graph->instance == nullptr) {
res = true;
}

// Check if the graph size has changed
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
res = true;
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
}

// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool props_match = true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
}

for (int i = 0; i < cgraph->n_leafs; i++) {
bool props_match= true;
bool props_match = true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
}

return res;
}

static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);

#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000

if (stat == cudaErrorGraphExecUpdateFailure) {
Expand All @@ -3029,14 +3037,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
(void)cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
#endif
#endif // USE_CUDA_GRAPH

static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
const ggml_tensor * view,
Expand Down Expand Up @@ -3241,7 +3249,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}

static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
bool graph_evaluated_or_captured = false;

// flag used to determine whether it is an integrated_gpu
Expand Down Expand Up @@ -3695,13 +3703,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
}

#ifdef USE_CUDA_GRAPH
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
if (cuda_ctx->cuda_graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
cuda_ctx->cuda_graph->graph = nullptr;
if (graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph->graph));
graph->graph = nullptr;
}

CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured

std::lock_guard<std::mutex> lock(ggml_cuda_lock);
Expand All @@ -3714,40 +3723,39 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
}

if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
}
if (cuda_graph_update_required) { // Update graph executable
ggml_cuda_graph_update_executable(cuda_ctx);
ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
#else
graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH
}
}

static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {

#ifdef USE_CUDA_GRAPH
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);

if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}

if (cuda_ctx->cuda_graph->graph == nullptr) {
if (graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) {
if (!graph->disable_due_to_gpu_arch) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
}
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
graph->disable_due_to_gpu_arch = true;
}
}

return cuda_ctx->cuda_graph->is_enabled();
return graph->is_enabled();
#else
GGML_UNUSED(cuda_ctx);
GGML_UNUSED(graph_key);
return false;
#endif // USE_CUDA_GRAPH
}
Expand All @@ -3759,15 +3767,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,

bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
const void * graph_key = nullptr;

#ifdef USE_CUDA_GRAPH
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
graph_key = ggml_cuda_graph_get_key(cgraph);

use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);

if (cuda_ctx->cuda_graph->is_enabled()) {
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->is_enabled()) {
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph);

cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
graph->record_update(use_cuda_graph, cuda_graph_update_required);
}
#endif // USE_CUDA_GRAPH

Expand All @@ -3781,7 +3793,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}

ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);

return GGML_STATUS_SUCCESS;
}
Expand Down Expand Up @@ -3814,7 +3826,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;

const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
#ifdef USE_CUDA_GRAPH
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
#else
const bool use_cuda_graph = false;
GGML_UNUSED(cuda_ctx);
GGML_UNUSED(cgraph);
#endif

static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT");
Expand Down
17 changes: 9 additions & 8 deletions ggml/src/ggml-cuda/mean.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
#endif // USE_CUDA_GRAPH
if ((nrows == 1) &&
#ifdef USE_CUDA_GRAPH
// CUDA_GRAPHS_DISABLED
((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled())) ||
// CUDA_GRAPHS ENABLED
((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled()))) {
// Determine if CUDA graphs are effectively disabled for this context
// (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
(((ncols > 65536) &&
(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
ctx.any_cuda_graph_enabled())) ||
// CUDA graphs are enabled - use lower threshold
((ncols > 32768) &&
!(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
ctx.any_cuda_graph_enabled())))) {
#else
(ncols > 65536)) {
#endif // USE_CUDA_GRAPH
Expand Down
58 changes: 34 additions & 24 deletions ggml/src/ggml-hexagon/htp/flash-attn-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"

#include <assert.h>
#include <HAP_farf.h>
#include <HAP_perf.h>

#include <math.h>
#include <string.h>

Expand Down Expand Up @@ -111,7 +111,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
hvx_vec_store_u(r, 4, rsum);
}

// MAD: y (F32) += x (F16) * v (float)
// MAD: y (F32) += x (F16) * s (float)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
Expand Down Expand Up @@ -318,9 +318,12 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
uint32_t ic = 0;

// Process in blocks of 32 (VLEN_FP32)
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4;
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE];
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
Expand Down Expand Up @@ -356,36 +359,43 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
scores = Q6_Vsf_equals_Vqf32(scores);
}

scores_x4.v[iv] = scores;
v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max);
}

{
// 4. Online Softmax Update
HVX_Vector v_max = hvx_vec_reduce_max_f32(scores);
v_max = hvx_vec_reduce_max_f32(v_max);
float m_block = hvx_vec_get_f32(v_max);

float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;

float ms = expf(M_old - M_new);

const float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
S = S * ms;

HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));

HVX_Vector p_sum_vec = hvx_vec_reduce_sum_f32(P);
float p_sum = hvx_vec_get_f32(p_sum_vec);
S += p_sum;

// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;

for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f);
for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) {
HVX_Vector scores = scores_x4.v[iv];
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted));

p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P));

// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;

for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic2 + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
}
}

p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec);
S = S * ms + hvx_vec_get_f32(p_sum_vec);
}

// Leftover
Expand Down
Loading
Loading