diff --git a/README.md b/README.md index e59612f7aed..0d9d1ef6b44 100644 --- a/README.md +++ b/README.md @@ -200,6 +200,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo *(to have a project listed here, it should clearly state that it depends on `llama.cpp`)* - [AI Sublime Text plugin](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (MIT) +- [BonzAI App](https://apps.apple.com/us/app/bonzai-your-local-ai-agent/id6752847988) (proprietary) - [cztomsik/ava](https://github.com/cztomsik/ava) (MIT) - [Dot](https://github.com/alexpinel/Dot) (GPL) - [eva](https://github.com/ylsdamxssjxxdd/eva) (MIT) diff --git a/SECURITY.md b/SECURITY.md index ae496f4e3da..dd3a78d2909 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,12 +1,48 @@ # Security Policy + - [**Reporting a vulnerability**](#reporting-a-vulnerability) + - [**Requirements**](#requirements) + - [**Covered Topics**](#covered-topics) - [**Using llama.cpp securely**](#using-llamacpp-securely) - [Untrusted models](#untrusted-models) - [Untrusted inputs](#untrusted-inputs) - [Data privacy](#data-privacy) - [Untrusted environments or networks](#untrusted-environments-or-networks) - [Multi-Tenant environments](#multi-tenant-environments) - - [**Reporting a vulnerability**](#reporting-a-vulnerability) + +## Reporting a vulnerability + +If you have discovered a security vulnerability in this project that falls inside the [covered topics](#covered-topics), please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. + +Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). + +A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. + +> [!IMPORTANT] +> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 + +## Requirements + +Before submitting your report, ensure you meet the following requirements: + +- You have read this policy and fully understand it. +- AI is only permitted in an assistive capacity as stated in [AGENTS.md](AGENTS.md). We do not accept reports that are written exclusively by AI. +- Your report must include a working Proof-of-Concept in the form of a script and/or attached files. + +Maintainers reserve the right to close the report if these requirements are not fulfilled. + +## Covered Topics + +Only vulnerabilities that fall within these parts of the project are considered valid. For problems falling outside of this list, please report them as issues. + +- `src/**/*` +- `ggml/**/*` +- `gguf-py/**/*` +- `tools/server/*` (note: Web UI is not covered) + +Note that none of the topics under [Using llama.cpp securely](#using-llamacpp-securely) are considered vulnerabilities in LLaMA C++. + +For vulnerabilities that fall within the `vendor` directory, please report them directly to the third-party project. ## Using llama.cpp securely @@ -55,19 +91,3 @@ If you intend to run multiple models in parallel with shared memory, it is your 3. Model Sharing: In a multitenant model sharing design, tenants and users must understand the security risks of running code provided by others. Since there are no reliable methods to detect malicious models, sandboxing the model execution is the recommended approach to mitigate the risk. 4. Hardware Attacks: GPUs or TPUs can also be attacked. [Researches](https://scholar.google.com/scholar?q=gpu+side+channel) has shown that side channel attacks on GPUs are possible, which can make data leak from other models or processes running on the same system at the same time. - -## Reporting a vulnerability - -Beware that none of the topics under [Using llama.cpp securely](#using-llamacpp-securely) are considered vulnerabilities of LLaMA C++. - - -However, If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. - -Please disclose it as a private [security advisory](https://github.com/ggml-org/llama.cpp/security/advisories/new). - -Please note that using AI to identify vulnerabilities and generate reports is permitted. However, you must (1) explicitly disclose how AI was used and (2) conduct a thorough manual review before submitting the report. - -A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. - -> [!IMPORTANT] -> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index ead180523c8..cc5e3691c36 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4367,7 +4367,37 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"): data_torch = data_torch + 1 - yield from super().modify_tensors(data_torch, name, bid) + if "in_proj_qkvz.weight" in name: + # original order: [q, k, v, z] * head_count + # corrected order: [q * head_count, k * head_count, v * head_count, z * head_count] + head_k_dim = self.hparams["linear_key_head_dim"] + head_v_dim = self.hparams["linear_value_head_dim"] + num_v_heads = self.hparams["linear_num_value_heads"] + num_k_heads = self.hparams["linear_num_key_heads"] + hidden_size = self.hparams["hidden_size"] + split_arg_list_qkvz = [ + head_k_dim, # q partition + head_k_dim, # k partition + (num_v_heads // num_k_heads * head_v_dim), # v partition + (num_v_heads // num_k_heads * head_v_dim), # z partition + ] + # view as (n_embd, head_count, [q+k+v+z]) + data_torch = data_torch.permute(1, 0).contiguous() + data_torch = data_torch.view(-1, num_k_heads, sum(split_arg_list_qkvz)) + # split into q, k, v, z + q, k, v, z = torch.split(data_torch, split_arg_list_qkvz, dim=-1) + # flatten dim + head_count + q = q.contiguous().view(hidden_size, -1) + k = k.contiguous().view(hidden_size, -1) + v = v.contiguous().view(hidden_size, -1) + z = z.contiguous().view(hidden_size, -1) + # stack back + qkv = torch.cat([q, k, v], dim=-1).permute(1, 0).contiguous() + z = z.permute(1, 0).contiguous() + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_QKV, bid, ".weight"), qkv) + yield (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_GATE, bid, ".weight"), z) + else: + yield from super().modify_tensors(data_torch, name, bid) @ModelBase.register("RND1") diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b240e8e4a6b..404af1ef03f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1738,6 +1738,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, MODEL_TENSOR.FFN_GATE_INP, MODEL_TENSOR.FFN_GATE_INP_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 2ead965469a..f736ee67050 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -950,6 +950,8 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_ATTN_V, LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5de6493b9e9..f6cea8f8db4 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6763,7 +6763,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } else { // Linear attention (gated delta net) specific tensors // Create tensors with calculated dimensions - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0); + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); diff --git a/src/models/models.h b/src/models/models.h index 72b2b760c69..6c40f48042b 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -466,7 +466,8 @@ struct llm_build_qwen3next : public llm_graph_context_mamba { ggml_tensor * cur, int il); - ggml_tensor * build_delta_net_chunking( + // returns pair of output and new state + std::pair build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -478,7 +479,8 @@ struct llm_build_qwen3next : public llm_graph_context_mamba { ggml_tensor * diag_mask, int il); - ggml_tensor * build_delta_net_autoregressive( + // returns pair of output and new state + std::pair build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -493,6 +495,11 @@ struct llm_build_qwen3next : public llm_graph_context_mamba { ggml_tensor * gate, int layer); + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + const llama_model & model; }; diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 775b3135d35..0e4fe7ebdc3 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -86,7 +86,15 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +std::pair llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -187,18 +195,16 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); + cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - cb(g_cumsum, "g_cumsum", il); - - ggml_tensor * gcs_i = ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); + ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * gcs_j_broadcast = ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - - cb(decay_mask, "decay_mask", il); + cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); decay_mask = ggml_exp(ctx0, decay_mask); @@ -208,8 +214,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - - cb(attn, "attn_pre_solve", il); + cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); @@ -217,8 +222,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); attn = ggml_mul(ctx0, lin_solve, causal_mask); attn = ggml_add(ctx0, attn, identity); - - cb(attn, "attn_solved", il); + cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); @@ -226,116 +230,126 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_chunking( ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - - cb(kbeta_gexp, "kbeta_gexp", il); + cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) ggml_tensor * k_cumdecay = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); + cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - cb(k_cumdecay, "k_cumdecay", il); + ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); + attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); + attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); + cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - ggml_tensor * new_state = ggml_dup(ctx0, state); - cb(new_state, "new_state", il); + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - auto chunkify = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; + // get last element in g_cumsum along chunk_size dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], + g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], + (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); + g_last = ggml_cont(ctx0, g_last); + cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - auto chunkify_g = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, t->ne[1], 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); + cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp); + cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) + + + // state to be updated per chunk + ggml_tensor * new_state = state; // ggml_dup(ctx0, state); + cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - ggml_tensor * k_chunk = chunkify(k); - ggml_tensor * q_chunk = chunkify(q); - ggml_tensor * v_chunk = chunkify(v); + // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) + ggml_tensor * core_attn_out = nullptr; + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + // shape: (S_k, chunk_size, 1, H_k * n_seqs) + ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - ggml_tensor * g_cs_chunk = chunkify_g(g_cumsum); - ggml_tensor * g_cs_chunk_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cs_chunk)); + // shape: (S_v, chunk_size, 1, H_v * n_seqs) + ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - ggml_tensor * decay_mask_chunk = chunkify(decay_mask); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); + // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) + ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - ggml_tensor * gexp_chunk = ggml_exp(ctx0, g_cs_chunk_t); + // shape: (chunk_size, 1, H_v * n_seqs) + ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - attn = ggml_mul_mat(ctx0, k_chunk, q_chunk); - attn = ggml_mul(ctx0, attn, decay_mask_chunk); - attn = ggml_mul(ctx0, attn, diag_mask); + // replaced by precomputed attn_kq + ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); + cb(attn_chunk, "attn_chunk", il); ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); + cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) // v_new = v_i - v_prime ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); + cb(v_new, "v_new_chunk", il); // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); + cb(attn_inter, "attn_inter_chunk", il); // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn); + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); + cb(v_attn, "v_attn_chunk", il); ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); + cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); + core_attn_out = core_attn_out == nullptr + ? core_attn_out_chunk + : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - ggml_tensor * g_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, g_cs_chunk_t, g_cs_chunk_t->ne[0], 1, g_cs_chunk_t->ne[2], g_cs_chunk_t->ne[3], - g_cs_chunk_t->nb[1], g_cs_chunk_t->nb[2], g_cs_chunk_t->nb[3], - g_cs_chunk_t->nb[0] * (g_cs_chunk_t->ne[1] - 1))); - - ggml_tensor * gexp_last = - ggml_reshape_4d(ctx0, ggml_exp(ctx0, g_cum_last), 1, 1, g_cum_last->ne[0] * g_cum_last->ne[2], g_cum_last->ne[3]); - - ggml_tensor * g_cum_last_3d = - ggml_reshape_3d(ctx0, g_cum_last, g_cum_last->ne[0], g_cum_last->ne[2], g_cum_last->ne[3]); - - ggml_tensor * g_cumsum_3d = ggml_reshape_3d(ctx0, g_cs_chunk, g_cs_chunk->ne[0], g_cs_chunk->ne[2], g_cs_chunk->ne[3]); - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum_3d, g_cum_last_3d)); - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k_chunk, - ggml_reshape_4d(ctx0, g_diff_exp, 1, g_diff_exp->ne[0], g_diff_exp->ne[1], - g_diff_exp->ne[2] * g_diff_exp->ne[3])); - - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff))); + ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk)); + //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? + ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last, gexp_last->ne[0], gexp_last->ne[1], H_v, n_seqs)), + ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); } - core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, S_v, n_tokens, H_v, n_seqs, core_attn_out->nb[1], core_attn_out->nb[2], core_attn_out->nb[3], 0); + // truncate padded tokens + ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(core_attn_out->type, S_v), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), + ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); + output_tokens = ggml_cont(ctx0, output_tokens); cb(output_tokens, "output_tokens", il); - // flatten output - ggml_tensor * flat_output = - ggml_cont_1d(ctx0, ggml_permute(ctx0, output_tokens, 0, 2, 1, 3), S_v * H_v * n_tokens * n_seqs); - - ggml_tensor * flat_state = ggml_cont_1d(ctx0, new_state, S_v * S_v * H_v * n_seqs); + // permute back to (S_v, H_v, n_tokens, n_seqs) + output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); + output_tokens = ggml_cont(ctx0, output_tokens); - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {output_tokens, new_state}; } -ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( +std::pair llm_build_qwen3next::build_delta_net_autoregressive( ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -419,11 +433,7 @@ ggml_tensor * llm_build_qwen3next::build_delta_net_autoregressive( cb(core_attn_out, "output_tokens", il); cb(state, "new_state", il); - // flatten output, no need to permute since n_tokens is 1 so [S_v, 1, H_v, n_seqs] and [S_v, H_v, 1, n_seqs] are equivalent memory-layout wise - ggml_tensor * flat_output = ggml_reshape_1d(ctx0, core_attn_out, S_v * H_v * n_tokens * n_seqs); - ggml_tensor * flat_state = ggml_reshape_1d(ctx0, state, S_v * S_v * H_v * n_seqs); - - return ggml_concat(ctx0, flat_output, flat_state, 0); + return {core_attn_out, state}; } ggml_tensor * llm_build_qwen3next::build_norm_gated( @@ -523,6 +533,87 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( return cur; } +std::pair llm_build_qwen3next::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + if (model.layers[il].wqkv) { + // optimized path + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; + + } else { + // legacy (slower) path + ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); + cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + + int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); + ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); + + // Split mixed_qkvz into query, key, value, z + int64_t split_sizes_qkvz[4] = { + head_k_dim, // query size + head_k_dim, // key size + head_v_dim * num_v_heads / num_k_heads, // value size + head_v_dim * num_v_heads / num_k_heads // z size + }; + + ggml_tensor * query = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); + cb(query, "q", il); + + ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + split_sizes_qkvz[0] * ggml_element_size(mixed_qkvz_reshaped)); + cb(key, "k", il); + + ggml_tensor * value = + ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * ggml_element_size(mixed_qkvz_reshaped)); + cb(value, "v", il); + + ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, + mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], + (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * ggml_element_size(mixed_qkvz_reshaped)); + cb(z, "z", il); + + // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions + // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(query_flat, "query_flat", il); + + // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] + ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); + cb(key_flat, "key_flat", il); + + // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] + ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(value_flat, "value_flat", il); + + // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] + ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); + qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); + cb(qkv_mixed, "qkv_mixed", il); + + return { qkv_mixed, z }; + } +} + ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, @@ -547,15 +638,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); // Input projections - ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, cur); - cb(mixed_qkvz, "linear_attn_mixed_qkvz", il); + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; ggml_tensor * mixed_ba = build_lora_mm(model.layers[il].ssm_beta_alpha, cur); cb(mixed_ba, "linear_attn_mixed_ba", il); - int64_t qkvz_new_dim = 2 * head_k_dim + 2 * head_v_dim * (num_v_heads / num_k_heads); - ggml_tensor * mixed_qkvz_reshaped = ggml_reshape_4d(ctx0, mixed_qkvz, qkvz_new_dim, num_k_heads, n_seq_tokens, n_seqs); - // Reshape mixed_ba: [batch, seq_len, hidden_size] -> [batch, seq_len, num_k_heads, 2*num_v_heads/num_k_heads] int64_t ba_new_dim = 2 * num_v_heads / num_k_heads; ggml_tensor * mixed_ba_reshaped = ggml_reshape_4d(ctx0, mixed_ba, ba_new_dim, num_k_heads, n_seq_tokens, n_seqs); @@ -575,8 +664,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - // Reshape b and a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] - ggml_tensor * beta = ggml_cont_3d(ctx0, b, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + + // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); @@ -585,48 +675,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); - // Split mixed_qkvz into query, key, value, z - int64_t split_sizes_qkvz[4] = { - head_k_dim, // query size - head_k_dim, // key size - head_v_dim * num_v_heads / num_k_heads, // value size - head_v_dim * num_v_heads / num_k_heads // z size - }; - - ggml_tensor * query = - ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[0], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], 0); - cb(query, "q", il); - - ggml_tensor * key = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[1], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - split_sizes_qkvz[0] * sizeof(float)); - cb(key, "k", il); - - ggml_tensor * value = - ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[2], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - (split_sizes_qkvz[0] + split_sizes_qkvz[1]) * sizeof(float)); - cb(value, "v", il); - - ggml_tensor * z = ggml_view_4d(ctx0, mixed_qkvz_reshaped, split_sizes_qkvz[3], num_k_heads, n_seq_tokens, n_seqs, - mixed_qkvz_reshaped->nb[1], mixed_qkvz_reshaped->nb[2], mixed_qkvz_reshaped->nb[3], - (split_sizes_qkvz[0] + split_sizes_qkvz[1] + split_sizes_qkvz[2]) * sizeof(float)); - cb(z, "z", il); - - // After creating query, key, and value_reshaped, reshape each to flatten the head dimensions - // query: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] - ggml_tensor * query_flat = ggml_cont_3d(ctx0, query, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); - cb(query_flat, "query_flat", il); - - // key: [head_k_dim, num_k_heads, n_tokens, n_seqs] -> [head_k_dim * num_k_heads, n_tokens, n_seqs] - ggml_tensor * key_flat = ggml_cont_3d(ctx0, key, head_k_dim * num_k_heads, n_seq_tokens, n_seqs); - cb(key_flat, "key_flat", il); - - // value_reshaped: [head_v_dim, num_v_heads, n_tokens, n_seqs] -> [head_v_dim * num_v_heads, n_tokens, n_seqs] - ggml_tensor * value_flat = ggml_cont_3d(ctx0, value, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); - cb(value_flat, "value_flat", il); - // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -637,17 +685,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); - // Now concatenate along the feature dimension (dim 0) to get [conv_dim, n_tokens, n_seqs] - ggml_tensor * qkv_mixed = ggml_concat(ctx0, query_flat, key_flat, 0); - qkv_mixed = ggml_concat(ctx0, qkv_mixed, value_flat, 0); - cb(qkv_mixed, "qkv_mixed", il); - - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); - - // Calculate the total conv dimension - int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; @@ -655,6 +692,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); + qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); + cb(qkv_mixed, "qkv_mixed_permuted", il); + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -677,26 +717,25 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); - conv_output_proper = ggml_cont(ctx0, ggml_transpose(ctx0, conv_output_proper)); - cb(conv_output_proper, "conv_output_pre_silu", il); - ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); cb(conv_output_silu, "conv_output_silu", il); - ggml_tensor * conv_qkv_mix = - ggml_cont_2d(ctx0, ggml_transpose(ctx0, conv_output_silu), qkv_dim, n_seq_tokens * n_seqs); - cb(conv_qkv_mix, "conv_qkv_mix", il); + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], 0); + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); cb(q_conv, "q_conv", il); ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, conv_qkv_mix->nb[1], + ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); @@ -705,8 +744,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); cb(state, "state_predelta", il); @@ -738,45 +775,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(v_conv, "v_conv_predelta", il); // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - ggml_tensor * attn_out; + std::pair attn_out; // pair of (output, new_state) if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); } - cb(attn_out, "attn_out", il); - - // The tensors were concatenated 1d, so we need to extract them 1d as well - const int64_t output_flat_size = head_v_dim * num_v_heads * n_seq_tokens * n_seqs; - ggml_tensor * attn_out_1d = ggml_view_1d(ctx0, attn_out, output_flat_size, 0); - cb(attn_out_1d, "attn_out_1d", il); - - ggml_tensor * attn_out_final = ggml_cont_4d(ctx0, attn_out_1d, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); - cb(attn_out_final, "attn_out_reshaped", il); - - // Extract the state part (second part of the concatenated tensor) - // State starts after n_tokens elements along dimension 1 - const int64_t state_flat_size = head_v_dim * head_v_dim * num_v_heads * n_seqs; - - ggml_tensor * state_1d = - ggml_view_1d(ctx0, attn_out, state_flat_size, output_flat_size * ggml_element_size(attn_out)); - cb(state_1d, "state_1d", il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, state_1d, + ggml_cpy(ctx0, new_state, ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - GGML_ASSERT(ggml_nelements(attn_out_1d) + ggml_nelements(state_1d) == ggml_nelements(attn_out)); - // Reshape both attn_out_final and z to 2D tensors for normalization // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = - ggml_cont_2d(ctx0, attn_out_final, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_cont_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); @@ -828,12 +849,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // The gate needs to be broadcast to match the dimensions of ffn_shexp - // ffn_shexp is [n_embd, n_tokens, 1, 1] and shared_gate is [1, n_tokens, 1, 1] - // We need to repeat the gate along the feature dimension - shared_gate = ggml_repeat(ctx0, shared_gate, ffn_shexp); - cb(shared_gate, "shared_expert_gate_broadcast", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il);