diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py new file mode 100644 index 00000000..f8df3128 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py @@ -0,0 +1,86 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import tempfile +import unittest +import warnings + +import tico + +import torch +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp import ( + QuantQwen3VLVisionMLP, +) + + +class DummyMLP(torch.nn.Module): + """Tiny stand-in for HF LlamaMLP (hidden=4, inter=8).""" + + def __init__(self): + super().__init__() + self.linear_fc1 = torch.nn.Linear(4, 8) + self.linear_fc2 = torch.nn.Linear(8, 4) + self.act_fn = torch.nn.SiLU() + + def forward(self, x): + return self.linear_fc2(self.act_fn(self.linear_fc1(x))) + + +class TestQuantQwenVisionMLP(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + self.fp32 = DummyMLP() + self.quant = QuantQwen3VLVisionMLP(self.fp32) + self.x = torch.randn(32, 4) + + def test_mode_and_forward(self): + # calibration + self.quant.enable_calibration() + _ = self.quant(self.x) + self.quant.freeze_qparams() + self.assertIs(self.quant._mode, Mode.QUANT) + + # forward diff + with torch.no_grad(): + q = self.quant(self.x) + f = self.fp32(self.x) + diff = (q - f).abs().mean().item() + self.assertLess(diff, 0.7) # loose bound + self.assertGreater(diff, 0.0) + + +class TestSubgraphExport(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + self.mlp_int8 = QuantQwen3VLVisionMLP(DummyMLP()).eval() + self.x = torch.randn(16, 4) + + def test_calib_quant_export(self): + # calib + self.mlp_int8.enable_calibration() + _ = self.mlp_int8(self.x) + self.mlp_int8.freeze_qparams() + + self.assertIs(self.mlp_int8._mode, Mode.QUANT) + + # export + with tempfile.TemporaryDirectory() as td: + path = pathlib.Path(td) / "mlp.circle" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + exported = tico.convert(self.mlp_int8, (self.x[:1],)) + exported.save(path) + self.assertTrue(path.exists()) diff --git a/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py new file mode 100644 index 00000000..b5acf450 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_mlp.py @@ -0,0 +1,95 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import torch +from transformers import AutoModelForImageTextToText # since 4.5 + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp import ( + QuantQwen3VLVisionMLP, +) +from tico.utils.utils import SuppressWarning + +# ------------------------------------------------------------------------- +# 0. Load a Qwen3-VL model (text tower) + tokenizer +# ------------------------------------------------------------------------- +name = "Qwen/Qwen3-VL-2B-Instruct" +model = AutoModelForImageTextToText.from_pretrained( + name, + device_map="cpu", + trust_remote_code=True, + dtype=torch.float32, +) +model.eval() + +# ------------------------------------------------------------------------- +# 1. Replace layer-0’s mlp with QuantQwen3VLVisionMLP +# ------------------------------------------------------------------------- +orig_mlp = model.model.visual.blocks[0].mlp +mlp_q = prepare(orig_mlp, PTQConfig()) +mlp_q.eval() +assert isinstance(mlp_q.wrapped, QuantQwen3VLVisionMLP) + +inp_shape = (orig_mlp.intermediate_size, orig_mlp.hidden_size) +# ------------------------------------------------------------------------- +# 2. calibration +# ------------------------------------------------------------------------- +examples = [ + torch.randn(inp_shape), + torch.randn(inp_shape), + torch.randn(inp_shape), +] + +with torch.no_grad(): + for example in examples: + _ = mlp_q(example) + +convert(mlp_q) +assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now." + +# ------------------------------------------------------------------------- +# 3. Quick diff check (INT-sim vs FP32) +# ------------------------------------------------------------------------- +hidden = examples[0] + +with torch.no_grad(): + int8_out = mlp_q(hidden) + fp_out = orig_mlp(hidden) + +print("┌───────────── Quantization Error Summary ─────────────") +print(f"│ Mean |diff|: {(int8_out - fp_out).abs().mean().item():.6f}") +print(f"│ PEIR : {compute_peir(fp_out, int8_out) * 100:.6f} %") +print("└──────────────────────────────────────────────────────") +print(plot_two_outputs(fp_out, int8_out)) + +# ------------------------------------------------------------------------- +# 4. Export the quantized block +# ------------------------------------------------------------------------- +import tico + +save_path = pathlib.Path("qwen3vl_vision_mlp.q.circle") + +example = torch.randn(inp_shape) + +with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(mlp_q, (example,)) +cm.save(save_path) + +print(f"Quantized Circle model saved to {save_path.resolve()}") diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py new file mode 100644 index 00000000..ef70099c --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionMLP", +) +class QuantQwen3VLVisionMLP(QuantModuleBase): + def __init__( + self, + mlp_fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + linear_fc1_cfg = qcfg.child("linear_fc1") if qcfg else None + linear_fc2_cfg = qcfg.child("linear_fc2") if qcfg else None + act_cfg = qcfg.child("act_fn") if qcfg else None + + # ----- wrap three Linear layers ------------------------------- + assert hasattr(mlp_fp, "linear_fc1") and isinstance( + mlp_fp.linear_fc1, torch.nn.Module + ) + assert hasattr(mlp_fp, "linear_fc2") and isinstance( + mlp_fp.linear_fc2, torch.nn.Module + ) + + self.linear_fc1 = PTQWrapper( + mlp_fp.linear_fc1, qcfg=linear_fc1_cfg, fp_name=f"{fp_name}.linear_fc1" + ) + self.linear_fc2 = PTQWrapper( + mlp_fp.linear_fc2, qcfg=linear_fc2_cfg, fp_name=f"{fp_name}.linear_fc2" + ) + + # ----- activation --------------------------------------------- + assert hasattr(mlp_fp, "act_fn") and isinstance(mlp_fp.act_fn, torch.nn.Module) + self.act_fn = PTQWrapper( + mlp_fp.act_fn, qcfg=act_cfg, fp_name=f"{fp_name}.act_fn" + ) + + # ----- local observers ---------------------------------------- + self.obs_act_in = self._make_obs("act_in") + self.obs_act_out = self._make_obs("act_out") + + def forward(self, hidden_state): + + # self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + # 1) quantize input once + x_q = self._fq(hidden_state, self.obs_act_in) + + # 2) linear_fc1 + fc1 = self.linear_fc1(x_q) + + # 3) activation on linear_fc1 + a = self.act_fn(fc1) + + # 4) linear_fc2 + h = self._fq(self.linear_fc2(a), self.obs_act_out) + + return h + + def _all_observers(self) -> Iterable: + yield self.obs_act_in + yield self.obs_act_out + # recurse into children that are QuantModuleBase + for m in (self.linear_fc1, self.linear_fc2, self.act_fn): + yield from m._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index e8c15c23..6a0c2b83 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -43,6 +43,7 @@ "tico.quantization.wrapq.wrappers.fairseq.quant_mha", ## qwen_vl ## "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", # add future core wrappers here )