From 3d26a09dc7b1a7c13da57fdd26d1cf22efa81229 Mon Sep 17 00:00:00 2001 From: R Date: Tue, 6 Jan 2026 16:17:13 +0100 Subject: [PATCH 1/8] server : add thinking content blocks to Anthropic Messages API (#18551) * server : add thinking content blocks to Anthropic Messages API Add support for returning reasoning/thinking content in Anthropic API responses when using models with --reasoning-format deepseek and the thinking parameter enabled. - Non-streaming: adds thinking block before text in content array - Streaming: emits thinking_delta events with correct block indices - Partial streaming: tracks reasoning state across chunks via anthropic_has_reasoning member variable Tested with bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF model. * server : fix Anthropic API streaming for thinking content blocks Add signature field and fix duplicate content_block_start events in Anthropic Messages API streaming responses for reasoning models. * server: refactor Anthropic streaming state to avoid raw pointer Replace raw pointer to task_result_state with direct field copies: - Copy state fields in update() before processing chunk - Use local copies in to_json_anthropic() instead of dereferencing - Pre-compute state updates for next chunk in update() This makes the data flow clearer and avoids unsafe pointer patterns. --- tools/server/server-task.cpp | 143 ++++++++++++++++-- tools/server/server-task.h | 26 ++++ .../tests/unit/test_compat_anthropic.py | 89 +++++++++++ 3 files changed, 243 insertions(+), 15 deletions(-) diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 6d374131e3b..ed4f6546ea3 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -814,6 +814,15 @@ json server_task_result_cmpl_final::to_json_anthropic() { msg.content = content; } + // thinking block comes first (Anthropic extended thinking format) + if (!msg.reasoning_content.empty()) { + content_blocks.push_back({ + {"type", "thinking"}, + {"thinking", msg.reasoning_content}, + {"signature", ""} // empty signature for local models (no cryptographic verification) + }); + } + if (!msg.content.empty()) { content_blocks.push_back({ {"type", "text"}, @@ -862,20 +871,57 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; } - bool has_text = !oaicompat_msg.content.empty(); + bool has_thinking = !oaicompat_msg.reasoning_content.empty(); + bool has_text = !oaicompat_msg.content.empty(); size_t num_tool_calls = oaicompat_msg.tool_calls.size(); - bool text_block_started = false; + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + size_t text_block_index = has_thinking ? 1 : 0; + + bool thinking_block_started = false; + bool text_block_started = false; std::unordered_set tool_calls_started; for (const auto & diff : oaicompat_msg_diffs) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", thinking_block_index}, + {"content_block", { + {"type", "thinking"}, + {"thinking", ""} + }} + }} + }); + thinking_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content if (!diff.content_delta.empty()) { if (!text_block_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", text_block_index}, {"content_block", { {"type", "text"}, {"text", ""} @@ -889,7 +935,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -898,8 +944,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + diff.tool_call_index; if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; @@ -935,18 +982,42 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { } } + // close content blocks in order + if (has_thinking) { + // Anthropic API requires a signature_delta before closing thinking blocks + // We use an empty signature since we can't generate a cryptographic signature for local models + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "signature_delta"}, + {"signature", ""} + }} + }} + }); + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", thinking_block_index} + }} + }); + } + if (has_text) { events.push_back({ {"event", "content_block_stop"}, {"data", { {"type", "content_block_stop"}, - {"index", 0} + {"index", text_block_index} }} }); } for (size_t i = 0; i < num_tool_calls; i++) { - size_t content_block_index = (has_text ? 1 : 0) + i; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + i; events.push_back({ {"event", "content_block_stop"}, {"data", { @@ -1154,11 +1225,10 @@ json server_task_result_rerank::to_json() { json server_task_result_cmpl_partial::to_json_anthropic() { json events = json::array(); bool first = (n_decoded == 1); - bool text_block_started = false; + // use member variables to track block state across streaming calls + // (anthropic_thinking_block_started, anthropic_text_block_started) if (first) { - text_block_started = false; - events.push_back({ {"event", "message_start"}, {"data", { @@ -1180,28 +1250,69 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + // use anthropic_has_reasoning (set in update()) to know if ANY reasoning was generated + size_t text_block_index = anthropic_has_reasoning ? 1 : 0; + + // use local copies of streaming state (copied from task_result_state in update()) + // these reflect the state BEFORE this chunk was processed + bool thinking_started = anthropic_thinking_block_started; + bool text_started = anthropic_text_block_started; + for (const auto & diff : oaicompat_msg_diffs) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", thinking_block_index}, + {"content_block", { + {"type", "thinking"}, + {"thinking", ""} + }} + }} + }); + thinking_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content if (!diff.content_delta.empty()) { - if (!text_block_started) { + if (!text_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", text_block_index}, {"content_block", { {"type", "text"}, {"text", ""} }} }} }); - text_block_started = true; + text_started = true; } events.push_back({ {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -1210,8 +1321,10 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + // use anthropic_has_reasoning for thinking block count (persists across calls) + size_t content_block_index = (anthropic_has_reasoning ? 1 : 0) + (text_started ? 1 : 0) + diff.tool_call_index; if (!diff.tool_call_delta.name.empty()) { events.push_back({ diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 687770de5e9..ead14911821 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -96,6 +96,10 @@ struct task_result_state { std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; + // for Anthropic API streaming: track content block state across chunks + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) : oaicompat_chat_syntax(oaicompat_chat_syntax) {} @@ -337,6 +341,12 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; + // for Anthropic API: track if any reasoning content has been generated + bool anthropic_has_reasoning = false; + // Streaming state copied from task_result_state for this chunk + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } @@ -346,6 +356,22 @@ struct server_task_result_cmpl_partial : server_task_result { virtual void update(task_result_state & state) override { is_updated = true; state.update_chat_msg(content, true, oaicompat_msg_diffs); + // track if the accumulated message has any reasoning content + anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty(); + + // Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk) + anthropic_thinking_block_started = state.anthropic_thinking_block_started; + anthropic_text_block_started = state.anthropic_text_block_started; + + // Pre-compute state updates based on diffs (for next chunk) + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) { + state.anthropic_thinking_block_started = true; + } + if (!diff.content_delta.empty() && !state.anthropic_text_block_started) { + state.anthropic_text_block_started = true; + } + } } json to_json_non_oaicompat(); diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index e0a003557e7..e16e0235c64 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -805,3 +805,92 @@ def test_anthropic_vs_openai_different_response_format(): assert "input_tokens" in anthropic_res.body["usage"] assert "completion_tokens" in openai_res.body["usage"] assert "output_tokens" in anthropic_res.body["usage"] + + +# Extended thinking tests with reasoning models + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [False, True]) +def test_anthropic_thinking_with_reasoning_model(stream): + """Test that thinking content blocks are properly returned for reasoning models""" + global server + server = ServerProcess() + server.model_hf_repo = "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF" + server.model_hf_file = "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf" + server.reasoning_format = "deepseek" + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 1024 + server.server_port = 8084 + server.start(timeout_seconds=600) # large model needs time to download + + if stream: + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ], + "stream": True + }) + + events = list(res) + + # should have thinking content block events + thinking_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "thinking"] + assert len(thinking_starts) > 0, "Should have thinking content_block_start event" + assert thinking_starts[0]["index"] == 0, "Thinking block should be at index 0" + + # should have thinking_delta events + thinking_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "thinking_delta"] + assert len(thinking_deltas) > 0, "Should have thinking_delta events" + + # should have signature_delta event before thinking block closes (Anthropic API requirement) + signature_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "signature_delta"] + assert len(signature_deltas) > 0, "Should have signature_delta event for thinking block" + + # should have text block after thinking + text_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "text"] + assert len(text_starts) > 0, "Should have text content_block_start event" + assert text_starts[0]["index"] == 1, "Text block should be at index 1 (after thinking)" + else: + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + content = res.body["content"] + assert len(content) >= 2, "Should have at least thinking and text blocks" + + # first block should be thinking + thinking_blocks = [b for b in content if b.get("type") == "thinking"] + assert len(thinking_blocks) > 0, "Should have thinking content block" + assert "thinking" in thinking_blocks[0], "Thinking block should have 'thinking' field" + assert len(thinking_blocks[0]["thinking"]) > 0, "Thinking content should not be empty" + assert "signature" in thinking_blocks[0], "Thinking block should have 'signature' field (Anthropic API requirement)" + + # should also have text block + text_blocks = [b for b in content if b.get("type") == "text"] + assert len(text_blocks) > 0, "Should have text content block" From 968929528c6a05e10249366fbe5f0330ad9af678 Mon Sep 17 00:00:00 2001 From: Beinsezii <39478211+Beinsezii@users.noreply.github.com> Date: Tue, 6 Jan 2026 07:26:07 -0800 Subject: [PATCH 2/8] mmq.cu: tune mmq/rocblas switching for RDNA (#18537) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Patch perf regression for mmq kernels in ROCm recover performance regression for https://github.com/ggml-org/llama.cpp/issues/17917 * add n_experts branch like the cdna path * mmq.cu: tune mmq/wmma switching for RDNA * mmq.cu: move amd wmma mmq/wmma switching behind IS_RDNA3 * Update ggml/src/ggml-cuda/mmq.cu Co-authored-by: Johannes Gäßler --------- Co-authored-by: Jiacheng (Jason) Chen <76919340+jiachengjason@users.noreply.github.com> Co-authored-by: jiachengjason Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/mmq.cu | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 85692d45430..ceb95758d20 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } if (amd_wmma_available(cc)) { + // RDNA 4 is consistently worse on rocblas + // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + // High expert counts almost always better on MMQ + // due to a large amount of graph splits + // https://github.com/ggml-org/llama.cpp/pull/18202 + if (n_experts >= 64) { + return true; + } + + switch (type) { + // These quants are really bad on MMQ + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q6_K: + // These quants are usually worse but not always + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + return ne11 <= 128; + default: + return true; + } + } return true; } From 090b137e56a80b189dbced7d31e637951f3e123f Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Tue, 6 Jan 2026 23:48:45 +0800 Subject: [PATCH 3/8] ggml-cuda: refactor cuda graph usage (#18637) * ggml-cuda: refactor cuda graph usage * use is_enabled() instead of enabled --- ggml/src/ggml-cuda/common.cuh | 24 ++++-- ggml/src/ggml-cuda/ggml-cuda.cu | 138 ++++++++++++-------------------- ggml/src/ggml-cuda/mean.cu | 6 +- 3 files changed, 72 insertions(+), 96 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 995b774c207..9516d8ec8f9 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1036,7 +1036,7 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { +struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; @@ -1061,11 +1061,25 @@ struct ggml_cuda_graph { std::vector nodes; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; - bool cuda_graphs_enabled = false; - std::vector ggml_graph_properties; - std::vector extraneous_srcs_properties; + std::vector props; + + void record_update(bool use_graph, bool update_required) { + if (use_graph && update_required) { + number_consecutive_updates++; + } else { + number_consecutive_updates = 0; + } + if (number_consecutive_updates >= 4) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + disable_due_to_too_many_updates = true; + } + } + + bool is_enabled() const { + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + } #endif }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 75269170c34..bac69cdd1c8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2853,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - bool use_cuda_graph) { +static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; @@ -2915,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, return use_cuda_graph; } -static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - graph_node_properties->node_address = node->data; - graph_node_properties->node_op = node->op; +static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { + props->node_address = node->data; + props->node_op = node->op; for (int i = 0; i < GGML_MAX_DIMS; i++) { - graph_node_properties->ne[i] = node->ne[i]; - graph_node_properties->nb[i] = node->nb[i]; + props->ne[i] = node->ne[i]; + props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } - memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); + memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && +static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { + if (node->data != props->node_address && node->op != GGML_OP_VIEW) { return false; } - if (node->op != graph_node_properties->node_op) { + if (node->op != props->node_op) { return false; } for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != graph_node_properties->ne[i]) { + if (node->ne[i] != props->ne[i]) { return false; } - if (node->nb[i] != graph_node_properties->nb[i]) { + if (node->nb[i] != props->nb[i]) { return false; } } for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && + node->src[i]->data != props->src_address[i] && node->op != GGML_OP_VIEW ) { return false; @@ -2957,56 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - bool cuda_graph_update_required = false; + bool res = false; if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; + res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes + cgraph->n_leafs); + if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + res = true; + cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + bool props_match = true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); } for (int i = 0; i < cgraph->n_leafs; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]); + bool props_match= true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->leafs[i], &cuda_ctx->cuda_graph->ggml_graph_properties[cgraph->n_nodes + i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); } - return cuda_graph_update_required; + return res; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; @@ -3237,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { + bool graph_evaluated_or_captured = false; + // flag used to determine whether it is an integrated_gpu - const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); bool is_concurrent_event_active = false; @@ -3710,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx); } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); @@ -3720,43 +3720,25 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } -static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) { +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { #ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - // Objects required for CUDA Graph if (cuda_ctx->cuda_graph == nullptr) { cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); } - bool use_cuda_graph = true; - if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; -#ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); -#endif } } - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } - - cuda_ctx->cuda_graph->cuda_graphs_enabled = use_cuda_graph; + return cuda_ctx->cuda_graph->is_enabled(); #else - bool use_cuda_graph = false; + return false; #endif // USE_CUDA_GRAPH - - return use_cuda_graph; } static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -3767,30 +3749,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = false; bool cuda_graph_update_required = false; - // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called) - // we call it here instead. #ifdef USE_CUDA_GRAPH - use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx); - - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } + if (cuda_ctx->cuda_graph->is_enabled()) { + cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; - cuda_ctx->cuda_graph->cuda_graphs_enabled = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif - } + cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); } #endif // USE_CUDA_GRAPH @@ -3804,9 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - bool graph_evaluated_or_captured = false; - - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -3839,7 +3803,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - const bool use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); static bool enable_graph_optimization = [] { const char * env = getenv("GGML_CUDA_GRAPH_OPT"); diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 691d8dcb148..60542fc19dd 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // CUDA_GRAPHS_DISABLED ((ncols > 65536) && ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture)) || + ctx.cuda_graph->is_enabled())) || // CUDA_GRAPHS ENABLED ((ncols > 32768) && !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture))) { + ctx.cuda_graph->is_enabled()))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH From ea13cba85092fa17cae4d7bd064e5476a86ea53c Mon Sep 17 00:00:00 2001 From: Jeff Bolz Date: Tue, 6 Jan 2026 10:37:07 -0600 Subject: [PATCH 4/8] vulkan: support buffer_from_host_ptr (#18467) * vulkan: support buffer_from_host_ptr * hacky use of buffer_from_host_ptr for directio * disable buffer_from_host_ptr cap * use external memory for ggml_vk_host_malloc, revert model loader changes * disable external_memory_host for MoltenVK * take buffer memory types into account * don't use external_memory_host for ggml_vk_host_malloc --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 173 +++++++++++++++++++++++---- 1 file changed, 148 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 502a4deebc9..3c13777b8aa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -550,6 +550,8 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t max_buffer_size; uint64_t suballocation_block_size; + uint64_t min_imported_host_pointer_alignment; + bool external_memory_host {}; bool fp16; bool bf16; bool pipeline_robustness; @@ -2410,7 +2412,8 @@ static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDe return indices; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list, + void *import_ptr = nullptr) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_buffer_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); @@ -2439,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std nullptr, }; + vk::ExternalMemoryBufferCreateInfo external_memory_bci; + if (import_ptr) { + external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + buffer_create_info.setPNext(&external_memory_bci); + } + buf->buffer = device->device.createBuffer(buffer_create_info); vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); @@ -2453,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std mem_flags_info.setPNext(&mem_priority_info); } - for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { - const auto & req_flags = *it; + if (import_ptr) { + vk::MemoryHostPointerPropertiesEXT host_pointer_props; + try { + host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what()); + device->device.destroyBuffer(buf->buffer); + return {}; + } + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); + uint32_t memory_type_idx; + vk::MemoryPropertyFlags property_flags = *req_flags_list.begin(); + for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) { + if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } + if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } - if (memory_type_indices.empty()) { - continue; + vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx]; + // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed + if ((memory_type.propertyFlags & property_flags) == property_flags) { + property_flags = memory_type.propertyFlags; + break; + } + } + if (memory_type_idx == 32) { + GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n"); + device->device.destroyBuffer(buf->buffer); + return {}; } - buf->memory_property_flags = req_flags; - bool done = false; + buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags; + try { + vk::ImportMemoryHostPointerInfoEXT import_info; + import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + import_info.pHostPointer = import_ptr; + import_info.setPNext(&mem_flags_info); + buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info }); + } catch (const vk::SystemError& e) { + } + } else { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; - for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); - done = true; - break; - } catch (const vk::SystemError& e) { - // loop and retry - // during last attempt throw the exception - if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { - device->device.destroyBuffer(buf->buffer); - throw e; + const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_indices.empty()) { + continue; + } + buf->memory_property_flags = req_flags; + + bool done = false; + + for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); + done = true; + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } } } - } - if (done) { - break; + if (done) { + break; + } } } @@ -2492,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->ptr = nullptr; - if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { - buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + if (import_ptr) { + buf->ptr = import_ptr; + } else { + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } } device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); @@ -4447,6 +4505,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { device->memory_priority = true; + } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { + device->external_memory_host = true; } } @@ -4461,6 +4521,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props; props2.pNext = &props3; props3.pNext = &subgroup_props; @@ -4500,11 +4561,22 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; } + if (device->external_memory_host) { + last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props; + last_struct = (VkBaseOutStructure *)&external_memory_host_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; device->driver_id = driver_props.driverID; + if (device->driver_id == vk::DriverId::eMoltenvk) { + // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622 + // is available in the Vulkan SDK. + device->external_memory_host = false; + } + // Implementing the async backend interfaces seems broken on older Intel HW, // see https://github.com/ggml-org/llama.cpp/issues/17302. device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || @@ -4586,6 +4658,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); @@ -4717,6 +4791,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_pipeline_executable_properties"); } + if (device->external_memory_host) { + device_extensions.push_back("VK_EXT_external_memory_host"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->pipeline_executable_properties_support = pipeline_executable_properties_support; @@ -14773,6 +14851,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); } +static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { + if (!device->external_memory_host) { + return {}; + } + + uintptr_t uptr = reinterpret_cast(ptr); + if (uptr & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + if (size & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + + const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached; + + vk_buffer buf {}; + try { + buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what()); + } + + return buf; +} + +static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")"); + GGML_UNUSED(max_tensor_size); + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size); + + if (!buf) { + return {}; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name); + + ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size); + + return ret; +} + static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .get_name = */ ggml_backend_vk_device_get_name, /* .get_description = */ ggml_backend_vk_device_get_description, @@ -14782,7 +14905,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .init_backend = */ ggml_backend_vk_device_init, /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, - /* .buffer_from_host_ptr = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr, /* .supports_op = */ ggml_backend_vk_device_supports_op, /* .supports_buft = */ ggml_backend_vk_device_supports_buft, /* .offload_op = */ ggml_backend_vk_device_offload_op, From 07fbe19f1fbcfa09abca7cccc62eaf82c1567b7e Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Tue, 6 Jan 2026 17:51:08 +0100 Subject: [PATCH 5/8] arg: use CSV escape style for multiple-value args (#18643) * arg: use CSV escape style for multiple-value args * add test --- common/arg.cpp | 107 ++++++++++++++++++++++++-------------- tests/test-arg-parser.cpp | 9 ++++ 2 files changed, 76 insertions(+), 40 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index b52b3e70b78..c3610d262b3 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -854,6 +854,54 @@ bool common_arg_utils::is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } +// Simple CSV parser that handles quoted fields and escaped quotes +// example: +// input: value1,"value, with, commas","value with ""escaped"" quotes",value4 +// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4] +static std::vector parse_csv_row(const std::string& input) { + std::vector fields; + std::string field; + bool in_quotes = false; + + for (size_t i = 0; i < input.length(); ++i) { + char ch = input[i]; + + if (ch == '"') { + if (!in_quotes) { + // start of quoted field (only valid if at beginning of field) + if (!field.empty()) { + // quote appeared in middle of unquoted field, treat as literal + field += '"'; + } else { + in_quotes = true; // start + } + } else { + if (i + 1 < input.length() && input[i + 1] == '"') { + // escaped quote: "" + field += '"'; + ++i; // skip the next quote + } else { + in_quotes = false; // end + } + } + } else if (ch == ',') { + if (in_quotes) { + field += ','; + } else { + fields.push_back(std::move(field)); + field.clear(); + } + } else { + field += ch; + } + } + + // Add the last field + fields.push_back(std::move(field)); + + return fields; +} + common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { // per-example default params // we define here to make sure it's included in llama-gen-docs @@ -1250,7 +1298,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--in-file"}, "FNAME", "an input file (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2002,7 +2050,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--image", "--audio"}, "FILE", "path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.image.emplace_back(item); } } @@ -2259,37 +2307,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--override-kv"}, "KEY=TYPE:VALUE,...", - "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n" + "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n" "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false", [](common_params & params, const std::string & value) { - std::vector kv_overrides; - - std::string current; - bool escaping = false; - - for (const char c : value) { - if (escaping) { - current.push_back(c); - escaping = false; - } else if (c == '\\') { - escaping = true; - } else if (c == ',') { - kv_overrides.push_back(current); - current.clear(); - } else { - current.push_back(c); - } - } - - if (escaping) { - current.push_back('\\'); - } - - kv_overrides.push_back(current); - - for (const auto & kv_override : kv_overrides) { - if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) { - throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str())); + for (const auto & item : parse_csv_row(value)) { + if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) { + throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str())); } } } @@ -2306,7 +2329,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (use comma-separated values to load multiple adapters)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.lora_adapters.push_back({ item, 1.0, "", "", nullptr }); } } @@ -2317,7 +2340,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n" "note: use comma-separated values", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("lora-scaled format: FNAME:SCALE"); @@ -2331,7 +2354,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--control-vector"}, "FNAME", "add a control vector\nnote: use comma-separated values to add multiple control vectors", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.control_vectors.push_back({ 1.0f, item, }); } } @@ -2341,7 +2364,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "add a control vector with user defined scaling SCALE\n" "note: use comma-separated values (format: FNAME:SCALE,...)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE"); @@ -2439,7 +2462,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--context-file"}, "FNAME", "file to load context from (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item, std::ios::binary); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2675,9 +2698,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(common_arg( {"--api-key"}, "KEY", - "API key to use for authentication (default: none)", + "API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)", [](common_params & params, const std::string & value) { - params.api_keys.push_back(value); + for (const auto & key : parse_csv_row(value)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(common_arg( @@ -2691,7 +2718,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::string key; while (std::getline(key_file, key)) { if (!key.empty()) { - params.api_keys.push_back(key); + params.api_keys.push_back(key); } } key_file.close(); @@ -2713,7 +2740,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); add_opt(common_arg( {"--chat-template-kwargs"}, "STRING", - string_format("sets additional params for the json template parser"), + "sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'", [](common_params & params, const std::string & value) { auto parsed = json::parse(value); for (const auto & item : parsed.items()) { diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 1bbb745e784..e995974a2e7 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -127,6 +127,15 @@ int main(void) { assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); assert(params.speculative.n_max == 123); + // multi-value args (CSV) + argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"}; + assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); + assert(params.lora_adapters.size() == 4); + assert(params.lora_adapters[0].path == "file1.gguf"); + assert(params.lora_adapters[1].path == "file2,2.gguf"); + assert(params.lora_adapters[2].path == "file3\"3\".gguf"); + assert(params.lora_adapters[3].path == "file4\".gguf"); + // skip this part on windows, because setenv is not supported #ifdef _WIN32 printf("test-arg-parser: skip on windows build\n"); From 24af22fc365ea6ef8e37875108a83658aa16fc8a Mon Sep 17 00:00:00 2001 From: Aadeshveer Singh <24b0926@iitb.ac.in> Date: Tue, 6 Jan 2026 23:54:34 +0530 Subject: [PATCH 6/8] ggml : optimize cuda ssm_scan using warp-level reduction (#18505) * ggml : optimize cuda ssm_scan using warp-level reduction * ggml : apply code review suggestions (style, const, constexpr) * ggml : add TODO regarding stride consistency --- ggml/src/ggml-cuda/ssm-scan.cu | 133 ++++++++++++--------------------- 1 file changed, 49 insertions(+), 84 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 6b424381df5..c1d4e2bc8df 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1) #endif // __clang__ // assumes as many threads as d_state -template +template __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, @@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1) const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { - const int head_idx = (blockIdx.x * splitH) / d_head; - const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); - const int seq_idx = blockIdx.y; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int warp_idx = blockIdx.x * c_factor + warp; + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); - const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); - const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); - const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); - const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); - float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; - float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); // strides across n_seq_tokens const int stride_x = src1_nb2 / sizeof(float); @@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1) const int stride_C = src5_nb2 / sizeof(float); const int stride_y = n_head * d_head; - float state[splitH]; - // for the parallel accumulation - __shared__ float stateC[splitH * d_state]; + float state[c_factor]; + float state_sum = 0.0f; #pragma unroll - for (int j = 0; j < splitH; j++) { - state[j] = s0_block[j * d_state + threadIdx.x]; + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; } for (int64_t i = 0; i < n_tok; i++) { - // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements - // TODO: only calculate B and C once per head group - // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. - float dt_soft_plus = dt_block[i * stride_dt]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(expf(dt_soft_plus)); - } - const float dA = expf(dt_soft_plus * A_block[0]); - const float B = B_block[i * stride_B + threadIdx.x]; - const float C = C_block[i * stride_C + threadIdx.x]; + // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. + // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. + const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); - // across d_head + state_sum = 0.0f; + const float dA = expf(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; #pragma unroll - for (int j = 0; j < splitH; j++) { - const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; - - state[j] = (state[j] * dA) + (B * x_dt); - - stateC[j * d_state + threadIdx.x] = state[j] * C; + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; } - __syncthreads(); - - // parallel accumulation for stateC - // TODO: simplify - { - static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); - static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); - - // reduce until w matches the warp size - // TODO: does this work even when the physical warp size is 64? -#pragma unroll - for (int w = d_state; w > WARP_SIZE; w >>= 1) { - // (assuming there are d_state threads) -#pragma unroll - for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { - // TODO: check for bank conflicts - const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); - stateC[k] += stateC[k + (w >> 1)]; - - } - __syncthreads(); - } - - static_assert(splitH >= d_state / WARP_SIZE); + // parallel accumulation for output + state_sum = warp_reduce_sum(state_sum); -#pragma unroll - for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { - float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; - y = warp_reduce_sum(y); - - // store the above accumulations - if (threadIdx.x % WARP_SIZE == 0) { - const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); - y_block[i * stride_y + k] = y; - } - } + if (lane == 0) { + y_warp[i * stride_y] = state_sum; } } // write back the state #pragma unroll - for (int j = 0; j < splitH; j++) { - s_block[j * d_state + threadIdx.x] = state[j]; + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; } } @@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { - const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { - GGML_ASSERT(d_state % threads == 0); - // NOTE: can be any power of two between 4 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 128><<>>( + constexpr int threads = 128; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); } else if (d_state == 256) { // Falcon-H1 - const int threads = 256; - // NOTE: can be any power of two between 8 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 256><<>>( + constexpr int threads = 256; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa } } else { // Mamba-1 + constexpr int threads = 128; GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); From 68b4d516c305325d31e698c4673b691d2a9d879f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 6 Jan 2026 20:02:30 +0100 Subject: [PATCH 7/8] llama-params-fit: fix last devices with low VRAM (#18494) --- src/llama.cpp | 66 +++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 98fb770844c..0162ae8d58c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -359,6 +359,11 @@ static void llama_params_fit_impl( // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + + uint32_t n_full() const { + assert(n_layer >= n_part); + return n_layer - n_part; + } }; const size_t ntbo = llama_max_tensor_buft_overrides(); @@ -382,7 +387,7 @@ static void llama_params_fit_impl( size_t itbo = 0; for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part; + il0 += ngl_per_device[id].n_full(); for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { if (itbo + 1 >= ntbo) { tensor_buft_overrides[itbo].pattern = nullptr; @@ -393,7 +398,7 @@ static void llama_params_fit_impl( + std::to_string(ntbo) + " is insufficient for model"); } tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = overflow_bufts[id]; + tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); itbo++; } il0 += ngl_per_device[id].n_part; @@ -468,20 +473,14 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } - std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: + std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd - 1; ++id) { - overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1])); + for (size_t id = 0; id < nd; id++) { + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); } - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); std::vector ngl_per_device(nd); std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - if (hp_nex > 0) { - for (size_t id = 0; id < nd; id++) { - ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; - } - } // optimize the number of layers per device using the method of false position: // - ngl_per_device has 0 layers for each device, lower bound @@ -512,9 +511,6 @@ static void llama_params_fit_impl( if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - if (hp_nex > 0 && size_t(id) == nd - 1) { - delta--; - } LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); @@ -524,7 +520,8 @@ static void llama_params_fit_impl( std::vector ngl_per_device_test = ngl_per_device; ngl_per_device_test[id].n_layer += step_size; if (hp_nex) { - ngl_per_device_test[id].n_part += step_size; + ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? + step_size - 1 : step_size; // the first layer is the output layer which must always be full } const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); @@ -573,7 +570,7 @@ static void llama_params_fit_impl( assert(id_dense_start < nd); LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start; id++) { + for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { std::vector ngl_per_device_high = ngl_per_device; for (size_t jd = id_dense_start; jd < nd; jd++) { const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; @@ -585,12 +582,8 @@ static void llama_params_fit_impl( std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); - assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part); - assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part); - uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); step_size = std::max(step_size, uint32_t(1)); @@ -606,7 +599,7 @@ static void llama_params_fit_impl( ngl_per_device_test[id].n_layer += n_convert_jd; n_converted_test += n_convert_jd; - if (ngl_per_device_test[id_dense_start_test].n_layer > 0) { + if (ngl_per_device_test[id_dense_start_test].n_part > 0) { break; } } @@ -625,8 +618,8 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); } - delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); } } else { ngl_per_device = ngl_per_device_high; @@ -644,14 +637,19 @@ static void llama_params_fit_impl( ngl_per_device_test[id_dense_start_test].n_part--; ngl_per_device_test[id].n_layer++; ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_layer == 0) { + if (ngl_per_device_test[id_dense_start_test].n_part == 0) { id_dense_start_test++; } ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + std::vector overflow_bufts_test = overflow_bufts; + if (id < nd - 1) { + overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); + } LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", @@ -659,9 +657,10 @@ static void llama_params_fit_impl( ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", @@ -670,9 +669,10 @@ static void llama_params_fit_impl( } else { ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", @@ -687,6 +687,14 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); } + // print info for devices that were not changed during the conversion from dense only to full layers: + for (size_t id = id_dense_start + 1; id < nd; id++) { + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); } From ccbc84a5374bab7a01f68b129411772ddd8e7c79 Mon Sep 17 00:00:00 2001 From: Tarek Dakhran Date: Tue, 6 Jan 2026 21:00:29 +0100 Subject: [PATCH 8/8] mtmd: mtmd_audio_streaming_istft (#18645) Change is decoupled from https://github.com/ggml-org/llama.cpp/pull/18641. [LFM2.5-Audio-1.5B](https://huggingface.co/LiquidAI/LFM2.5-Audio-1.5B) needs streaming istft for generating output audio. * add streaming ISTFT class (`mtmd_audio_streaming_istft`) with overlap-add for audio reconstruction * replace global audio cache with per-instance cache, the model requires two independent caches, for preprocessing (audio input) and for istft (audio output). * unified templated FFT/IFFT implementation supporting both forward and inverse transforms --- tools/mtmd/mtmd-audio.cpp | 576 +++++++++++++++++++++++--------------- tools/mtmd/mtmd-audio.h | 73 +++++ 2 files changed, 431 insertions(+), 218 deletions(-) diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index e99101184b1..e8eef035ff5 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -9,207 +9,250 @@ #include #include -// most of the code here is copied from whisper.cpp +// some of the code here is copied from whisper.cpp constexpr bool DEBUG = false; -struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; - - std::vector data; -}; - -// note: this global cache is shared among all preprocessors -// if we want to use multiple preprocessors at the same time, -// we will need to enclose it in the preprocessor class in the future -static struct mtmd_audio_global_cache { - // precomputed sin/cos table for FFT - std::vector sin_vals; - std::vector cos_vals; - - // hann window - std::vector hann_window; - - // mel filter bank - mtmd_audio_mel_filters filters; - - void fill_sin_cos_table(int n) { - sin_vals.resize(n); - cos_vals.resize(n); - for (int i = 0; i < n; i++) { - double theta = (2 * M_PI * i) / n; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); - } +void mtmd_audio_cache::fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); } +} - void fill_hann_window(int length, bool periodic) { - hann_window.resize(length); - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } +void mtmd_audio_cache::fill_hann_window(int length, bool periodic) { + hann_window.resize(length); + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); } +} - // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. - // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix( - int n_mel, - int n_fft, - int sample_rate, // e.g. 16000 - float fmin = 0.0f, // e.g. 0.0 - float fmax = -1.0f, // e.g. sr/2; pass -1 for auto - bool slaney_area_norm = true, - float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code - ) { - GGML_ASSERT(n_mel > 0 && n_fft > 1); - if (fmax <= 0.0f) { - fmax = 0.5f * sample_rate; - } +void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, + float fmin, + float fmax, + bool slaney_area_norm, + float scale) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; + } - // Slaney scale (matches librosa default) - const double min_log_hz = 1000.0; - const double lin_slope = 3 / 200.; - const double min_log_mel = min_log_hz * lin_slope; - const double log_step = log(6.4) / 27.0; - auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { - return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; - }; - auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { - return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); - }; - - // infer N_fft from n_fft_bins - const double bin_hz_step = double(sample_rate) / double(n_fft); - - // mel grid: n_mel + 2 edges - const double m_lo = hz_to_mel(fmin); - const double m_hi = hz_to_mel(fmax); - std::vector mel_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); - } + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); + } - // convert to Hz - std::vector hz_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - hz_pts[i] = mel_to_hz(mel_pts[i]); - } + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } - const int n_fft_bins = n_fft / 2 + 1; - - // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { - const double f_left = hz_pts[m]; - const double f_center = hz_pts[m + 1]; - const double f_right = hz_pts[m + 2]; - - const double denom_l = std::max(1e-30, f_center - f_left); - const double denom_r = std::max(1e-30, f_right - f_center); - const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - - for (int k = 0; k < n_fft_bins; ++k) { - const double f = k * bin_hz_step; - double w = 0.0; - if (f >= f_left && f <= f_center) { - w = (f - f_left) / denom_l; - } else if (f > f_center && f <= f_right) { - w = (f_right - f) / denom_r; - } - out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + const int n_fft_bins = n_fft / 2 + 1; + + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; + + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; + + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); } + } - filters.n_mel = n_mel; - filters.n_fft = n_fft; - filters.data = std::move(out); + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); - if (DEBUG) { // debug - for (size_t i = 0; i < filters.data.size(); ++i) { - if (filters.data[i] != 0.0f) { - printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); - } + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); } } } -} g_cache; +} -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); - const int sin_cos_step = n_sin_cos_vals / N; +// Unified DFT implementation for both forward and inverse transforms +// Template parameters: +// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling +// true = IDFT with exp(+2πi·k·n/N), scales by 1/N +// RealInput: true = input is real-valued (stride 1), avoids imaginary computations +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + const float scale = Inverse ? (1.0f / N) : 1.0f; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N - re += in[n] * g_cache.cos_vals[idx]; // cos(t) - im -= in[n] * g_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % n_sin_cos_vals; + float cos_val = cache.cos_vals[idx]; + float sin_val = cache.sin_vals[idx]; + + if constexpr (RealInput) { + // Real input: in_im = 0, simplifies to: + // re += in_re * cos_val + // im += sign * in_re * sin_val + float in_re = in[n]; + re += in_re * cos_val; + im += sign * in_re * sin_val; + } else { + float in_re = in[n * 2 + 0]; + float in_im = in[n * 2 + 1]; + // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i + re += in_re * cos_val - sign * in_im * sin_val; + im += sign * in_re * sin_val + in_im * cos_val; + } } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re * scale; + out[k * 2 + 1] = im * scale; } } -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); +// Cooley-Tukey FFT/IFFT unified implementation +// Template parameters: +// Inverse: false = FFT with exp(-2πi·k/N), no scaling +// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level +// RealInput: true = input is real-valued (stride 1) +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + if (N == 1) { out[0] = in[0]; - out[1] = 0; + if constexpr (RealInput) { + out[1] = 0.0f; + } else { + out[1] = in[1]; + } return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { - dft(in, N, out); + if (N - half_N * 2 == 1) { + // Odd N: fall back to DFT + dft_impl(cache, in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; - } - float* even_fft = out + 2 * N; - fft(even, half_N, even_fft); + // Split into even and odd + if constexpr (RealInput) { + // Real input: stride is 1, copy only real values + float * even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i] = in[2 * i]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); + + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2 * i + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); + } else { + // Complex input: stride is 2, copy complex pairs + float * even = in + N * 2; + for (int i = 0; i < half_N; ++i) { + even[i * 2 + 0] = in[2 * i * 2 + 0]; + even[i * 2 + 1] = in[2 * i * 2 + 1]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0]; + odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); } - float* odd_fft = even_fft + N; - fft(odd, half_N, odd_fft); + + float * even_fft = out + 2 * N; + float * odd_fft = even_fft + N; const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + constexpr float scale = Inverse ? 0.5f : 1.0f; + for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = g_cache.cos_vals[idx]; // cos(t) - float im = -g_cache.sin_vals[idx]; // sin(t) + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; + float im = sign * cache.sin_vals[idx]; - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd); + out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd); - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd); + out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd); } } +// Forward FFT for real input (used by mel spectrogram) +static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + +// Inverse FFT for complex input +static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + struct filter_params { int32_t n_mel; int32_t n_fft_bins; @@ -222,20 +265,27 @@ struct filter_params { bool norm_per_feature = false; }; -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, - const filter_params & params, mtmd_audio_mel & out) { +static void log_mel_spectrogram_worker_thread(int ith, + const float * hann, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const filter_params & params, + const mtmd_audio_cache & cache, + mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft_bins = params.n_fft_bins; int i = ith; - const auto & filters = g_cache.filters; + const auto & filters = cache.filters; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); - GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); + GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; @@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } // FFT - fft(fft_in.data(), frame_size, fft_out.data()); + fft(cache, fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -298,6 +348,7 @@ static bool log_mel_spectrogram( const int n_samples_in, const int n_threads, const filter_params & params, + const mtmd_audio_cache & cache, mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); @@ -305,9 +356,9 @@ static bool log_mel_spectrogram( int n_samples = n_samples_in; // Hann window - const float * hann = g_cache.hann_window.data(); - const int frame_size = (params.n_fft_bins - 1) * 2; - const int frame_step = params.hop_length; + const float * hann = cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; // Padding std::vector samples_padded; @@ -335,9 +386,9 @@ static bool log_mel_spectrogram( // preemphasis if (params.preemph) { - const int pad_amount = frame_size / 2; + const int pad_amount = frame_size / 2; const float preemph = 0.97f; - float prev = samples_padded[pad_amount]; + float prev = samples_padded[pad_amount]; for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { float cur = samples_padded[i]; samples_padded[i] = cur - preemph * prev; @@ -372,14 +423,14 @@ static bool log_mel_spectrogram( { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples, frame_size, frame_step, n_threads, - std::cref(params), std::ref(out)); + workers[iw] = + std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples, + frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, + cache, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } @@ -404,7 +455,7 @@ static bool log_mel_spectrogram( for (int j = 0; j < effective_n_len; ++j) { auto &value = out.data[i * out.n_len + j]; - value = (value - mean) / mstd; + value = (value - mean) / mstd; } // pad the rest with zeros @@ -450,18 +501,14 @@ static bool log_mel_spectrogram( // void mtmd_audio_preprocessor_whisper::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_whisper::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { if (n_samples == 0) { // empty audio return false; @@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // if input is too short, pad with zeros // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram // TODO: maybe handle this better - size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin if (n_samples < min_samples) { smpl.resize(min_samples, 0.0f); std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); @@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess( params.hop_length = hparams.audio_hop_len; params.sample_rate = hparams.audio_sample_rate; params.center_padding = false; - params.preemph = 0.0f; // disabled + params.preemph = 0.0f; // disabled params.use_natural_log = false; params.norm_per_feature = false; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess( printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); } const size_t frames_per_chunk = 3000; - GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); - for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); - if ((size_t)n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); + if ((size_t) n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; - out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.n_len_org = out_full.n_mel; // unused out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i*out_full.n_len + off; + auto src = out_full.data.begin() + i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // void mtmd_audio_preprocessor_conformer::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_conformer::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { // empty audio if (n_samples == 0) { return false; @@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess( params.use_natural_log = true; params.norm_per_feature = true; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess( output.push_back(std::move(out_full)); return true; } + +// +// mtmd_audio_streaming_istft implementation +// + +mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) : + n_fft(n_fft), + hop_length(hop_length), + n_fft_bins(n_fft / 2 + 1), + overlap_buffer(n_fft, 0.0f), + window_sum_buffer(n_fft, 0.0f), + padding_to_remove((n_fft - hop_length) / 2), + ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT + ifft_out(n_fft * 2 * 4, 0.0f) { + cache.fill_sin_cos_table(n_fft); + cache.fill_hann_window(n_fft, true); +} + +void mtmd_audio_streaming_istft::reset() { + std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f); + std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f); + padding_to_remove = (n_fft - hop_length) / 2; +} + +std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) { + std::vector output(hop_length); + + // copy frequencies + for (int j = 0; j < n_fft_bins; j++) { + ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0]; + ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1]; + } + + // mirror negative frequencies + for (int j = 1; j < n_fft_bins - 1; j++) { + int mirror_idx = n_fft - j; + ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0]; + ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate + } + + ifft(cache, ifft_in.data(), n_fft, ifft_out.data()); + + // update window sum and overlap buffer + for (int j = 0; j < n_fft; j++) { + window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j]; + overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j]; + } + + // extract hop_length samples with normalization + for (int i = 0; i < hop_length; i++) { + if (window_sum_buffer[i] > 1e-8f) { + output[i] = overlap_buffer[i] / window_sum_buffer[i]; + } else { + output[i] = overlap_buffer[i]; + } + } + + // shift buffers left by hop_length + std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f); + + // Remove padding if needed + int to_remove = std::min(padding_to_remove, (int) output.size()); + padding_to_remove -= to_remove; + output.erase(output.begin(), output.begin() + to_remove); + + return output; +} + +std::vector mtmd_audio_streaming_istft::flush() { + std::vector output; + + // Extract remaining samples from overlap buffer + // Continue until we've extracted all meaningful samples + int remaining = n_fft - hop_length; + while (remaining > 0) { + int chunk_size = std::min(remaining, hop_length); + + for (int i = 0; i < chunk_size; i++) { + float sample; + if (window_sum_buffer[i] > 1e-8f) { + sample = overlap_buffer[i] / window_sum_buffer[i]; + } else { + sample = overlap_buffer[i]; + } + output.push_back(sample); + } + + // Shift buffers + std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f); + + remaining -= chunk_size; + } + + return output; +} diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index d484c9d0301..016c7392e4f 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -17,6 +17,38 @@ struct mtmd_audio_mel { std::vector data; }; +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// cache for audio processing, each processor instance owns its own cache +struct mtmd_audio_cache { + std::vector sin_vals; + std::vector cos_vals; + + std::vector hann_window; + + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n); + + void fill_hann_window(int length, bool periodic); + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling + ); +}; + struct mtmd_audio_preprocessor { const clip_hparams & hparams; @@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; }; struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor { mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +// +// streaming ISTFT - converts spectrogram frames back to audio one frame at a time +// +struct mtmd_audio_streaming_istft { + mtmd_audio_streaming_istft(int n_fft, int hop_length); + + // reset streaming state + void reset(); + + // process a single STFT frame (streaming) + // frame_spectrum: [n_fft_bins x 2] interleaved real/imag + // returns: up to hop_length samples + std::vector process_frame(const float * frame_spectrum); + + // flush remaining samples at end of stream + std::vector flush(); + + private: + int n_fft; + int hop_length; + int n_fft_bins; + + // Own cache for output processing + mtmd_audio_cache cache; + + // Streaming state + std::vector overlap_buffer; + std::vector window_sum_buffer; + int padding_to_remove; + + // Working buffers for IFFT + std::vector ifft_in; + std::vector ifft_out; };