diff --git a/test/quantization/wrapq/wrappers/nn/test_quant_gelutanh.py b/test/quantization/wrapq/wrappers/nn/test_quant_gelutanh.py new file mode 100644 index 00000000..3969b610 --- /dev/null +++ b/test/quantization/wrapq/wrappers/nn/test_quant_gelutanh.py @@ -0,0 +1,146 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import unittest + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.nn.quant_gelutanh import QuantGELUTanh + +trans_spec = importlib.util.find_spec("transformers") +skip_msg = "transformers not installed — skipping TestQuantGELUTanh tests" + + +@unittest.skipUnless(trans_spec, skip_msg) +class TestQuantGELUTanh(unittest.TestCase): + gelutanh: type + + @classmethod + def setUpClass(cls): + import transformers + + cls.gelutanh = transformers.activations.GELUTanh + + def setUp(self): + torch.manual_seed(0) + self.x = torch.randn(128, 4) * 3 # wider than N(0,1) for better tanh coverage + self.fp_gelu_tanh = self.gelutanh() + self.qgelu_tanh = QuantGELUTanh(self.fp_gelu_tanh) # default uint8 + + def test_mode_transitions(self): + """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" + self.assertIs(self.qgelu_tanh._mode, Mode.NO_QUANT) + self.qgelu_tanh.enable_calibration() + self.assertIs(self.qgelu_tanh._mode, Mode.CALIB) + _ = self.qgelu_tanh(self.x) # collect stats + self.qgelu_tanh.freeze_qparams() + self.assertIs(self.qgelu_tanh._mode, Mode.QUANT) + + def test_quantised_output(self): + """ + Test that quantized output is acceptably close to FP32 reference. + After calibration and freeze, quantized output should: + - Differ from FP reference (quantization actually applied) + - Stay within reasonable error bounds + """ + self.qgelu_tanh.enable_calibration() + _ = self.qgelu_tanh(self.x) + self.qgelu_tanh.freeze_qparams() + + with torch.no_grad(): + q_out = self.qgelu_tanh(self.x) + fp_out = self.gelutanh()(self.x) + + diff = (q_out - fp_out).abs().mean().item() + self.assertGreater(diff, 0.0) # not identical (quantization applied) + self.assertLess(diff, 0.3) # acceptably close (same tolerance as SiLU) + + def test_dtype_override(self): + """ + PTQConfig overrides should propagate to observers created by QuantGELUTanh. + Test that different dtypes can be applied to intermediate activations. + """ + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "tanh": {"dtype": DType.uint(4)}, + "mul": {"dtype": DType.uint(4)}, + }, + ) + qgelu_custom = QuantGELUTanh(self.fp_gelu_tanh, qcfg=cfg) + + # Check that overrides were applied + self.assertEqual(qgelu_custom.obs_tanh.dtype, DType.uint(4)) + self.assertEqual(qgelu_custom.obs_mul.dtype, DType.uint(4)) + + def test_activation_stats_collected(self): + """ + Test that activation statistics are properly collected during calibration. + All three observers (act_in, tanh, mul) should collect statistics. + """ + self.qgelu_tanh.enable_calibration() + + # Run forward pass to collect stats + _ = self.qgelu_tanh(self.x) + + # Check that activation observers have collected stats + self.assertTrue( + self.qgelu_tanh.obs_act_in.has_qparams + or self.qgelu_tanh.obs_act_in.min_val.numel() > 0 + ) + self.assertTrue( + self.qgelu_tanh.obs_tanh.has_qparams + or self.qgelu_tanh.obs_tanh.min_val.numel() > 0 + ) + self.assertTrue( + self.qgelu_tanh.obs_mul.has_qparams + or self.qgelu_tanh.obs_mul.min_val.numel() > 0 + ) + + # Freeze and check qparams exist + self.qgelu_tanh.freeze_qparams() + self.assertTrue(self.qgelu_tanh.obs_act_in.has_qparams) + self.assertTrue(self.qgelu_tanh.obs_tanh.has_qparams) + self.assertTrue(self.qgelu_tanh.obs_mul.has_qparams) + + def test_no_quant_matches_reference(self): + """ + In NO_QUANT mode, output should match FP32 reference exactly + (up to numerical tolerances). + """ + # Create fresh wrapper that stays in NO_QUANT mode + qgelu = QuantGELUTanh(self.fp_gelu_tanh) + + with torch.no_grad(): + q_out = qgelu(self.x) + fp_out = self.gelutanh()(self.x) + + self.assertIs(qgelu._mode, Mode.NO_QUANT) + self.assertTrue(torch.allclose(q_out, fp_out, atol=1e-6, rtol=1e-6)) + + def test_registration_in_registry(self): + """ + Test that GELUTanh is properly registered in the wrapper registry. + """ + from tico.quantization.wrapq.wrappers.nn.quant_gelutanh import QuantGELUTanh + from tico.quantization.wrapq.wrappers.registry import lookup + + # Verify GELUTanh maps to QuantGELUTanh + wrapper_cls = lookup(self.gelutanh) + self.assertIs(wrapper_cls, QuantGELUTanh) diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py new file mode 100644 index 00000000..8b58ae7f --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py @@ -0,0 +1,87 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import tempfile +import unittest +import warnings + +import tico + +import torch +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp import ( + QuantQwen3VLVisionMLP, +) +from transformers.activations import GELUTanh + + +class DummyMLP(torch.nn.Module): + """Tiny stand-in for HF LlamaMLP (hidden=4, inter=8).""" + + def __init__(self): + super().__init__() + self.linear_fc1 = torch.nn.Linear(4, 8) + self.linear_fc2 = torch.nn.Linear(8, 4) + self.act_fn = GELUTanh() + + def forward(self, x): + return self.linear_fc2(self.act_fn(self.linear_fc1(x))) + + +class TestQuantQwenVisionMLP(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + self.fp32 = DummyMLP() + self.quant = QuantQwen3VLVisionMLP(self.fp32) + self.x = torch.randn(32, 4) + + def test_mode_and_forward(self): + # calibration + self.quant.enable_calibration() + _ = self.quant(self.x) + self.quant.freeze_qparams() + self.assertIs(self.quant._mode, Mode.QUANT) + + # forward diff + with torch.no_grad(): + q = self.quant(self.x) + f = self.fp32(self.x) + diff = (q - f).abs().mean().item() + self.assertLess(diff, 0.7) # loose bound + self.assertGreater(diff, 0.0) + + +class TestSubgraphExport(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + self.mlp_int8 = QuantQwen3VLVisionMLP(DummyMLP()).eval() + self.x = torch.randn(16, 4) + + def test_calib_quant_export(self): + # calib + self.mlp_int8.enable_calibration() + _ = self.mlp_int8(self.x) + self.mlp_int8.freeze_qparams() + + self.assertIs(self.mlp_int8._mode, Mode.QUANT) + + # export + with tempfile.TemporaryDirectory() as td: + path = pathlib.Path(td) / "mlp.circle" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + exported = tico.convert(self.mlp_int8, (self.x[:1],)) + exported.save(path) + self.assertTrue(path.exists()) diff --git a/tico/quantization/evaluation/script/mini_vqa_eval.py b/tico/quantization/evaluation/script/mini_vqa_eval.py index e015658b..4c4b7d9e 100644 --- a/tico/quantization/evaluation/script/mini_vqa_eval.py +++ b/tico/quantization/evaluation/script/mini_vqa_eval.py @@ -231,6 +231,8 @@ def main(): default="bfloat16", choices=["float16", "bfloat16", "float32"], ) + ap.add_argument("--cache_dir", type=str, default="cpu") + args = ap.parse_args() # Reproducibility @@ -270,11 +272,12 @@ def main(): torch_dtype = dtype_map[args.dtype] # Load model and processor - processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(args.model_id, trust_remote_code=True, cache_dir=args.cache_dir) model = AutoModelForVision2Seq.from_pretrained( args.model_id, torch_dtype=torch_dtype, trust_remote_code=True, + cache_dir=args.cache_dir, ).to(args.device) model.eval() diff --git a/tico/quantization/wrapq/examples/quantize_qwen_vision_mlp.py b/tico/quantization/wrapq/examples/quantize_qwen_vision_mlp.py new file mode 100644 index 00000000..80c4b665 --- /dev/null +++ b/tico/quantization/wrapq/examples/quantize_qwen_vision_mlp.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib + +import torch +from transformers import AutoModelForVision2Seq + +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp import ( + QuantQwen3VLVisionMLP, +) +from tico.utils.utils import SuppressWarning + +# ------------------------------------------------------------------------- +# 0. Load a Qwen3-VL model (text tower) + tokenizer +# ------------------------------------------------------------------------- +name = "Qwen/Qwen3-VL-2B-Instruct" +model = AutoModelForVision2Seq.from_pretrained( + name, + device_map="cpu", + trust_remote_code=True, + cache_dir="/mnt/storage/transformers_cache", +) +model.eval() + +# ------------------------------------------------------------------------- +# 1. Replace layer-0’s mlp with QuantQwen3VLVisionMLP +# ------------------------------------------------------------------------- +orig_mlp = model.model.visual.blocks[0].mlp +mlp_q = prepare(orig_mlp, PTQConfig()) +mlp_q.eval() +assert isinstance(mlp_q.wrapped, QuantQwen3VLVisionMLP) + +inp_shape = (orig_mlp.intermediate_size, orig_mlp.hidden_size) +# ------------------------------------------------------------------------- +# 2. calibration +# ------------------------------------------------------------------------- +examples = [ + torch.randn(inp_shape), + torch.randn(inp_shape), + torch.randn(inp_shape), +] + +with torch.no_grad(): + for example in examples: + _ = mlp_q(example) + +convert(mlp_q) +assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now." + +# ------------------------------------------------------------------------- +# 3. Quick diff check (INT-sim vs FP32) +# ------------------------------------------------------------------------- +hidden = examples[0] + +with torch.no_grad(): + int8_out = mlp_q(hidden) + fp_out = orig_mlp(hidden) + +print("┌───────────── Quantization Error Summary ─────────────") +print(f"│ Mean |diff|: {(int8_out - fp_out).abs().mean().item():.6f}") +print(f"│ PEIR : {compute_peir(fp_out, int8_out) * 100:.6f} %") +print("└──────────────────────────────────────────────────────") +print(plot_two_outputs(fp_out, int8_out)) + +# ------------------------------------------------------------------------- +# 4. Export the quantized block +# ------------------------------------------------------------------------- +import tico + +save_path = pathlib.Path("qwen3vl_vision_mlp.q.circle") + +example = torch.randn(inp_shape) + +with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(mlp_q, (example,)) +cm.save(save_path) + +print(f"Quantized Circle model saved to {save_path.resolve()}") diff --git a/tico/quantization/wrapq/wrappers/nn/quant_gelutanh.py b/tico/quantization/wrapq/wrappers/nn/quant_gelutanh.py new file mode 100644 index 00000000..0988e1c2 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/nn/quant_gelutanh.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register("transformers.activations.GELUTanh") +class QuantGELUTanh(QuantModuleBase): + """ + QuantGELUTanh — drop-in quantized implementation of the Tanh-based GELUTanh activation. + + This module quantizes both intermediate tensors: + t = tanh(sqrt(2/π) * (x + 0.044715 * x^3)) (tanh) + y = x * 0.5 * (1 + t) (mul) + + GELUTanh formula: + GELUTanh(x) = x * 0.5 * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + """ + + def __init__( + self, + fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None + ): + super().__init__(qcfg, fp_name=fp_name) + self.obs_act_in = self._make_obs("act_in") + self.obs_tanh = self._make_obs("tanh") + self.obs_mul = self._make_obs("mul") + self.module = fp + + def forward(self, x: torch.Tensor): + # Quantize input + x_q = self._fq(x, self.obs_act_in) + + # GELUTanh computation: x * 0.5 * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3))) + x3 = x_q * x_q * x_q + inner = x_q + 0.044715 * x3 + t = torch.tanh(math.sqrt(2.0 / math.pi) * inner) + + # Quantize tanh output + t = self._fq(t, self.obs_tanh) + + y = x_q * 0.5 * (1 + t) + + # Quantize final output + y = self._fq(y, self.obs_mul) + + return y + + def _all_observers(self): + return (self.obs_act_in, self.obs_tanh, self.obs_mul) diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py new file mode 100644 index 00000000..19d619e0 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py @@ -0,0 +1,91 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionMLP", +) +class QuantQwen3VLVisionMLP(QuantModuleBase): + def __init__( + self, + mlp_fp: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + linear_fc1_cfg = qcfg.child("linear_fc1") if qcfg else None + linear_fc2_cfg = qcfg.child("linear_fc2") if qcfg else None + act_cfg = qcfg.child("act_fn") if qcfg else None + + # ----- wrap three Linear layers ------------------------------- + assert hasattr(mlp_fp, "linear_fc1") and isinstance( + mlp_fp.linear_fc1, torch.nn.Module + ) + assert hasattr(mlp_fp, "linear_fc2") and isinstance( + mlp_fp.linear_fc2, torch.nn.Module + ) + + self.linear_fc1 = PTQWrapper( + mlp_fp.linear_fc1, qcfg=linear_fc1_cfg, fp_name=f"{fp_name}.linear_fc1" + ) + self.linear_fc2 = PTQWrapper( + mlp_fp.linear_fc2, qcfg=linear_fc2_cfg, fp_name=f"{fp_name}.linear_fc2" + ) + + # ----- activation --------------------------------------------- + assert hasattr(mlp_fp, "act_fn") and isinstance(mlp_fp.act_fn, torch.nn.Module) + self.act_fn = PTQWrapper( + mlp_fp.act_fn, qcfg=act_cfg, fp_name=f"{fp_name}.act_fn" + ) + + # ----- local observers ---------------------------------------- + self.obs_act_in = self._make_obs("act_in") + self.obs_act_out = self._make_obs("act_out") + + def forward(self, hidden_state): + + # self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state))) + + # 1) quantize input once + x_q = self._fq(hidden_state, self.obs_act_in) + + # 2) linear_fc1 + fc1 = self.linear_fc1(x_q) + + # 3) activation on linear_fc1 + a = self.act_fn(fc1) + + # 4) linear_fc2 + h = self._fq(self.linear_fc2(a), self.obs_act_out) + + return h + + def _all_observers(self) -> Iterable: + yield self.obs_act_in + yield self.obs_act_out + # recurse into children that are QuantModuleBase + for m in (self.linear_fc1, self.linear_fc2, self.act_fn): + yield from m._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index e8c15c23..52b6eefa 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -26,6 +26,7 @@ ## nn ## "tico.quantization.wrapq.wrappers.nn.quant_layernorm", "tico.quantization.wrapq.wrappers.nn.quant_linear", + "tico.quantization.wrapq.wrappers.nn.quant_gelutanh", "tico.quantization.wrapq.wrappers.nn.quant_conv3d", # This includes not only `nn.SiLU` but also `SiLUActivation` from transformers # as they are same operation. @@ -43,6 +44,7 @@ "tico.quantization.wrapq.wrappers.fairseq.quant_mha", ## qwen_vl ## "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", # add future core wrappers here )