Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3799,7 +3799,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
yield from super().modify_tensors(data_torch, merged_name, bid)
else:
yield from super().modify_tensors(data_torch, name, bid)
yield from ModelBase.modify_tensors(self, data_torch, name, bid)

def prepare_tensors(self):
super().prepare_tensors()
Expand Down Expand Up @@ -6153,7 +6153,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
if new_name.endswith("conv_stem.conv.bias") or new_name.endswith("layer_scale.gamma"):
data_torch = data_torch.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) # [1, C, 1, 1]

yield from super().modify_tensors(data_torch, new_name, bid)
yield from ModelBase.modify_tensors(self, data_torch, new_name, bid)


@ModelBase.register("Gemma3nForCausalLM", "Gemma3nForConditionalGeneration")
Expand Down Expand Up @@ -6253,7 +6253,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

# Continue with normal processing
name = name.replace("language_model.", "")
yield from super().modify_tensors(data_torch, name, bid)
yield from ModelBase.modify_tensors(self, data_torch, name, bid)
return

if "altup_unembed_projections" in name:
Expand All @@ -6270,7 +6270,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_unembd)
if out is not None:
yield from super().modify_tensors(out, "model.altup_unembed_projections.weight", bid)
yield from ModelBase.modify_tensors(self, out, "model.altup_unembed_projections.weight", bid)
return
else:
return
Expand All @@ -6287,7 +6287,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
raise ValueError(f"Unknown name: {name}")
out = self._stack_matrices(self._altup_proj)
if out is not None:
yield from super().modify_tensors(out, "model.altup_projections.weight", bid)
yield from ModelBase.modify_tensors(self, out, "model.altup_projections.weight", bid)
return
else:
return
Expand Down Expand Up @@ -8803,8 +8803,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
ffn_dim = self.hparams["intermediate_size"]
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
gate, up = data_torch.split(ffn_dim, dim=-2)
yield from super().modify_tensors(gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid)
yield from super().modify_tensors(up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid)
yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid)

has_experts = bool(self.hparams.get('num_local_experts'))

Expand All @@ -8813,15 +8813,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
gate, up = data_torch.split(ffn_dim, dim=-2)
if has_experts:
yield from super().modify_tensors(gate,self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), bid)
yield from super().modify_tensors(up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), bid)
yield from ModelBase.modify_tensors(self, gate,self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), bid)
return
yield from super().modify_tensors(gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), bid)
yield from super().modify_tensors(up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), bid)
yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), bid)
return

if not has_experts and name.endswith("shared_mlp.output_linear.weight"):
yield from super().modify_tensors(data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), bid)
yield from ModelBase.modify_tensors(self, data_torch, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), bid)
return

yield from super().modify_tensors(data_torch, name, bid)
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ void launch_fattn(
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);

const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
}

const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);

const int cc = ggml_cuda_info().devices[device].cc;

Expand Down
8 changes: 4 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
}

const uint32_t n_embd_out = model.hparams.get_n_embd_out();
const uint32_t n_embd_out = model.hparams.n_embd_out();
return embd + j*n_embd_out;
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
Expand Down Expand Up @@ -1279,7 +1279,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
const uint32_t n_embd_out = hparams.get_n_embd_out();
const uint32_t n_embd_out = hparams.n_embd_out();

GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
Expand Down Expand Up @@ -1688,7 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);
const uint32_t n_embd_out = hparams.get_n_embd_out();
const uint32_t n_embd_out = hparams.n_embd_out();
float * embd_out = embd + n_outputs_prev*n_embd_out;

if (n_outputs) {
Expand Down Expand Up @@ -1821,7 +1821,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba

const auto n_batch = cparams.n_batch;
const auto n_vocab = vocab.n_tokens();
const auto n_embd_out = hparams.get_n_embd_out();
const auto n_embd_out = hparams.n_embd_out();

bool has_logits = true;
bool has_embd = cparams.embeddings;
Expand Down
117 changes: 111 additions & 6 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
return res;
}

void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
mctx->set_input_k_idxs(self_k_idxs, ubatch);

mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
}

bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);

this->mctx = mctx;

bool res = true;

res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;

res &= self_kq_mask->ne[0] == mctx->get_n_kv();
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;

return res;
}

void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
Expand Down Expand Up @@ -1596,11 +1617,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_transpose(ctx0, v);
}

// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
if (v_mla) {
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
}

// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
Expand Down Expand Up @@ -1823,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn(
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
ggml_tensor * v_mla, // TODO: remove
float kq_scale,
int il) const {
GGML_ASSERT(v_mla == nullptr);

// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
Expand Down Expand Up @@ -1868,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn(
return cur;
}

static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
ggml_context * ctx0,
const llama_ubatch & ubatch,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx_cur) {

auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);

{
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");

const auto n_kv = mctx_cur->get_n_kv();
const auto n_tokens = ubatch.n_tokens;
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;

inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);

inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
ggml_set_input(inp->self_kq_mask);

inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
}

return inp;
}

llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);

auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);

return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur,
ggml_tensor * k_cur,
ggml_tensor * v_cur,
ggml_tensor * kq_b,
ggml_tensor * sinks,
ggml_tensor * v_mla,
float kq_scale,
int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
// expand k later to enable rope fusion which directly writes into k-v cache
ggml_build_forward_expand(gf, q_cur);
ggml_build_forward_expand(gf, v_cur);
ggml_build_forward_expand(gf, k_cur);

const auto * mctx_cur = inp->mctx;

// store to KV cache
{
const auto & k_idxs = inp->get_k_idxs();

ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
}

const auto & kq_mask = inp->get_kq_mask();

ggml_tensor * q = q_cur;
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);

ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
cb(cur, "kqv_out", il);

if (wo) {
cur = build_lora_mm(wo, cur);
if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
// GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
}
}

if (wo_b) {
cur = ggml_add(ctx0, cur, wo_b);
}

return cur;
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_kv_iswa * inp,
ggml_tensor * wo,
Expand Down
48 changes: 48 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,39 @@ class llm_graph_input_attn_kv : public llm_graph_input_i {
const llama_kv_cache_context * mctx;
};

// V-less input for the KV cache
// ref: https://github.com/ggml-org/llama.cpp/pull/19067
class llm_graph_input_attn_k : public llm_graph_input_i {
public:
llm_graph_input_attn_k(
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache_context * mctx) :
hparams(hparams),
cparams(cparams),
mctx(mctx) {
}
~llm_graph_input_attn_k() = default;

void set_input(const llama_ubatch * ubatch) override;

bool can_reuse(const llm_graph_params & params) override;

ggml_tensor * get_k_idxs() const { return self_k_idxs; }

ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }

ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]

ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]

const llama_hparams hparams;
const llama_cparams cparams;

const llama_kv_cache_context * mctx;
};

class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
public:
llm_graph_input_attn_kv_iswa(
Expand Down Expand Up @@ -833,6 +866,21 @@ struct llm_graph_context {
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
float kq_scale,
int il) const;

llm_graph_input_attn_k * build_attn_inp_k() const;

ggml_tensor * build_attn(
llm_graph_input_attn_k * inp,
ggml_tensor * wo,
ggml_tensor * wo_b,
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
ggml_tensor * kq_b,
ggml_tensor * sinks, // [n_head_q]
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale,
int il) const;
Expand Down
19 changes: 17 additions & 2 deletions src/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ uint32_t llama_hparams::n_embd_inp() const {
return n_embd_inp;
}

uint32_t llama_hparams::get_n_embd_out() const {
return n_embd_out > 0 ? n_embd_out : n_embd;
uint32_t llama_hparams::n_embd_out() const {
return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
}

uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
Expand Down Expand Up @@ -175,6 +175,21 @@ bool llama_hparams::is_swa(uint32_t il) const {
GGML_ABORT("fatal error");
}

bool llama_hparams::is_mla() const {
assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
(n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));

return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
}

uint32_t llama_hparams::n_embd_head_k_mla() const {
return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
}

uint32_t llama_hparams::n_embd_head_v_mla() const {
return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
}

bool llama_hparams::has_kv(uint32_t il) const {
if (n_layer_kv_from_start >= 0) {
if (il < (uint32_t) n_layer_kv_from_start) {
Expand Down
Loading
Loading