From 5ecee3d54d8754adb4ac6442ebc0c81eca7a9461 Mon Sep 17 00:00:00 2001 From: "d.savchenkov" Date: Fri, 13 Feb 2026 15:11:03 +0300 Subject: [PATCH] [quantization] Introduce wrapper for Qwen3VLVisionPatchMerger This change introduces QuantQwen3VLVisionPatchMerger wrapper to support post-training quantization of Qwen3VLVisionPatchMerger module. TICO-DCO-1.0-Signed-off-by: d.savchenkov --- .../qwen_vl/test_quant_vision_patch_merger.py | 314 ++++++++++++++++++ .../qwen/quantize_qwen_vision_patch_meger.py | 115 +++++++ .../qwen_vl/quant_vision_patch_merger.py | 143 ++++++++ tico/quantization/wrapq/wrappers/registry.py | 1 + 4 files changed, 573 insertions(+) create mode 100644 test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py create mode 100644 tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_meger.py create mode 100644 tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py new file mode 100644 index 00000000..79d6b979 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py @@ -0,0 +1,314 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import pathlib +import tempfile +import unittest +import warnings + +import tico + +import torch +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.nn.quant_layernorm import QuantLayerNorm +from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_merger import ( + QuantQwen3VLVisionPatchMerger, +) + + +trans_spec = importlib.util.find_spec("transformers") +skip_msg = "transformers not installed — skipping Qwen3VLVisionPatchMerger tests" + + +@unittest.skipUnless(trans_spec, skip_msg) +class TestQuantQwen3VLVisionPatchMerger(unittest.TestCase): + fp_merger: torch.nn.Module + hidden_size: int + out_hidden_size: int + spatial_merge_size: int + + @classmethod + def setUpClass(cls): + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLVisionConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchMerger, + ) + + # Use smaller sizes for testing + cfg = Qwen3VLVisionConfig( + hidden_size=64, + spatial_merge_size=2, + out_hidden_size=32, + ) + + cls.fp_merger = Qwen3VLVisionPatchMerger(cfg, use_postshuffle_norm=False) + cls.hidden_size = cfg.hidden_size + cls.out_hidden_size = cfg.out_hidden_size + cls.spatial_merge_size = cfg.spatial_merge_size + + def test_mode_transitions(self): + """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger) + self.assertIs(q_merger._mode, Mode.NO_QUANT) + + q_merger.enable_calibration() + self.assertIs(q_merger._mode, Mode.CALIB) + + # Run forward pass during calibration + x = torch.randn(32, self.hidden_size) + _ = q_merger(x) + + q_merger.freeze_qparams() + self.assertIs(q_merger._mode, Mode.QUANT) + + def test_forward_diff(self): + """ + Test that quantized output is acceptably close to FP32 reference. + """ + torch.manual_seed(42) + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger) + q_merger.enable_calibration() + + # Calibrate with multiple inputs + for _ in range(4): + x = torch.randn(32, self.hidden_size) + _ = q_merger(x) + + q_merger.freeze_qparams() + + x = torch.randn(32, self.hidden_size) + with torch.no_grad(): + q_out = q_merger(x) + fp_out = self.fp_merger(x) + + self.assertEqual(fp_out.shape, q_out.shape) + diff = (fp_out - q_out).abs().mean().item() + self.assertGreater(diff, 0.0) # not identical + self.assertLess(diff, 0.7) # acceptably close + + def test_module_override(self): + """ + PTQConfig overrides should propagate to wrapped submodules. + """ + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "linear_fc1": { + "weight": {"dtype": DType.uint(4)}, + "act_in": {"dtype": DType.uint(4)}, + "act_out": {"dtype": DType.uint(4)}, + }, + "linear_fc2": { + "weight": {"dtype": DType.uint(4)}, + }, + "act_fn": { + "act_in": {"dtype": DType.uint(4)}, + "act_out": {"dtype": DType.uint(4)}, + }, + }, + ) + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger, qcfg=cfg) + + # Check linear_fc1 + q_fc1 = q_merger.linear_fc1.wrapped + self.assertIsInstance(q_fc1, QuantLinear) + self.assertEqual(q_fc1.obs_weight.dtype, DType.uint(4)) + self.assertEqual(q_fc1.obs_act_in.dtype, DType.uint(4)) + self.assertEqual(q_fc1.obs_act_out.dtype, DType.uint(4)) + + # Check linear_fc2 + q_fc2 = q_merger.linear_fc2.wrapped + self.assertIsInstance(q_fc2, QuantLinear) + self.assertEqual(q_fc2.obs_weight.dtype, DType.uint(4)) + + # Check act_fn (QuantGELU via QuantElementwise) + q_act = q_merger.act_fn.wrapped + self.assertEqual(q_act.act_in_obs.dtype, DType.uint(4)) + self.assertEqual(q_act.act_out_obs.dtype, DType.uint(4)) + + def test_activation_stats_collected(self): + """ + Test that activation statistics are properly collected during calibration. + """ + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger) + q_merger.enable_calibration() + + # Run forward pass to collect stats + x = torch.randn(32, self.hidden_size) + _ = q_merger(x) + + # Check that local observers have collected stats + self.assertTrue(q_merger.obs_hidden.min_val.numel() > 0) + self.assertTrue(q_merger.obs_post_norm.min_val.numel() > 0) + self.assertTrue(q_merger.obs_fc1.min_val.numel() > 0) + self.assertTrue(q_merger.obs_act.min_val.numel() > 0) + self.assertTrue(q_merger.obs_output.min_val.numel() > 0) + + # Freeze and check qparams exist + q_merger.freeze_qparams() + self.assertTrue(q_merger.obs_hidden.has_qparams) + self.assertTrue(q_merger.obs_post_norm.has_qparams) + self.assertTrue(q_merger.obs_fc1.has_qparams) + self.assertTrue(q_merger.obs_act.has_qparams) + self.assertTrue(q_merger.obs_output.has_qparams) + + def test_observer_count(self): + """ + Test that the wrapper has the correct number of observers. + - 5 local observers + - observers from norm (QuantLayerNorm) + - observers from linear_fc1 (QuantLinear) + - observers from act_fn (QuantElementwise) + - observers from linear_fc2 (QuantLinear) + Total: 5+ + """ + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger) + + observers = list(q_merger._all_observers()) + self.assertGreater(len(observers), 5) + + def test_registration_in_registry(self): + """ + Test that Qwen3VLVisionPatchMerger is properly registered. + """ + from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_merger import ( + QuantQwen3VLVisionPatchMerger, + ) + from tico.quantization.wrapq.wrappers.registry import lookup + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchMerger, + ) + + wrapper_cls = lookup(Qwen3VLVisionPatchMerger) + self.assertIs(wrapper_cls, QuantQwen3VLVisionPatchMerger) + + def test_output_shape(self): + """ + Test that output shape is correct. + Input: (N, hidden_size) + Output: (N // self.spatial_merge_size**2, out_hidden_size) = (N/4, 32) + """ + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger) + q_merger.enable_calibration() + + num_patches = 32 + x = torch.randn(num_patches, self.hidden_size) + _ = q_merger(x) + + q_merger.freeze_qparams() + + with torch.no_grad(): + q_out = q_merger(x) + fp_out = self.fp_merger(x) + + expected_shape = ( + num_patches // self.spatial_merge_size**2, + self.out_hidden_size, + ) + self.assertEqual(q_out.shape, expected_shape) + self.assertEqual(fp_out.shape, expected_shape) + + def test_use_postshuffle_norm(self): + """ + Test with use_postshuffle_norm=True flag. + """ + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLVisionConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchMerger, + ) + + cfg = Qwen3VLVisionConfig( + hidden_size=64, spatial_merge_size=2, out_hidden_size=32 + ) + + fp_merger = Qwen3VLVisionPatchMerger(cfg, use_postshuffle_norm=True) + q_merger = QuantQwen3VLVisionPatchMerger(fp_merger) + self.assertEqual(q_merger.hidden_size, fp_merger.hidden_size) + + num_patches = 32 + + q_merger.enable_calibration() + x = torch.randn(num_patches, fp_merger.hidden_size) + _ = q_merger(x) + q_merger.freeze_qparams() + + with torch.no_grad(): + q_out = q_merger(x) + fp_out = fp_merger(x) + + expected_shape = (num_patches, cfg.out_hidden_size) + self.assertEqual(fp_out.shape, expected_shape) + self.assertEqual(q_out.shape, expected_shape) + diff = (fp_out - q_out).abs().mean().item() + self.assertLess(diff, 0.7) + + def test_different_batch_sizes(self): + """ + Test that quantization works correctly with different batch sizes. + """ + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger) + q_merger.enable_calibration() + + # Calibrate with one size + calibrate_batch = torch.randn(32, self.hidden_size) + for _ in range(3): + _ = q_merger(calibrate_batch) + q_merger.freeze_qparams() + + # Test with different sizes + for num_patches in [16, 32, 64]: + x = torch.randn(num_patches, self.hidden_size) + with torch.no_grad(): + q_out = q_merger(x) + fp_out = self.fp_merger(x) + + expected_shape = ( + num_patches // self.spatial_merge_size**2, + self.out_hidden_size, + ) + self.assertEqual(q_out.shape, expected_shape) + self.assertEqual(fp_out.shape, expected_shape) + diff = (fp_out - q_out).abs().mean().item() + self.assertLess(diff, 0.7) + + def test_subgraph_export(self): + """ + Test that quantized merger can be exported to Circle format. + """ + q_merger = QuantQwen3VLVisionPatchMerger(self.fp_merger).eval() + x = torch.randn(16, self.hidden_size) + + # Calibrate and freeze + q_merger.enable_calibration() + _ = q_merger(x) + q_merger.freeze_qparams() + + self.assertIs(q_merger._mode, Mode.QUANT) + + # Export to Circle + with tempfile.TemporaryDirectory() as td: + path = pathlib.Path(td) / "patch_merger.circle" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + exported = tico.convert(q_merger, (x,)) + exported.save(path) + self.assertTrue(path.exists()) diff --git a/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_meger.py b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_meger.py new file mode 100644 index 00000000..2d818341 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_meger.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import sys + +import torch +import torch.nn as nn + +import tico +import tico.quantization +import tico.quantization.config.ptq + +# Check if transformers is available +trans_spec = importlib.util.find_spec("transformers") +if trans_spec is None: + print( + "Error: transformers package not installed. Cannot test Qwen3VLVisionPatchMerger." + ) + sys.exit(1) + +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchMerger + + +def generate_calibration_data( + batch_size: int, num_patches: int, hidden_size: int +) -> list: + """Generate calibration data for PTQ""" + calibration_data = [] + for i in range(batch_size): + x = torch.randn(num_patches, hidden_size) + calibration_data.append(x) + return calibration_data + + +def main(): + # Create the vision patch merger model + cfg = Qwen3VLVisionConfig( + hidden_size=1024, + spatial_merge_size=2, + out_hidden_size=2048, + ) + model = Qwen3VLVisionPatchMerger(cfg, use_postshuffle_norm=False) + model.eval() + + # Qwen3VLVisionPatchMerger( + # (norm): LayerNorm(4096, eps=1e-06, elementwise_affine=True) + # (linear_fc1): Linear(in_features=4096, out_features=4096, bias=True) + # (act_fn): GELU(approximation='none') + # (linear_fc2): Linear(in_features=4096, out_features=2048, bias=True) + # ) + assert ( + model.hidden_size == 4096 + ) # cfg.hidden_size * (cfg.spatial_merge_size**2) = 1024 * 2**2 + assert model.linear_fc1.in_features == 4096 + assert model.linear_fc1.out_features == 4096 + assert model.linear_fc2.in_features == 4096 + assert model.linear_fc2.out_features == 2048 + + # Generate calibration data + # Input shape: (num_patches, hidden_size) + # Example: input.shape=(num_patches=32, hidden_size=1024) + # num_patches=32 can come from e.g. two 8-frame videos 32x32 pixels RGB channels after they are embedded by Qwen3VLVisionPatchEmbed (Conv3d): + # (Batch, Channels, Time, Height, Width) = (2, 3, 4, 32, 32) --> Qwen3VLVisionPatchEmbed --> (num_patches, hidden_size) = (2*4*4, 1024), + # where 2*4*4 means (2 videos) times (4 spatial patches) times (4 temporal patches). + # 4 spatial patches can come from 32x32 frame with stride 16: 32/16 * 32/16 = 2*2 = 4. + # 4 temporal patches can come from 8 frames with stride 2: 8 / 2 = 4. + num_patches = 32 + hidden_size = cfg.hidden_size + calibration_data = generate_calibration_data( + batch_size=20, num_patches=num_patches, hidden_size=hidden_size + ) + + # Configure PTQ + ptq_config = tico.quantization.config.ptq.PTQConfig() + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Calibrate the model (collect statistics) + with torch.no_grad(): + for i, batch in enumerate(calibration_data): + prepared_model(batch) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Convert to Circle format + # example_inputs shape: (num_patches, hidden_size) + example_inputs = (torch.randn(num_patches, hidden_size),) + circle_model = tico.convert(quantized_model, example_inputs) + + # Save the Circle model + filename = "quantized_vision_patch_merger.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py new file mode 100644 index 00000000..ebfb25cc --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py @@ -0,0 +1,143 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionPatchMerger", +) +class QuantQwen3VLVisionPatchMerger(QuantModuleBase): + """ + Quantization wrapper for Qwen3VLVisionPatchMerger module. + + This module wraps the patch merger that transforms vision features to + language model input dimensions through a 2-layer MLP structure. + """ + + def __init__( + self, + fp_merger: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + + self.hidden_size = fp_merger.hidden_size + self.use_postshuffle_norm = fp_merger.use_postshuffle_norm + + assert hasattr(fp_merger, "norm") and isinstance(fp_merger.norm, nn.LayerNorm) + assert hasattr(fp_merger, "linear_fc1") and isinstance( + fp_merger.linear_fc1, nn.Linear + ) + assert hasattr(fp_merger, "linear_fc2") and isinstance( + fp_merger.linear_fc2, nn.Linear + ) + assert hasattr(fp_merger, "act_fn") and isinstance(fp_merger.act_fn, nn.GELU) + + # --- Wrap submodules via PTQWrapper ---------------------------------- + norm_cfg = qcfg.child("norm") if qcfg else None + fc1_cfg = qcfg.child("linear_fc1") if qcfg else None + fc2_cfg = qcfg.child("linear_fc2") if qcfg else None + act_cfg = qcfg.child("act_fn") if qcfg else None + + self.norm = PTQWrapper( + fp_merger.norm, + qcfg=norm_cfg, + fp_name=f"{fp_name}.norm", + ) + + self.linear_fc1 = PTQWrapper( + fp_merger.linear_fc1, + qcfg=fc1_cfg, + fp_name=f"{fp_name}.linear_fc1", + ) + + self.act_fn = PTQWrapper( + fp_merger.act_fn, + qcfg=act_cfg, + fp_name=f"{fp_name}.act_fn", + ) + + self.linear_fc2 = PTQWrapper( + fp_merger.linear_fc2, + qcfg=fc2_cfg, + fp_name=f"{fp_name}.linear_fc2", + ) + + # --- Observers ------------------------------------------------------ + mk = self._make_obs + self.obs_hidden = mk("hidden") + self.obs_post_norm = mk("post_norm") + self.obs_fc1 = mk("fc1") + self.obs_act = mk("act") + self.obs_output = mk("output") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with fake quantization. + + Args: + x: Input tensor of shape (num_patches, hidden_size) + + Returns: + Transformed features of shape (num_patches, out_hidden_size) + """ + # Quantize input activation + x = self._fq(x, self.obs_hidden) + + # Apply LayerNorm (with optional reshape based on use_postshuffle_norm) + if self.use_postshuffle_norm: + # Reshape to (N, hidden_size) before norm + x = self.norm(x.view(-1, self.hidden_size)).view(-1, self.hidden_size) + else: + x = x.view(-1, self.hidden_size) + + # Quantize post-norm activation + x = self._fq(x, self.obs_post_norm) + + # Apply first linear layer + x = self._fq(self.linear_fc1(x), self.obs_fc1) + + # Apply GELU activation + x = self._fq(self.act_fn(x), self.obs_act) + + # Apply second linear layer (projection to language model dimension) + x = self._fq(self.linear_fc2(x), self.obs_output) + + return x + + def _all_observers(self) -> Iterable: + """Yield all observers from this module and wrapped submodules.""" + # Local observers + yield from ( + self.obs_hidden, + self.obs_post_norm, + self.obs_fc1, + self.obs_act, + self.obs_output, + ) + + # Observers from wrapped submodules + for module in (self.norm, self.linear_fc1, self.act_fn, self.linear_fc2): + yield from module.wrapped._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 8eb896bd..f67cc0ba 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -45,6 +45,7 @@ "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_merger", # add future core wrappers here )