From c0cdefb80801acb793bb236cfc4fc49f3dfabbd0 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Fri, 13 Feb 2026 13:19:47 +0300 Subject: [PATCH 1/2] [quantization] Quantization of Llama This PR quantizes the full `LLama` model and converts it to circle format. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../quantize_full_qmodel_with_gptq.py | 574 ++++++++++++++++++ tico/quantization/wrapq/quantizer.py | 39 ++ .../wrappers/llama/quant_decoder_layer.py | 6 +- tico/quantization/wrapq/wrappers/registry.py | 1 + 4 files changed, 617 insertions(+), 3 deletions(-) create mode 100644 tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py new file mode 100644 index 00000000..cb2d3d65 --- /dev/null +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -0,0 +1,574 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ============================================================================= +# PTQ + GPTQ HYBRID QUANTIZATION PIPELINE +# ----------------------------------------------------------------------------- +# This script shows how to: +# 1. Load a pretrained FP Llama-3 model. +# 2. Run GPTQ to quantize weights only (optional). +# 3. Wrap every Transformer layer with a PTQWrapper to quantize activations. +# 4. Calibrate activations observers in a single pass over a text corpus. +# 5. Inject GPTQ’s per-tensor weight scales / zero-points into the PTQ graph. +# 6. Freeze all Q-params and compute Wikitext-2 perplexity. +# 7. Save model/layers (optional) +# ============================================================================= + +import argparse +import pathlib +import random + +import types + +from typing import Any, List, Optional, Tuple, Union + +import torch +import tqdm +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import KwargsForCausalLM, LlamaForCausalLM +from transformers.processing_utils import Unpack + +import tico + +from tico.quantization import convert, prepare +from tico.quantization.config.gptq import GPTQConfig +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.observers.affine_base import AffineObserverBase +from tico.quantization.wrapq.qscheme import QScheme +from tico.quantization.wrapq.utils.metrics import perplexity +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase + +from tico.utils.utils import SuppressWarning + +DTYPE_MAP = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, +} + +# Hardcoded dataset settings +DATASET_NAME = "wikitext" +DATASET_CONFIG = "wikitext-2-raw-v1" +TRAIN_SPLIT = "train" +TEST_SPLIT = "test" + + +# ------------------------------------------------------------------------- +# Helper — copy GPTQ (scale, zp) into PTQ observers +# ------------------------------------------------------------------------- +def inject_gptq_qparams( + root: torch.nn.Module, + gptq_quantizers: dict[str, Any], # {fp_name: quantizer} + weight_obs_name: str = "weight", +): + """ + For every `QuantModuleBase` whose `fp_name` matches a GPTQ key, + locate the observer called `weight_obs_name` and overwrite its + (scale, zero-point), then lock them against further updates. + """ + for m in root.modules(): + if not isinstance(m, QuantModuleBase): + continue + if m.fp_name is None: + continue + quantizer = gptq_quantizers.get(m.fp_name) + if quantizer is None: + continue + obs = m.get_observer(weight_obs_name) + if obs is None: + continue + assert isinstance(obs, AffineObserverBase) + # GPTQ quantizer attributes + obs.load_qparams(quantizer.scale, quantizer.zero, lock=True) + + +# ------------------------------------------------------------------------- +# Save model/layers in circle format +# ------------------------------------------------------------------------- +def save_circles_to(q_m, calib_inputs, save_circle_to_folder): + q_m.eval() + q_m.cpu() + save_path = pathlib.Path(save_circle_to_folder, "embedding.q.circle") + pathlib.Path() + print(f"saving input embedding to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert( + q_m.model.embed_tokens, + (calib_inputs[0],), + strict=False, + ) + cm.save(save_path) + + save_path = pathlib.Path(save_circle_to_folder, "lm_head.q.circle") + print(f"saving lm_head to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size + example_hidden = torch.randn(B, S, D) + cm = tico.convert( + q_m.lm_head, + (example_hidden,), + strict=False, + ) + cm.save(save_path) + + print("saving layers") + for i in range(len(q_m.model.layers)): + save_path = pathlib.Path(save_circle_to_folder, f"decoder_layer_{i}.q.circle") + print(f"saving model layer_{i} to {save_path.resolve()}") + B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size + example_hidden = torch.randn(B, S, D) + + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert( + q_m.model.layers[i], + (example_hidden,), + strict=False, + ) + cm.save(save_path) + + save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle") + print(f"saving model.model to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(q_m.model, (calib_inputs[0],), strict=False) + + cm.save(save_path) + + save_path = pathlib.Path(save_circle_to_folder, "model.q.circle") + print(f"saving the whole model to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(q_m, (calib_inputs[0],), strict=False) + + cm.save(save_path) + + +def quantize_using_PTQ(q_m, calib_inputs, args): + print("Wrapping layers with PTQWrapper …") + + w_cfg = { + "mlp": { + "gate_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + "up_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + "down_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + }, + "self_attn": { + "q_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + "k_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + "v_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + "o_proj": { + "weight": { + "dtype": DType.uint(args.gptq_weight_bits), + }, + }, + }, + "input_layernorm": { + "dtype": DType.int(16), + "weight": {"dtype": DType.int(16)}, + }, + "post_attention_layernorm": { + "dtype": DType.int(16), + "weight": {"dtype": DType.int(16)}, + }, + } + + cfg = PTQConfig( + default_dtype=DType.int(16), + default_qscheme=QScheme.PER_TENSOR_SYMM, + overrides={ + "model.embeddings": { + "weight": { + "dtype": ( + DType.uint(args.embedding_weight_bits) + if args.embedding_weight_bits < 16 + else DType.int(args.embedding_weight_bits) + ), + }, + }, + "lm_head": { + "weight": { + "dtype": ( + DType.uint(args.lm_head_weight_bits) + if args.lm_head_weight_bits < 16 + else DType.int(args.lm_head_weight_bits) + ), + }, + }, + "model.norm": { + "weight": {"dtype": DType.int(16)}, + }, + }, + ) + for i in range(len(q_m.model.layers)): + child_scope = f"layer{i}" + cfg.overrides[child_scope] = w_cfg # type: ignore[index] + + qcfg = cfg + prepare(q_m, qcfg) + + # ------------------------------------------------------------------------- + # Single-pass activation calibration + # ------------------------------------------------------------------------- + print("Calibrating PTQ obeservers…") + + # Overwrite weight observers with GPTQ statistics + if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict): + inject_gptq_qparams(q_m, q_m.quantizers) + else: + print( + "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection." + ) + + device = torch.device(args.device) + with torch.no_grad(): + for inp in tqdm.tqdm(calib_inputs): + q_m(inp.to(device)) + + # Freeze all Q-params (scale, zero-point) + q_m = convert(q_m) + + return q_m + + +def fix_inputs(model, tokenizer, input_ids): + if tokenizer.pad_token_id is not None: + pads = torch.full( + ( + input_ids.shape[0], + model.config.max_position_embeddings - input_ids.shape[1], + ), + fill_value=tokenizer.pad_token_id, + device=input_ids.device, + ) + elif tokenizer.eos_token_id is not None: + pads = torch.full( + ( + input_ids.shape[0], + model.config.max_position_embeddings - input_ids.shape[1], + ), + fill_value=tokenizer.eos_token_id, + device=input_ids.device, + ) + else: + raise RuntimeError( + "failed to pad sequence - tokenizer doesn't have pad_token_id/eos_token_id" + ) + + return torch.cat((input_ids, pads), dim=1) + + +class LLamaWithFixedInput(LlamaForCausalLM): + + def __init__(self, parent: LlamaForCausalLM, tokenizer): + assert parent.config is not None, "config is a must have" + super(LlamaForCausalLM, self).__init__(parent.config) + self.__dict__.update(parent.__dict__) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + # fixed input size, due to position_ids fixed + orig_len = input_ids.shape[-1] + input_ids = fix_inputs(self, self.tokenizer, input_ids) + if labels is not None: + labels = fix_inputs(self, self.tokenizer, labels) + res = super().forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + inputs_embeds, + labels, + use_cache, + output_attentions, + output_hidden_states, + return_dict, + cache_position, + logits_to_keep, + **kwargs, + ) + # we need to trim to the original size + res.logits = res.logits[..., :orig_len, :] + return res + + self.forward = types.MethodType(forward, self) + self.tokenizer = tokenizer + + +def evaluate(q_m, tokenizer, dataset_test, args): + # ------------------------------------------------------------------------- + # Evaluate perplexity on Wikitext-2 + # ------------------------------------------------------------------------- + print("\nCalculating perplexities …") + enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") + ppl_uint8 = perplexity( + q_m, enc, args.device, stride=q_m.config.max_position_embeddings + ) + + print("\n┌── Wikitext-2 test perplexity ─────────────") + print(f"│ int16 : {ppl_uint8:8.2f}") + print("└───────────────────────────────────────────") + + +def main(): + parser = argparse.ArgumentParser( + description="GPTQ+PTQ pipeline (weight-only + activation)" + ) + parser.add_argument( + "--model", type=str, required=True, help="HF repo name or local path." + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run on (cuda|cpu|mps).", + ) + parser.add_argument( + "--dtype", + choices=list(DTYPE_MAP.keys()), + default="float32", + help="Model dtype for load.", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Enable only if you trust the model repo code.", + ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="Optional HF token for gated/private repos.", + ) + parser.add_argument( + "--no-tqdm", action="store_true", help="Disable tqdm progress bars." + ) + parser.add_argument( + "--no_GPTQ", + action="store_true", + default=False, + help="Don't use GPTQ", + ) + parser.add_argument( + "--no_PTQ", + action="store_true", + default=False, + help="Leave model float", + ) + parser.add_argument( + "--save_circle_to_folder", + type=str, + default=None, + help="Save embedding/lm_head/all_layers/model.model/the_whole_model to the folder specified", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="cache_dir for using model/datasets loading", + ) + parser.add_argument( + "--nsamples_for_qcalibration", + type=int, + default="128", # almost standard + help="number of samples to be used in GPTQ/PTQ calibration", + ) + parser.add_argument( + "--gptq_weight_bits", + type=int, + default=4, + help="Number of bits to be used in GPTQ quantizer for weight quantization", + ) + parser.add_argument( + "--gptq_mse", + action="store_true", + default=False, + help="Whether to use mse in gptq", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=None, + help="constraint for max_position_embeddings", + ) + parser.add_argument( + "--embedding_weight_bits", + type=int, + default=8, + help="Number of bits to be used to quantize input Embedding", + ) + parser.add_argument( + "--lm_head_weight_bits", + type=int, + default=4, + help="Number of bits to be used to quantize lm_head", + ) + parser.add_argument( + "--eval_tasks", + type=str, + default=None, + help="tasks to be evaluated using lm_eval, e.g. `winogrande,arc_easy,arc_challenge,openbookqa,mmlu_pro,ifeval,bbh`", + ) + args = parser.parse_args() + print(args) + + # Basic setup + torch.manual_seed(args.seed) + device = torch.device(args.device) + dtype = DTYPE_MAP[args.dtype] + + print("=== Config ===") + print(f"Model : {args.model}") + print(f"Device : {device.type}") + print(f"DType : {args.dtype}") + print() + + # ------------------------------------------------------------------------- + # 2. Load the FP backbone and tokenizer + # ------------------------------------------------------------------------- + print("Loading FP model …") + tokenizer = AutoTokenizer.from_pretrained( + args.model, + trust_remote_code=args.trust_remote_code, + token=args.hf_token, + cache_dir=args.cache_dir, + ) + model = ( + AutoModelForCausalLM.from_pretrained( + args.model, + torch_dtype=dtype, + trust_remote_code=args.trust_remote_code, + token=args.hf_token, + cache_dir=args.cache_dir, + ) + .to(device) + .eval() + ) + + model.config.use_cache = False # TODO use args for it + if args.max_seq_len is not None: + model.config.max_position_embeddings = min( + model.config.max_position_embeddings, args.max_seq_len + ) + + dataset_test = load_dataset( + DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT, cache_dir=args.cache_dir + ) + + print("\nCalculating original perplexities …") + enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") + ppl_fp32 = perplexity( + model, enc, device, stride=model.config.max_position_embeddings + ) + + print("\n┌── Wikitext-2 test perplexity ─────────────") + print(f"│ FP32 : {ppl_fp32:8.2f}") + print("└───────────────────────────────────────────") + + # ------------------------------------------------------------------------- + # Prepare calibration dataset + # ------------------------------------------------------------------------- + dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT) + calib_txt = " ".join(dataset_train["text"]) + train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device) + calib_inputs = [] + nsamples = args.nsamples_for_qcalibration + seqlen = model.config.max_position_embeddings + random.seed(args.seed) + for _ in range(nsamples): + i = random.randint(0, train_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = train_ids[:, i:j] + calib_inputs.append(inp.cpu()) + + # ------------------------------------------------------------------------- + # Run GPTQ (weight-only) pass + # ------------------------------------------------------------------------- + if not args.no_GPTQ: + if not args.no_GPTQ: + print("Applying GPTQ …") + + gptq_config = GPTQConfig( + weight_bits=args.gptq_weight_bits, perchannel=True, mse=args.gptq_mse + ) + q_m = prepare(model, gptq_config, inplace=True) + with torch.no_grad(): + for inp in calib_inputs: + q_m(inp.to(args.device)) + + q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + else: + q_m = model + + # ------------------------------------------------------------------------- + # Wrap every layer with PTQWrapper + # ------------------------------------------------------------------------- + if not args.no_PTQ: + q_m = quantize_using_PTQ(q_m, calib_inputs, args) + + # after PTQ quantizer only fixed-length input sequences are valid + evaluate(LLamaWithFixedInput(q_m, tokenizer), tokenizer, dataset_test, args) + + if args.save_circle_to_folder is not None: + save_circles_to(q_m, calib_inputs, args.save_circle_to_folder) + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/quantizer.py b/tico/quantization/wrapq/quantizer.py index f233fb42..901514aa 100644 --- a/tico/quantization/wrapq/quantizer.py +++ b/tico/quantization/wrapq/quantizer.py @@ -84,6 +84,45 @@ def _wrap_supported( # Case A: HuggingFace-style transformers: model.model.layers lm = getattr(root, "model", None) + + embeddings = ( + getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None + ) + if isinstance(embeddings, nn.Module): + child_scope = "model.embeddings" + child_cfg = qcfg.child(child_scope) + wrapped = self._try_wrap( + embeddings, + child_cfg, + fp_name=child_scope, + raise_on_fail=self.strict_wrap, + ) + lm.embed_tokens = wrapped # type: ignore[union-attr] + + model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None + if isinstance(model_norm, nn.Module): + child_scope = "model.norm" + child_cfg = qcfg.child(child_scope) + wrapped = self._try_wrap( + model_norm, + child_cfg, + fp_name=child_scope, + raise_on_fail=self.strict_wrap, + ) + lm.norm = wrapped # type: ignore[union-attr] + + lm_head = getattr(root, "lm_head", None) if isinstance(lm, nn.Module) else None + if isinstance(lm_head, nn.Module): + child_scope = "lm_head" + child_cfg = qcfg.child(child_scope) + wrapped = self._try_wrap( + lm_head, + child_cfg, + fp_name=child_scope, + raise_on_fail=self.strict_wrap, + ) + root.lm_head = wrapped + layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None if isinstance(layers, nn.ModuleList): new_list = nn.ModuleList() diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py index b760777b..93a5ac57 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py @@ -185,9 +185,9 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - if attention_mask is None or attention_mask.dtype == torch.bool: - L = hidden_states.size(1) - attention_mask = self._slice_causal(L, hidden_states.device) + # to prevent introduction of attention_mask as a parameter let's use preset attention_mask + L = hidden_states.size(1) + attention_mask = self._slice_causal(L, hidden_states.device) position_embeddings = ( self.rope_cos_template.to( diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 6a0c2b83..04d52333 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -24,6 +24,7 @@ _CORE_MODULES = ( "tico.quantization.wrapq.wrappers.quant_elementwise", ## nn ## + "tico.quantization.wrapq.wrappers.nn.quant_embedding", "tico.quantization.wrapq.wrappers.nn.quant_layernorm", "tico.quantization.wrapq.wrappers.nn.quant_linear", "tico.quantization.wrapq.wrappers.nn.quant_conv3d", From a0b1d47349fa1710f0b80f91774aa856de5e83b1 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Mon, 16 Feb 2026 15:13:26 +0300 Subject: [PATCH 2/2] [DRAFT] Improvements in disk space This PR fixes population of static `causal_masks`\`position_embeddings` through the layers to save disk space. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../quantize_full_qmodel_with_gptq.py | 95 +++++--- .../wrapq/examples/quantize_with_gptq.py | 2 +- tico/quantization/wrapq/quantizer.py | 12 +- tico/quantization/wrapq/utils/metrics.py | 14 +- .../wrapq/wrappers/llama/quant_attn.py | 15 +- .../wrappers/llama/quant_decoder_layer.py | 45 +++- .../wrapq/wrappers/llama/quant_model.py | 216 ++++++++++++++++++ .../llama/quant_model_for_causal_lm.py | 169 ++++++++++++++ tico/quantization/wrapq/wrappers/registry.py | 2 + 9 files changed, 507 insertions(+), 63 deletions(-) create mode 100644 tico/quantization/wrapq/wrappers/llama/quant_model.py create mode 100644 tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index cb2d3d65..1efbc877 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -104,13 +104,30 @@ def inject_gptq_qparams( def save_circles_to(q_m, calib_inputs, save_circle_to_folder): q_m.eval() q_m.cpu() + + save_path = pathlib.Path(save_circle_to_folder, "model.q.circle") + print(f"saving the whole model to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(q_m.wrapped, (calib_inputs[0],), strict=False) + + cm.save(save_path) + + save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle") + print(f"saving model.model to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(q_m.wrapped.model, (calib_inputs[0],), strict=False) + + cm.save(save_path) + save_path = pathlib.Path(save_circle_to_folder, "embedding.q.circle") pathlib.Path() print(f"saving input embedding to {save_path.resolve()}") with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): cm = tico.convert( - q_m.model.embed_tokens, + q_m.wrapped.model.wrapped.embed_tokens, (calib_inputs[0],), strict=False, ) @@ -120,47 +137,42 @@ def save_circles_to(q_m, calib_inputs, save_circle_to_folder): print(f"saving lm_head to {save_path.resolve()}") with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): - B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size + B, S, D = ( + 1, + q_m.wrapped.config.max_position_embeddings, + q_m.wrapped.config.hidden_size, + ) example_hidden = torch.randn(B, S, D) cm = tico.convert( - q_m.lm_head, + q_m.wrapped.lm_head, (example_hidden,), strict=False, ) cm.save(save_path) print("saving layers") - for i in range(len(q_m.model.layers)): + for i in range(len(q_m.wrapped.model.wrapped.layers)): save_path = pathlib.Path(save_circle_to_folder, f"decoder_layer_{i}.q.circle") print(f"saving model layer_{i} to {save_path.resolve()}") - B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size + B, S, D = ( + 1, + q_m.wrapped.config.max_position_embeddings, + q_m.wrapped.config.hidden_size, + ) example_hidden = torch.randn(B, S, D) + cur_layer = q_m.wrapped.model.wrapped.layers[i].wrapped + if hasattr(cur_layer, "copy_quantizers"): + cur_layer.copy_quantizers(q_m.wrapped.model.wrapped) with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): cm = tico.convert( - q_m.model.layers[i], + q_m.wrapped.model.wrapped.layers[i], (example_hidden,), strict=False, ) cm.save(save_path) - save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle") - print(f"saving model.model to {save_path.resolve()}") - with torch.no_grad(): - with SuppressWarning(UserWarning, ".*"): - cm = tico.convert(q_m.model, (calib_inputs[0],), strict=False) - - cm.save(save_path) - - save_path = pathlib.Path(save_circle_to_folder, "model.q.circle") - print(f"saving the whole model to {save_path.resolve()}") - with torch.no_grad(): - with SuppressWarning(UserWarning, ".*"): - cm = tico.convert(q_m, (calib_inputs[0],), strict=False) - - cm.save(save_path) - def quantize_using_PTQ(q_m, calib_inputs, args): print("Wrapping layers with PTQWrapper …") @@ -219,13 +231,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args): default_dtype=DType.int(16), default_qscheme=QScheme.PER_TENSOR_SYMM, overrides={ - "model.embeddings": { - "weight": { - "dtype": ( - DType.uint(args.embedding_weight_bits) - if args.embedding_weight_bits < 16 - else DType.int(args.embedding_weight_bits) - ), + "model": { + "embed_tokens": { + "weight": { + "dtype": ( + DType.uint(args.embedding_weight_bits) + if args.embedding_weight_bits < 16 + else DType.int(args.embedding_weight_bits) + ), + }, + }, + "layers": {}, + "norm": { + "weight": {"dtype": DType.int(16)}, }, }, "lm_head": { @@ -237,17 +255,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args): ), }, }, - "model.norm": { - "weight": {"dtype": DType.int(16)}, - }, }, ) for i in range(len(q_m.model.layers)): - child_scope = f"layer{i}" - cfg.overrides[child_scope] = w_cfg # type: ignore[index] + child_scope = f"{i}" + cfg.overrides["model"]["layers"][child_scope] = w_cfg # type: ignore[index] qcfg = cfg - prepare(q_m, qcfg) + q_m = prepare(q_m, qcfg) # ------------------------------------------------------------------------- # Single-pass activation calibration @@ -257,6 +272,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args): # Overwrite weight observers with GPTQ statistics if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict): inject_gptq_qparams(q_m, q_m.quantizers) + elif ( + hasattr(q_m, "wrapped") + and hasattr(q_m.wrapped, "quantizers") + and isinstance(q_m.wrapped.quantizers, dict) + ): + inject_gptq_qparams(q_m.wrapped, q_m.wrapped.quantizers) else: print( "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection." @@ -358,7 +379,7 @@ def evaluate(q_m, tokenizer, dataset_test, args): print("\nCalculating perplexities …") enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") ppl_uint8 = perplexity( - q_m, enc, args.device, stride=q_m.config.max_position_embeddings + q_m, enc, args.device, stride=q_m.wrapped.config.max_position_embeddings ) print("\n┌── Wikitext-2 test perplexity ─────────────") @@ -564,7 +585,7 @@ def main(): q_m = quantize_using_PTQ(q_m, calib_inputs, args) # after PTQ quantizer only fixed-length input sequences are valid - evaluate(LLamaWithFixedInput(q_m, tokenizer), tokenizer, dataset_test, args) + evaluate(q_m, tokenizer, dataset_test, args) if args.save_circle_to_folder is not None: save_circles_to(q_m, calib_inputs, args.save_circle_to_folder) diff --git a/tico/quantization/wrapq/examples/quantize_with_gptq.py b/tico/quantization/wrapq/examples/quantize_with_gptq.py index f86b4d0a..6e58b88e 100644 --- a/tico/quantization/wrapq/examples/quantize_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_with_gptq.py @@ -42,7 +42,6 @@ from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase - # Token-budget presets for activation calibration TOKENS: dict[str, int] = { # Smoke test (<1 min turnaround on CPU/GPU) @@ -65,6 +64,7 @@ TRAIN_SPLIT = "train" TEST_SPLIT = "test" + # ------------------------------------------------------------------------- # 1. Helper — copy GPTQ (scale, zp) into PTQ observers # ------------------------------------------------------------------------- diff --git a/tico/quantization/wrapq/quantizer.py b/tico/quantization/wrapq/quantizer.py index 901514aa..129ab7fd 100644 --- a/tico/quantization/wrapq/quantizer.py +++ b/tico/quantization/wrapq/quantizer.py @@ -81,12 +81,18 @@ def _wrap_supported( Recursively attempt to wrap boundaries. Strictness is applied at every boundary. """ assert not isinstance(root, QuantModuleBase), "The module is already wrapped." + try: + return PTQWrapper(root, qcfg=qcfg, fp_name="model") + except NotImplementedError as e: + print("no special wrapper for model, wrappig using general case") # Case A: HuggingFace-style transformers: model.model.layers lm = getattr(root, "model", None) embeddings = ( - getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None + getattr(lm, "embed_tokens", None) + if isinstance(lm.embed_tokens, nn.Module) + else None ) if isinstance(embeddings, nn.Module): child_scope = "model.embeddings" @@ -99,7 +105,9 @@ def _wrap_supported( ) lm.embed_tokens = wrapped # type: ignore[union-attr] - model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None + model_norm = ( + getattr(lm, "norm", None) if isinstance(lm.norm, nn.Module) else None + ) if isinstance(model_norm, nn.Module): child_scope = "model.norm" child_cfg = qcfg.child(child_scope) diff --git a/tico/quantization/wrapq/utils/metrics.py b/tico/quantization/wrapq/utils/metrics.py index acd36ea3..ef1ad805 100644 --- a/tico/quantization/wrapq/utils/metrics.py +++ b/tico/quantization/wrapq/utils/metrics.py @@ -89,11 +89,17 @@ def perplexity( device = _resolve_device(device, model) input_ids_full = input_ids_full.to(device) + if max_length is None: - assert hasattr(model, "config") - model_config = model.config - if hasattr(model.config, "text_config"): - model_config = model.config.text_config + if hasattr(model, "config"): + assert hasattr(model, "config") + model_config = model.config + else: + assert hasattr(model.wrapped, "config") + model_config = model.wrapped.config + + if hasattr(model_config, "text_config"): + model_config = model_config.text_config assert hasattr(model_config, "max_position_embeddings") assert isinstance(model_config.max_position_embeddings, int) max_length = model_config.max_position_embeddings diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn.py b/tico/quantization/wrapq/wrappers/llama/quant_attn.py index babdeed2..fd54457e 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn.py @@ -86,9 +86,7 @@ def __init__( ) # Constant scale (1/√d) - scale_t = torch.tensor( - float(getattr(fp_attn, "scaling", self.head_dim**-0.5)) - ) + scale_t = torch.tensor(float(getattr(fp_attn, "scaling", self.head_dim**-0.5))) # merge scale_t to k_proj, (otherwise merge it to q_proj) with torch.no_grad(): lin = self.k_proj.wrapped.module @@ -161,8 +159,9 @@ def _concat_kv( return k, v def _apply_rope(self, q, k, cos, sin, unsqueeze_dim: int = 1): - cos_u = cos.unsqueeze(unsqueeze_dim) - sin_u = sin.unsqueeze(unsqueeze_dim) + cos_u, sin_u = cos, sin + # cos_u = cos.unsqueeze(unsqueeze_dim) + # sin_u = sin.unsqueeze(unsqueeze_dim) q_half = self._rot( q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat @@ -201,8 +200,8 @@ def forward( # Rope tables cos, sin = position_embeddings - cos = self._fq(cos, self.obs_cos) - sin = self._fq(sin, self.obs_sin) + # cos = self._fq(cos, self.obs_cos) + # sin = self._fq(sin, self.obs_sin) q_rot, k_rot = self._apply_rope(q, k, cos, sin, unsqueeze_dim=1) # --- build/update KV for attention & present_key_value ------------- @@ -228,7 +227,7 @@ def forward( attention_mask = self.causal_mask_template[..., :q_len, :k_len].to( hidden_states.device ) - attention_mask = self._fq(attention_mask, self.obs_causal_mask) + attention_mask = self._fq(attention_mask, self.obs_causal_mask) attn_weights_parts = [] attn_out_parts = [] diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py index 93a5ac57..19e8f5c5 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py @@ -108,6 +108,9 @@ def __init__( qcfg=post_attention_layernorm, fp_name=f"{fp_name}.post_attention_layernorm", ) + self.obs_causal_mask = self._make_obs("causal_mask") + self.obs_cos = self._make_obs("cos") + self.obs_sin = self._make_obs("sin") # Static causal mask template --------------------------------------- assert hasattr(fp_layer.self_attn, "config") and hasattr( @@ -166,6 +169,21 @@ def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: assert isinstance(self.causal_mask_template, torch.Tensor) return self.causal_mask_template[..., :seq_len, :seq_len].to(device) + def get_attention_mask_for(self, x): + L = x.size(1) + attention_mask = self._slice_causal(L, x.device) + return attention_mask + + def get_position_embeddings_for(self, hidden_states): + return ( + self.rope_cos_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ), + self.rope_sin_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ), + ) + def forward( self, hidden_states: torch.Tensor, @@ -185,18 +203,17 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - # to prevent introduction of attention_mask as a parameter let's use preset attention_mask - L = hidden_states.size(1) - attention_mask = self._slice_causal(L, hidden_states.device) + if attention_mask is None or attention_mask.dtype == torch.bool: + attention_mask = self.get_attention_mask_for(hidden_states) + attention_mask = self._fq(attention_mask, self.obs_causal_mask) - position_embeddings = ( - self.rope_cos_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), - self.rope_sin_template.to( - dtype=hidden_states.dtype, device=hidden_states.device - ), - ) + if position_embeddings is None: + position_embeddings = self.get_position_embeddings_for(hidden_states) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos.unsqueeze(1), self.obs_cos), + self._fq(sin.unsqueeze(1), self.obs_sin), + ) attn_out = self.self_attn( hidden_states=hidden_states, @@ -242,6 +259,12 @@ def forward( # No local observers; just recurse into children def _all_observers(self): + yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin) yield from self.self_attn._all_observers() yield from self.mlp._all_observers() yield self.obs_mlp_residual_out + + def copy_quantizers(self, model): + self.obs_causal_mask = model.obs_causal_mask + self.obs_cos = model.obs_cos + self.obs_sin = model.obs_sin diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model.py b/tico/quantization/wrapq/wrappers/llama/quant_model.py new file mode 100644 index 00000000..fc31733c --- /dev/null +++ b/tico/quantization/wrapq/wrappers/llama/quant_model.py @@ -0,0 +1,216 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.processing_utils import Unpack + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +def fix_inputs(config, pad_token_id, input_ids): + pads = torch.full( + ( + input_ids.shape[0], + config.max_position_embeddings - input_ids.shape[1], + ), + fill_value=pad_token_id, + device=input_ids.device, + ) + + return torch.cat((input_ids, pads), dim=1) + + +@try_register("transformers.models.llama.modeling_llama.LlamaModel") +class QuantLlamaModel(QuantModuleBase): + def __init__( + self, + model_fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + + # ----- child configs (hierarchical override) ------------------- + embed_cfg = qcfg.child("embed_tokens") if qcfg else None + norm_cfg = qcfg.child("norm") if qcfg else None + layers_cfg = qcfg.child("layers") if qcfg else None + + # ----- wrap children ------------------------------- + assert hasattr(model_fp, "embed_tokens") and isinstance( + model_fp.embed_tokens, torch.nn.Module + ) + assert hasattr(model_fp, "norm") and isinstance(model_fp.norm, torch.nn.Module) + assert hasattr(model_fp, "layers") and isinstance( + model_fp.layers, torch.nn.ModuleList + ) + + self.embed_tokens = PTQWrapper( + model_fp.embed_tokens, embed_cfg, fp_name=f"{fp_name}.embed_tokens" + ) + + self.norm = PTQWrapper(model_fp.norm, norm_cfg, fp_name=f"{fp_name}.norm") + + new_list = nn.ModuleList() + for idx, layer in enumerate(model_fp.layers): + child_scope = f"{idx}" + child_cfg = layers_cfg.child(child_scope) + new_list.append( + PTQWrapper( + layer, + child_cfg, + fp_name=child_scope, + ) + ) + self.obs_causal_mask = self._make_obs("causal_mask") + self.obs_cos = self._make_obs("cos") + self.obs_sin = self._make_obs("sin") + + self.layers = new_list # type: ignore[union-attr] + self.config = model_fp.config + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + + orig_len = input_ids.shape[-1] + pad_id = ( + self.config.pad_token_id + if hasattr(self.config, "pad_token_id") + else self.config.eos_token_id + ) + + input_ids = fix_inputs(self.config, pad_id, input_ids) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + hidden_states = inputs_embeds + # create position_embeddings and causal_mask to be shared across the decoder layers + causal_mask = self.layers[0].wrapped.get_attention_mask_for(hidden_states) + causal_mask = self._fq(causal_mask, self.obs_causal_mask) + position_embeddings = self.layers[0].wrapped.get_position_embeddings_for( + hidden_states + ) + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos.unsqueeze(1), self.obs_cos), + self._fq(sin.unsqueeze(1), self.obs_sin), + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states[..., :orig_len, :] + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + def _all_observers(self): + # recurse into children that are QuantModuleBase + yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin) + + for m in (self.embed_tokens, self.norm): + yield from m._all_observers() + for m in self.layers: + yield from m._all_observers() diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py new file mode 100644 index 00000000..c6742936 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py @@ -0,0 +1,169 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import KwargsForCausalLM +from transformers.processing_utils import Unpack + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +def fix_inputs(config, pad_token_id, input_ids): + pads = torch.full( + ( + input_ids.shape[0], + config.max_position_embeddings - input_ids.shape[1], + ), + fill_value=pad_token_id, + device=input_ids.device, + ) + + return torch.cat((input_ids, pads), dim=1) + + +@try_register("transformers.models.llama.modeling_llama.LlamaForCausalLM") +class QuantLlamaForCausalLM(QuantModuleBase): + def __init__( + self, + model_fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + self.__dict__.update(model_fp.__dict__) # for quantizers at least + + # ----- child configs (hierarchical override) ------------------- + model_cfg = qcfg.child("model") if qcfg else None + lm_head_cfg = qcfg.child("lm_head") if qcfg else None + + ## ----- wrap model/lm_head ------------------------------- + assert hasattr(model_fp, "model") and isinstance( + model_fp.model, torch.nn.Module + ) + assert hasattr(model_fp, "lm_head") and isinstance( + model_fp.lm_head, torch.nn.Module + ) + + self.model = PTQWrapper( + model_fp.model, qcfg=model_cfg, fp_name=f"{fp_name}.model" + ) + + self.lm_head = PTQWrapper( + model_fp.lm_head, qcfg=lm_head_cfg, fp_name=f"{fp_name}.lm_head" + ) + self.config = model_fp.config + self.loss_function = model_fp.loss_function + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithPast]: + + orig_len = input_ids.shape[-1] + pad_id = ( + self.config.pad_token_id + if hasattr(self.config, "pad_token_id") + else self.config.eos_token_id + ) + + # input_ids = fix_inputs(self.config, pad_id, input_ids) + # if labels is not None: + # labels = fix_inputs(self.config, pad_id, labels) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + # logits = logits[..., :orig_len, :] + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _all_observers(self): + # recurse into children that are QuantModuleBase + for m in (self.model, self.lm_head): + yield from m._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 04d52333..4fe9d692 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -37,6 +37,8 @@ "tico.quantization.wrapq.wrappers.llama.quant_attn", "tico.quantization.wrapq.wrappers.llama.quant_decoder_layer", "tico.quantization.wrapq.wrappers.llama.quant_mlp", + "tico.quantization.wrapq.wrappers.llama.quant_model", + "tico.quantization.wrapq.wrappers.llama.quant_model_for_causal_lm", ## fairseq ## "tico.quantization.wrapq.wrappers.fairseq.quant_decoder_layer", "tico.quantization.wrapq.wrappers.fairseq.quant_encoder",