diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_attn.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_attn.py new file mode 100644 index 00000000..6bd21b78 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_attn.py @@ -0,0 +1,135 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The tests run only if *transformers* is available (they depend on the genuine +`transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionAttention`). +""" + +import importlib.util +import unittest + +import torch +from tico.quantization.config.ptq import PTQConfig + +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_attn import ( + QuantQwen3VLVisionAttention, +) + + +trans_spec = importlib.util.find_spec("transformers") +skip_msg = "transformers not installed — skipping LlamaAttention tests" + + +@unittest.skipUnless(trans_spec, skip_msg) +class TestQuantQwen3VLAttention(unittest.TestCase): + fp_attn: torch.nn.Module + head_dim: int + hidden_size: int + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLVisionConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionAttention, + ) + + cls.hidden_size = 16 + cfg = Qwen3VLVisionConfig(hidden_size=cls.hidden_size, num_heads=2) + + # Ensure eager attention implementation so outputs are deterministic + # and do not require GPU flash attention kernels. + # Some versions use `_attn_implementation`, others expose `attn_implementation`. + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + + cls.fp_attn = Qwen3VLVisionAttention(cfg) + cls.head_dim = cls.fp_attn.head_dim + + # dummy RoPE tables with correct last dim + def _rand_rope(self, S): + h = self.head_dim + emb = torch.randn(S, h) + return emb.cos(), emb.sin() + + def test_mode_transitions(self): + qattn = QuantQwen3VLVisionAttention(self.fp_attn) + self.assertIs(qattn._mode, Mode.NO_QUANT) + + qattn.enable_calibration() + self.assertIs(qattn._mode, Mode.CALIB) + seq_len = 12 + x = torch.randn(seq_len, self.hidden_size) + pos = self._rand_rope(seq_len) + _ = qattn(x, cu_seqlens=None, rotary_pos_emb=None, position_embeddings=pos) + + qattn.freeze_qparams() + self.assertIs(qattn._mode, Mode.QUANT) + + def test_forward_diff(self): + seq_len = 12 + cu_seqlens = torch.tensor([0, seq_len]) + qattn = QuantQwen3VLVisionAttention(self.fp_attn) + qattn.enable_calibration() + for _ in range(4): + inp = torch.randn(seq_len, self.hidden_size) + pos = self._rand_rope(seq_len) + _ = qattn(inp, cu_seqlens=cu_seqlens, position_embeddings=pos) + qattn.freeze_qparams() + + x = torch.randn(seq_len, self.hidden_size) + pos = self._rand_rope(seq_len) + with torch.no_grad(): + q_out = qattn(x, cu_seqlens=cu_seqlens, position_embeddings=pos) + fp_out = self.fp_attn(inp, cu_seqlens=cu_seqlens, position_embeddings=pos) + + diff = (fp_out - q_out).abs().mean().item() + self.assertGreater(diff, 0.0) + self.assertLess(diff, 0.4) + self.assertEqual(fp_out.shape, q_out.shape) + + def test_per_projection_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "qkv": { + "act_in": {"dtype": DType.uint(4)}, + "act_out": {"dtype": DType.uint(4)}, + }, + "proj": { + "act_in": {"dtype": DType.int(16)}, + "act_out": {"dtype": DType.int(16)}, + }, + }, + ) + qattn = QuantQwen3VLVisionAttention(self.fp_attn, qcfg=cfg) + + q_lin = qattn.proj.wrapped + self.assertIsInstance(q_lin, QuantLinear) + self.assertEqual(q_lin.obs_act_in.dtype, DType.int(16)) + self.assertEqual(q_lin.obs_act_out.dtype, DType.int(16)) + + q_lin = qattn.qkv.wrapped + self.assertIsInstance(q_lin, QuantLinear) + self.assertEqual(q_lin.obs_act_in.dtype, DType.uint(4)) + self.assertEqual(q_lin.obs_act_out.dtype, DType.uint(4)) diff --git a/tico/quantization/config/ptq.py b/tico/quantization/config/ptq.py index a5af954e..689cfcc0 100644 --- a/tico/quantization/config/ptq.py +++ b/tico/quantization/config/ptq.py @@ -71,9 +71,9 @@ class PTQConfig(BaseConfig): ``` """ - default_dtype: DType = DType.uint(8) + default_dtype: DType = DType.int(16) # DType.uint(8) default_observer: Type[ObserverBase] = MinMaxObserver # type: ignore[type-abstract] - default_qscheme: QScheme = QScheme.PER_TENSOR_ASYMM + default_qscheme: QScheme = QScheme.PER_TENSOR_SYMM # QScheme.PER_TENSOR_ASYMM overrides: Mapping[str, Mapping[str, Any]] = field(default_factory=dict) # If True, any module that cannot be wrapped will raise. strict_wrap: bool = True diff --git a/tico/quantization/evaluation/script/mini_vqa_eval.py b/tico/quantization/evaluation/script/mini_vqa_eval.py index e015658b..a356e990 100644 --- a/tico/quantization/evaluation/script/mini_vqa_eval.py +++ b/tico/quantization/evaluation/script/mini_vqa_eval.py @@ -21,7 +21,7 @@ import torch from datasets import load_dataset -from transformers import AutoModelForVision2Seq, AutoProcessor +from transformers import AutoModelForImageTextToText, AutoProcessor # ============================================================ @@ -271,10 +271,11 @@ def main(): # Load model and processor processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) - model = AutoModelForVision2Seq.from_pretrained( + model = AutoModelForImageTextToText.from_pretrained( args.model_id, torch_dtype=torch_dtype, trust_remote_code=True, + cache_dir="/mnt/storage/transformers_cache", ).to(args.device) model.eval() diff --git a/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_attn.py b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_attn.py new file mode 100644 index 00000000..99443f2e --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_attn.py @@ -0,0 +1,162 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import torch +from transformers import AutoModelForImageTextToText, AutoProcessor # since 4.5 + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_attn import ( + QuantQwen3VLVisionAttention, +) +from tico.utils.utils import SuppressWarning + + +def get_position_embeddings(model, grid_thw: torch.Tensor): + pos_embeds = model.fast_pos_embed_interpolate(grid_thw) + + rotary_pos_emb = model.rot_pos_emb(grid_thw) + + seq_len, _ = pos_embeds.size() + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + return position_embeddings + + +def get_cu_seqlens(grid_thw: torch.Tensor): + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum( + dim=0, + # Select dtype based on the following factors: + # - FA2 requires that cu_seqlens_q must have dtype int32 + # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw + # See https://github.com/huggingface/transformers/pull/34852 for more information + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + return cu_seqlens + + +# ------------------------------------------------------------------------- +# 0. Load a Qwen3-VL model (text tower) + tokenizer +# ------------------------------------------------------------------------- +name = "Qwen/Qwen3-VL-4B-Instruct" +model = AutoModelForImageTextToText.from_pretrained( + name, + device_map="cpu", + trust_remote_code=True, + dtype=torch.float32, + cache_dir="/mnt/storage/transformers_cache", +) +model.eval() + +processor = AutoProcessor.from_pretrained(name, trust_remote_code=True) +# 1) Build chat-style multimodal messages (image token + text) +messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + { + "type": "text", + "text": ( + f"Describe the picture\n" + "Return ONLY the final answer with no extra words." + ), + }, + ], + } +] + +# 2) Render prompt that includes image tokens +prompt = processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, +) + + +# ------------------------------------------------------------------------- +# 1. Replace layer-0’s attn with QuantQwen3VLVisionAttention +# ------------------------------------------------------------------------- +orig_attn = model.model.visual.blocks[0].attn +attn_q = prepare(orig_attn, PTQConfig()) +attn_q.eval() +assert isinstance(attn_q.wrapped, QuantQwen3VLVisionAttention) + +# ------------------------------------------------------------------------- +# 2. calibration +# ------------------------------------------------------------------------- +image_size = (3, 512, 640) +examples = [ + torch.randint(0, 255, image_size), + torch.randint(0, 255, image_size), + torch.randint(0, 255, image_size), +] + +attn_inputs = [] +with torch.no_grad(): + for example in examples: + inputs = processor( + text=prompt, + images=example, + return_tensors="pt", + ) + grid_thw = inputs["image_grid_thw"] + pixel_values = inputs["pixel_values"] + + hidden_states = model.model.visual.patch_embed(pixel_values) + position_embeddings = get_position_embeddings(model.model.visual, grid_thw) + cu_seqlens = get_cu_seqlens(grid_thw) + + _ = attn_q(hidden_states, cu_seqlens, None, position_embeddings) + attn_inputs.append((hidden_states, cu_seqlens, None, position_embeddings)) + +convert(attn_q) +assert attn_q._mode is Mode.QUANT, "Quantization mode should be active now." + +# ------------------------------------------------------------------------- +# 3. Quick diff check (INT-sim vs FP32) +# ------------------------------------------------------------------------- +attn_input = attn_inputs[0] + +with torch.no_grad(): + int8_out = attn_q(*attn_input) + fp_out = orig_attn(*attn_input) + +print("┌───────────── Quantization Error Summary ─────────────") +print(f"│ Mean |diff|: {(int8_out - fp_out).abs().mean().item():.6f}") +print(f"│ PEIR : {compute_peir(fp_out, int8_out) * 100:.6f} %") +print("└──────────────────────────────────────────────────────") +print(plot_two_outputs(fp_out, int8_out)) + +# ------------------------------------------------------------------------- +# 4. Export the quantized block +# ------------------------------------------------------------------------- +import tico + +save_path = pathlib.Path("qwen3vl_vision_attn.q.circle") + +with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(attn_q, attn_input) +cm.save(save_path) + +print(f"Quantized Circle model saved to {save_path.resolve()}") diff --git a/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py index b5acf450..d2841677 100644 --- a/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py +++ b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py @@ -36,6 +36,7 @@ device_map="cpu", trust_remote_code=True, dtype=torch.float32, + cache_dir="/mnt/storage/transformers_cache", ) model.eval() diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_attn.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_attn.py new file mode 100644 index 00000000..3a1e0191 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_attn.py @@ -0,0 +1,205 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Iterable, Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionAttention", +) +class QuantQwen3VLVisionAttention(QuantModuleBase): + def __init__( + self, + attn_fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + cfg = attn_fp.config + self.config = cfg + self.num_heads = attn_fp.num_heads + self.head_dim = attn_fp.head_dim + + # ---- Wrap q k v o projections via PTQWrapper --------------- + qkv_cfg = qcfg.child("qkv") if qcfg else None + proj_cfg = qcfg.child("proj") if qcfg else None + + assert hasattr(attn_fp, "qkv") and isinstance(attn_fp.qkv, torch.nn.Module) + assert hasattr(attn_fp, "proj") and isinstance(attn_fp.proj, torch.nn.Module) + + self.qkv = PTQWrapper( + copy.deepcopy(attn_fp.qkv), qcfg=qkv_cfg, fp_name=f"{fp_name}.qkv_cfg" + ) + self.proj = PTQWrapper(attn_fp.proj, qcfg=proj_cfg, fp_name=f"{fp_name}.proj") + + # Let's fold constant scale (1/√d) to k_proj + scale_t = torch.tensor( + float(getattr(attn_fp, "scaling", self.head_dim**-0.5)) + ) + with torch.no_grad(): + lin = self.qkv.wrapped.module + k_offset = lin.weight.shape[0] // 3 + k_size = lin.weight.shape[0] // 3 + lin.weight[k_offset : k_offset + k_size, :].mul_(scale_t) + if lin.bias is not None: + lin.bias[k_offset : k_offset + k_size].mul_(scale_t) + + mk = self._make_obs + self.obs_hidden = mk("hidden") + self.obs_scaling = mk("scaling") + self.obs_mul_logits_scale = mk("mul_logits_scale") + self.obs_cos = mk("cos") + self.obs_sin = mk("sin") + + # rotate_half sub-steps (q) + self.obs_q_x1 = mk("q_x1") + self.obs_q_x2 = mk("q_x2") + self.obs_q_neg = mk("q_neg") + self.obs_q_cat = mk("q_cat") + + # rotate_half sub-steps (k) + self.obs_k_x1 = mk("k_x1") + self.obs_k_x2 = mk("k_x2") + self.obs_k_neg = mk("k_neg") + self.obs_k_cat = mk("k_cat") + + # RoPE combine + self.obs_q_cos = mk("q_cos") + self.obs_q_sin = mk("q_sin") + self.obs_q_rot = mk("q_rot") + self.obs_k_cos = mk("k_cos") + self.obs_k_sin = mk("k_sin") + self.obs_k_rot = mk("k_rot") + + # Masking & attention math + self.obs_logits = mk("logits") + self.obs_softmax = mk("softmax") + self.obs_attn_out = mk("attn_out") + + def _rot(self, t, o_x1, o_x2, o_neg, o_cat): + x1, x2 = torch.chunk(t, 2, dim=-1) + x1 = self._fq(x1, o_x1) + x2 = self._fq(x2, o_x2) + x2n = self._fq(-x2, o_neg) + return self._fq(torch.cat((x2n, x1), -1), o_cat) + + def _apply_rope( + self, q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2) + q_rot = self._rot( + q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat + ) + q_cos = self._fq(q * cos, self.obs_q_cos) + q_sin = self._fq(q_rot * sin, self.obs_q_sin) + q_embed = self._fq(q_cos + q_sin, self.obs_q_rot) + + k_rot = self._rot( + k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat + ) + k_cos = self._fq(k * cos, self.obs_k_cos) + k_sin = self._fq(k_rot * sin, self.obs_k_sin) + k_embed = self._fq(k_cos + k_sin, self.obs_k_rot) + + return q_embed, k_embed + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ): + hidden_states = self._fq(hidden_states, self.obs_hidden) + + seq_length = hidden_states.shape[0] + query_states, key_states, value_states = ( + self.qkv(hidden_states) + .reshape(seq_length, 3, self.num_heads, -1) + .permute(1, 0, 2, 3) + .unbind(0) + ) + cos, sin = position_embeddings # type: ignore[misc] + cos = self._fq(cos, self.obs_cos) + sin = self._fq(sin, self.obs_sin) + query_states, key_states = self._apply_rope(query_states, key_states, cos, sin) + query_states = query_states.transpose(0, 1).unsqueeze(0) + key_states = key_states.transpose(0, 1).unsqueeze(0) + value_states = value_states.transpose(0, 1).unsqueeze(0) + + # Other implementations: Process each chunk separately + # lengths = cu_seqlens[1:] - cu_seqlens[:-1] + # splits = [ + # torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states) + # ] + rep = query_states.size(-3) // key_states.size(-3) + assert rep == 1 # currently no GQA is supported + + attn_outputs = [] + for n_head in range(self.num_heads): + k_i = key_states[:, n_head : n_head + 1, :, :] + v_i = value_states[:, n_head : n_head + 1, :, :] + q_i = query_states[:, n_head : n_head + 1, :, :] + logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits) + # softmax + attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype) + attn_i = self._fq(attn_i, self.obs_softmax) + out_i = self._fq(attn_i @ v_i, self.obs_attn_out) + attn_outputs.append(out_i) + + attn_output = torch.cat(attn_outputs, dim=1) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(seq_length, -1).contiguous() + + attn_output = self.proj(attn_output) + return attn_output + + def _all_observers(self) -> Iterable: + yield from ( + self.obs_hidden, + self.obs_scaling, + self.obs_mul_logits_scale, + self.obs_cos, + self.obs_sin, + self.obs_q_x1, + self.obs_q_x2, + self.obs_q_neg, + self.obs_q_cat, + self.obs_k_x1, + self.obs_k_x2, + self.obs_k_neg, + self.obs_k_cat, + self.obs_q_cos, + self.obs_q_sin, + self.obs_q_rot, + self.obs_k_cos, + self.obs_k_sin, + self.obs_k_rot, + self.obs_logits, + self.obs_softmax, + self.obs_attn_out, + ) + for m in (self.qkv, self.proj): + yield from m._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 8eb896bd..6e63ab41 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -43,6 +43,7 @@ "tico.quantization.wrapq.wrappers.fairseq.quant_mha", ## qwen_vl ## "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_attn", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed", # add future core wrappers here