diff --git a/convert-hf-to-powerinfer-gguf.py b/convert-hf-to-powerinfer-gguf.py index 181fe972..0aa4632e 100644 --- a/convert-hf-to-powerinfer-gguf.py +++ b/convert-hf-to-powerinfer-gguf.py @@ -185,6 +185,8 @@ def from_model_architecture(model_architecture): return FalconModel if model_architecture == "LlamaForCausalLM": return LlamaModel + if model_architecture == "OPTForCausalLM": + return OptModel raise NotImplementedError(f'Architecture "{model_architecture}" not supported!') @@ -218,6 +220,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH: return gguf.MODEL_ARCH.FALCON if arch == "RWForCausalLM" or arch == "LlamaForCausalLM": return gguf.MODEL_ARCH.LLAMA + if arch == "OPTForCausalLM": + return gguf.MODEL_ARCH.OPT raise NotImplementedError(f'Architecture "{arch}" not supported!') @@ -513,7 +517,63 @@ def write_tensors(self): self.gguf_writer.add_tensor(new_name, data) +class OptModel(Model): + def set_gguf_parameters(self, params: PredictorParams): + self.gguf_writer.add_name("opt") + self.gguf_writer.add_context_length(2050) # not in config.json + self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) + self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"]) + self.gguf_writer.add_feed_forward_length(self.hparams["ffn_dim"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + # self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_file_type(self.ftype) + + if params.sparse_threshold is not None: + self.gguf_writer.add_sparse_threshold(params.sparse_threshold) + + def write_tensors(self): + for name, data_torch in self.get_tensors(): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.squeeze().numpy() + + # map tensor names + new_name = self._translate_tensor_key(name) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + # We need to transpose the weight matrices for the FFN Down layers to support the + # Axpy operation in PowerInfer. So we don't need to transpose them at runtime. + if "ffn_down" in new_name: + new_name = new_name.replace("ffn_down", "ffn_down_t") + data = data.T + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ( + self.ftype == 1 + and data_dtype == np.float32 + and name.endswith(".weight") + and n_dims == 2 + ): + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) @dataclass class PredictorParams: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index e82df27b..9459b477 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -90,6 +90,7 @@ class MODEL_ARCH(IntEnum): GPT2 = auto() GPTJ = auto() GPTNEOX = auto() + OPT = auto() MPT = auto() STARCODER = auto() PERSIMMON = auto() @@ -135,6 +136,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GPT2: "gpt2", MODEL_ARCH.GPTJ: "gptj", MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.OPT: "opt", MODEL_ARCH.MPT: "mpt", MODEL_ARCH.STARCODER: "starcoder", MODEL_ARCH.PERSIMMON: "persimmon", @@ -356,7 +358,20 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GPT2: [ # TODO ], - # TODO + MODEL_ARCH.OPT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.POS_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], } # tensors that will not be serialized diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2c813050..641b81f0 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -11,6 +11,7 @@ class TensorNameMap: MODEL_TENSOR.TOKEN_EMBD: ( "gpt_neox.embed_in", # gptneox "transformer.wte", # gpt2 gpt-j mpt refact + "decoder.embed_tokens", # opt "transformer.word_embeddings", # falcon "word_embeddings", # bloom "model.embed_tokens", # llama-hf @@ -33,6 +34,7 @@ class TensorNameMap: MODEL_TENSOR.POS_EMBD: ( "transformer.wpe", # gpt2 "embeddings.position_embeddings", # bert + "decoder.embed_positions", # opt ), # Output @@ -47,6 +49,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT_NORM: ( "gpt_neox.final_layer_norm", # gptneox "transformer.ln_f", # gpt2 gpt-j falcon + "decoder.final_layer_norm", # opt "model.norm", # llama-hf baichuan "norm", # llama-pth "embeddings.LayerNorm", # bert @@ -66,6 +69,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_NORM: ( "gpt_neox.layers.{bid}.input_layernorm", # gptneox "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact + "decoder.layers.{bid}.self_attn_layer_norm", # opt "transformer.blocks.{bid}.norm_1", # mpt "transformer.h.{bid}.input_layernorm", # falcon7b "h.{bid}.input_layernorm", # bloom @@ -98,6 +102,7 @@ class TensorNameMap: "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert "transformer.h.{bid}.attn.q_proj", # gpt-j + "decoder.layers.{bid}.self_attn.q_proj", # opt ), # Attention key @@ -106,6 +111,7 @@ class TensorNameMap: "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert "transformer.h.{bid}.attn.k_proj", # gpt-j + "decoder.layers.{bid}.self_attn.k_proj", # opt ), # Attention value @@ -114,12 +120,14 @@ class TensorNameMap: "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert "transformer.h.{bid}.attn.v_proj", # gpt-j + "decoder.layers.{bid}.self_attn.v_proj", # opt ), # Attention output MODEL_TENSOR.ATTN_OUT: ( "gpt_neox.layers.{bid}.attention.dense", # gptneox "transformer.h.{bid}.attn.c_proj", # gpt2 refact + "decoder.layers.{bid}.self_attn.out_proj", # opt "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon "h.{bid}.self_attention.dense", # bloom @@ -140,6 +148,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox "transformer.h.{bid}.ln_2", # gpt2 refact + "decoder.layers.{bid}.final_layer_norm", # opt "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt "model.layers.{bid}.post_attention_layernorm", # llama-hf @@ -153,6 +162,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_UP: ( "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox "transformer.h.{bid}.mlp.c_fc", # gpt2 + "decoder.layers.{bid}.fc1", # opt "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon "h.{bid}.mlp.dense_h_to_4h", # bloom @@ -173,6 +183,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_DOWN: ( "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox "transformer.h.{bid}.mlp.c_proj", # gpt2 refact + "decoder.layers.{bid}.fc2", # opt "transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon "h.{bid}.mlp.dense_4h_to_h", # bloom diff --git a/llama.cpp b/llama.cpp index 3ae9e946..ac52908a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -230,6 +230,7 @@ enum llm_arch { LLM_ARCH_GPT2, LLM_ARCH_GPTJ, LLM_ARCH_GPTNEOX, + LLM_ARCH_OPT, LLM_ARCH_MPT, LLM_ARCH_STARCODER, LLM_ARCH_PERSIMMON, @@ -246,6 +247,7 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_GPT2, "gpt2" }, { LLM_ARCH_GPTJ, "gptj" }, { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_OPT, "opt" }, { LLM_ARCH_MPT, "mpt" }, { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, @@ -483,6 +485,26 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPT, + { + {LLM_TENSOR_TOKEN_EMBD, "token_embd"}, + {LLM_TENSOR_POS_EMBD, "position_embd"}, + {LLM_TENSOR_OUTPUT_NORM, "output_norm"}, + {LLM_TENSOR_OUTPUT, "output"}, + {LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm"}, + {LLM_TENSOR_ATTN_Q, "blk.%d.attn_q"}, + {LLM_TENSOR_ATTN_K, "blk.%d.attn_k"}, + {LLM_TENSOR_ATTN_V, "blk.%d.attn_v"}, + {LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output"}, + {LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm"}, + {LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down"}, + {LLM_TENSOR_FFN_DOWN_T, "blk.%d.ffn_down_t"}, + {LLM_TENSOR_FFN_UP, "blk.%d.ffn_up"}, + { LLM_TENSOR_MLP_PRED_FC1, "blk.%d.fc1" }, + { LLM_TENSOR_MLP_PRED_FC2, "blk.%d.fc2" }, + }, + }, { LLM_ARCH_PERSIMMON, { @@ -1321,6 +1343,9 @@ struct llama_layer { struct ggml_tensor * wqkv; // attention bias + struct ggml_tensor * bq; + struct ggml_tensor * bk; + struct ggml_tensor * bv; struct ggml_tensor * bo; struct ggml_tensor * bqkv; @@ -2341,6 +2366,17 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OPT: + { + // TODO: GGUF_GET_KEY & support different model versions + hparams.n_ctx_train = 2050; // TODO: hard coded for now + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 40: model.type = e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_30B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_FALCON: { GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); @@ -3229,6 +3265,50 @@ static void llm_load_sparse_model_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_OPT: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + model.pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}); + // output + { + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}); + // model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + } + + const uint32_t n_ff = hparams.n_ff; + model.layers.resize(n_layer); + + for (uint32_t &i = current_layer; i < n_layer; ++i) { + auto & layer = model.layers[i]; + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}); + + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}); + + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}); + + layer.ffn_down_t = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "weight", i), {n_embd, n_ff}); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_T, "bias", i), {n_embd}); + + layer.mlp_pre_w1 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC1, "weight", i), {n_embd, GGML_NE_WILDCARD}); + layer.mlp_pre_w2 = create_tensor(tn(LLM_TENSOR_MLP_PRED_FC2, "weight", i), {GGML_NE_WILDCARD, n_ff}); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}); + } + } break; case LLM_ARCH_FALCON: { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -3482,6 +3562,81 @@ static void llm_load_tensors( } } } break; + case LLM_ARCH_OPT: + { + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.pos_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_ctx_train}, GGML_BACKEND_CPU); + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = llama_backend_offload; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; +#endif // _WIN32 + + backend_output = llama_backend_offload_split; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + // model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, backend_output); // same as token_embed + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + // if (backend_output == GGML_BACKEND_GPU_SPLIT) { + // vram_weights += ggml_nbytes(model.output); + // } + } + const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; + model.layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, backend_split); + layer.bq = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, backend_split); + + layer.wk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.bk = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, backend_split); + + layer.wv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, backend_split); + layer.bv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, backend_split); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.ffn_down_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + + layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.ffn_up_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up); + } + } + } break; case LLM_ARCH_BAICHUAN: { model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); @@ -4928,6 +5083,126 @@ struct llm_build_context { return gf; } + struct ggml_cgraph * build_opt() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + struct ggml_tensor * cur; + struct ggml_tensor * pos; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb); + cb(inpL, "inp_embd", -1); + + struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + cb(inp_pos, "inp_pos", -1); + + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + cb(KQ_scale, "KQ_scale", -1); + + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + cb(KQ_mask, "KQ_mask", -1); + + pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos); + cb(pos, "pos_embd", -1); + + inpL = ggml_add(ctx0, inpL, pos); + cb(inpL, "inpL", -1); + + for (int il = 0; il < n_layer; ++il) { + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, model.layers[il].attn_norm_b, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + std::tie(k_cpy, v_cpy) = llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); + + cur = llm_build_kqv(ctx0, hparams, kv_self, + model.layers[il].wo, model.layers[il].bo, + Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); + } + // add input residual + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, model.layers[il].ffn_norm_b, + LLM_NORM, cb, il); + + if(llama_use_sparse_inference(&model)) { + llm_build_cb_short cbs = [&](ggml_tensor * cur, const char * name) { + std::string name_str = std::string(name) + "-" + std::to_string(il); + ggml_set_name(cur, name_str.c_str()); + }; + // We only offload the ffn input to GPU if all neurons are offloaded + if (model.layers[il].gpu_offload_ratio >= 1.) { + cb(cur, "ffn_norm", il); + } else { + cbs(cur, "ffn_norm"); + } + cur = llm_build_ffn_sparse(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down_t, model.layers[il].ffn_down_b, + model.layers[il].mlp_pre_w1, + model.layers[il].mlp_pre_w2, + ffn_inp, + model.layers[il].gpu_idx, + model.layers[il].gpu_bucket, model.layers[il].ffn_gate_gpu, model.layers[il].ffn_down_gpu, model.layers[il].ffn_up_gpu, + LLM_FFN_RELU, LLM_FFN_SEQ, model.layers[il].gpu_offload_ratio, cbs); + } else { + cb(cur, "ffn_norm", il); + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, + NULL, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, + LLM_FFN_RELU, LLM_FFN_SEQ, cb, il); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + // input for next layer + inpL = cur; + } + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, model.output_norm_b, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; + } + struct ggml_cgraph * build_baichuan() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6208,6 +6483,9 @@ static struct ggml_cgraph * llama_build_graph( for (int i = 0; i < n_tokens; ++i) { data[i] = batch.pos[i]; + if(model.arch == LLM_ARCH_OPT) { + data[i] += 2; + } } } @@ -6440,6 +6718,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_stablelm(); } break; + case LLM_ARCH_OPT: + { + result = llm.build_opt(); + } break; default: GGML_ASSERT(false); } diff --git a/powerinfer-py/powerinfer/export_split.py b/powerinfer-py/powerinfer/export_split.py index 9a773b26..7f230d8c 100644 --- a/powerinfer-py/powerinfer/export_split.py +++ b/powerinfer-py/powerinfer/export_split.py @@ -1,11 +1,14 @@ import argparse import pickle -import gguf +import sys from gguf.constants import GGMLQuantizationType from gguf.gguf_writer import GGUFWriter import torch from pathlib import Path import os +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +import gguf import struct import numpy as np import re