From 20d536bd87796ad4c6e001f0c37fb8115c805b38 Mon Sep 17 00:00:00 2001 From: "d.savchenkov" Date: Thu, 12 Feb 2026 12:58:20 +0300 Subject: [PATCH] [quantization] Introduce wrapper for Qwen3VLVisionPatchEmbed This change introduces QuantQwen3VLVisionPatchEmbed wrapper to support post-training quantization of Qwen3VLVisionPatchEmbed module. TICO-DCO-1.0-Signed-off-by: d.savchenkov --- .../qwen_vl/test_quant_vision_patch_embed.py | 236 ++++++++++++++++++ .../qwen/quantize_qwen_vision_patch_embed.py | 102 ++++++++ .../qwen_vl/quant_vision_patch_embed.py | 113 +++++++++ tico/quantization/wrapq/wrappers/registry.py | 1 + 4 files changed, 452 insertions(+) create mode 100644 test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_embed.py create mode 100755 tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_embed.py create mode 100644 tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_embed.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_embed.py new file mode 100644 index 00000000..5e547708 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_embed.py @@ -0,0 +1,236 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import unittest + +import torch +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.nn.quant_conv3d import QuantConv3d +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed import ( + QuantQwen3VLVisionPatchEmbed, +) + + +trans_spec = importlib.util.find_spec("transformers") +skip_msg = "transformers not installed — skipping Qwen3VLVisionPatchEmbed tests" + + +@unittest.skipUnless(trans_spec, skip_msg) +class TestQuantQwen3VLVisionPatchEmbed(unittest.TestCase): + fp_patch_embed: torch.nn.Module + hidden_size: int + + @classmethod + def setUpClass(cls): + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLVisionConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchEmbed, + ) + + cfg = Qwen3VLVisionConfig( + hidden_size=64, # Smaller for testing + spatial_merge_size=2, + temporal_merge_size=2, + ) + + cls.fp_patch_embed = Qwen3VLVisionPatchEmbed(cfg) + cls.hidden_size = cfg.hidden_size + + def test_mode_transitions(self): + """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + self.assertIs(q_patch._mode, Mode.NO_QUANT) + + q_patch.enable_calibration() + self.assertIs(q_patch._mode, Mode.CALIB) + + # Run forward pass during calibration + x = torch.randn(2, 3, 4, 32, 32) + _ = q_patch(x) + + q_patch.freeze_qparams() + self.assertIs(q_patch._mode, Mode.QUANT) + + def test_forward_diff(self): + """ + Test that quantized output is acceptably close to FP32 reference. + After calibration and freeze, quantized output should: + - Differ from FP reference (quantization actually applied) + - Stay within reasonable error bounds + """ + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + q_patch.enable_calibration() + + # Calibrate with multiple inputs + for _ in range(4): + x = torch.randn(2, 3, 4, 32, 32) + _ = q_patch(x) + + q_patch.freeze_qparams() + + x = torch.randn(2, 3, 4, 32, 32) + with torch.no_grad(): + q_out = q_patch(x) + fp_out = self.fp_patch_embed(x) + + diff = (fp_out - q_out).abs().mean().item() + self.assertGreater(diff, 0.0) # not identical + self.assertLess(diff, 0.4) # acceptably close + self.assertEqual(fp_out.shape, q_out.shape) + + def test_proj_override(self): + """ + PTQConfig overrides should propagate to the wrapped Conv3d layer. + """ + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "proj": { + "weight": {"dtype": DType.uint(4)}, + "act_in": {"dtype": DType.uint(4)}, + "act_out": {"dtype": DType.uint(4)}, + } + }, + ) + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed, qcfg=cfg) + q_conv3d = q_patch.proj.wrapped + + self.assertIsInstance(q_conv3d, QuantConv3d) + self.assertEqual(q_conv3d.obs_weight.dtype, DType.uint(4)) + self.assertEqual(q_conv3d.obs_act_in.dtype, DType.uint(4)) + self.assertEqual(q_conv3d.obs_act_out.dtype, DType.uint(4)) + + def test_activation_stats_collected(self): + """ + Test that activation statistics are properly collected during calibration. + Both local observers and wrapped Conv3d observers should collect stats. + """ + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + q_patch.enable_calibration() + + # Run forward pass to collect stats + x = torch.randn(2, 3, 4, 32, 32) + _ = q_patch(x) + + # Check that local observers have collected stats + self.assertTrue(q_patch.obs_hidden.min_val.numel() > 0) + self.assertTrue(q_patch.obs_output.min_val.numel() > 0) + + # Check that wrapped Conv3d observers have collected stats + q_conv3d = q_patch.proj.wrapped + self.assertTrue(q_conv3d.obs_act_in.min_val.numel() > 0) + self.assertTrue(q_conv3d.obs_act_out.min_val.numel() > 0) + self.assertTrue(q_conv3d.obs_weight.min_val.numel() > 0) + + # Freeze and check qparams exist + q_patch.freeze_qparams() + self.assertTrue(q_patch.obs_hidden.has_qparams) + self.assertTrue(q_patch.obs_output.has_qparams) + self.assertTrue(q_conv3d.obs_act_in.has_qparams) + self.assertTrue(q_conv3d.obs_act_out.has_qparams) + self.assertTrue(q_conv3d.obs_weight.has_qparams) + + def test_observer_count(self): + """ + Test that the wrapper has the correct number of observers. + - 2 local observers (obs_hidden, obs_output) + - 3 observers from wrapped Conv3d (obs_weight, obs_act_in, obs_act_out) + """ + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + + observers = list(q_patch._all_observers()) + self.assertEqual(len(observers), 5) # 2 local + 3 from Conv3d + + def test_registration_in_registry(self): + """ + Test that Qwen3VLVisionPatchEmbed is properly registered in the wrapper registry. + """ + from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed import ( + QuantQwen3VLVisionPatchEmbed, + ) + from tico.quantization.wrapq.wrappers.registry import lookup + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchEmbed, + ) + + # Verify Qwen3VLVisionPatchEmbed maps to QuantQwen3VLVisionPatchEmbed + wrapper_cls = lookup(Qwen3VLVisionPatchEmbed) + self.assertIs(wrapper_cls, QuantQwen3VLVisionPatchEmbed) + + def test_output_shape(self): + """Test that output shape is correct after patch embedding.""" + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + q_patch.enable_calibration() + + x = torch.randn(2, 3, 4, 32, 32) + _ = q_patch(x) + + q_patch.freeze_qparams() + + with torch.no_grad(): + q_out = q_patch(x) + fp_out = self.fp_patch_embed(x) + + self.assertEqual(q_out.shape, fp_out.shape) + + def test_multiple_calibration_steps(self): + """ + Test that running multiple calibration iterations works correctly. + Statistics should be accumulated across multiple forward passes. + """ + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + q_patch.enable_calibration() + + # Run multiple calibration steps + for i in range(5): + x = torch.randn(2, 3, 4, 32, 32) + _ = q_patch(x) + + q_patch.freeze_qparams() + + # Verify that all observers have quantization parameters + self.assertTrue(q_patch.obs_hidden.has_qparams) + self.assertTrue(q_patch.obs_output.has_qparams) + self.assertTrue(q_patch.proj.wrapped.obs_act_in.has_qparams) + self.assertTrue(q_patch.proj.wrapped.obs_act_out.has_qparams) + self.assertTrue(q_patch.proj.wrapped.obs_weight.has_qparams) + + def test_different_batch_sizes(self): + """ + Test that quantization works correctly with different batch sizes. + """ + q_patch = QuantQwen3VLVisionPatchEmbed(self.fp_patch_embed) + q_patch.enable_calibration() + + # Calibrate with one batch size + calibrate_batch = torch.randn(2, 3, 4, 32, 32) + for _ in range(3): + _ = q_patch(calibrate_batch) + q_patch.freeze_qparams() + + # Test with different batch sizes + for batch_size in [1, 2, 4]: + x = torch.randn(batch_size, 3, 4, 32, 32) + with torch.no_grad(): + q_out = q_patch(x) + fp_out = self.fp_patch_embed(x) + + self.assertEqual(q_out.shape, fp_out.shape) + diff = (fp_out - q_out).abs().mean().item() + self.assertLess(diff, 0.4) diff --git a/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_embed.py b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_embed.py new file mode 100755 index 00000000..5edd106c --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_embed.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib.util +import sys + +import torch +import torch.nn as nn + +import tico +import tico.quantization +import tico.quantization.config.ptq + +# Check if transformers is available +trans_spec = importlib.util.find_spec("transformers") +if trans_spec is None: + print( + "Error: transformers package not installed. Cannot test Qwen3VLVisionPatchEmbed." + ) + sys.exit(1) + +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionPatchEmbed + + +def generate_calibration_data(batch_size: int, sample_shape) -> list: + """Generate calibration data for PTQ""" + calibration_data = [] + for i in range(batch_size): + x = torch.randn(sample_shape) + calibration_data.append(x) + return calibration_data + + +def main(): + # Create the vision patch embed model + cfg = Qwen3VLVisionConfig( + in_channels=3, + hidden_size=1024, + temporal_merge_size=2, + patch_size=16, + ) + model = Qwen3VLVisionPatchEmbed(cfg) + model.eval() + + # Qwen3VLVisionPatchEmbed( + # (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16)) + # ) + assert model.proj.in_channels == 3 + assert model.proj.out_channels == 1024 + assert model.proj.kernel_size == (2, 16, 16) + assert model.proj.stride == (2, 16, 16) + + # Generate calibration data + # Input shape: (batch_size, in_channels, depth, height, width) + # Example: (2, 3, 8, 224, 224) - 2 videos, RGB, 8 frames, 224x224 resolution + calibration_data = generate_calibration_data( + batch_size=20, sample_shape=(2, 3, 8, 224, 224) + ) + + # Configure PTQ + ptq_config = tico.quantization.config.ptq.PTQConfig() + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Calibrate the model (collect statistics) + with torch.no_grad(): + for i, batch in enumerate(calibration_data): + prepared_model(batch) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Convert to Circle format + # example_inputs shape: (batch_size, in_channels, depth, height, width) + example_inputs = (torch.randn(2, 3, 8, 224, 224),) + circle_model = tico.convert(quantized_model, example_inputs) + + # Save the Circle model + filename = "quantized_vision_patch_embed.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py new file mode 100644 index 00000000..88e642e2 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py @@ -0,0 +1,113 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionPatchEmbed", +) +class QuantQwen3VLVisionPatchEmbed(QuantModuleBase): + """ + Quantization wrapper for Qwen3VLVisionPatchEmbed module. + + This module wraps the Conv3d patch embedding layer that converts raw video + frames into patch embeddings for the vision transformer. + """ + + def __init__( + self, + fp_patch_embed: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + + self.patch_size = fp_patch_embed.patch_size + self.temporal_patch_size = fp_patch_embed.temporal_patch_size + self.in_channels = fp_patch_embed.in_channels + self.embed_dim = fp_patch_embed.embed_dim + + assert hasattr(fp_patch_embed, "proj") and isinstance( + fp_patch_embed.proj, nn.Conv3d + ) + + # Wrap the Conv3d projection layer via PTQWrapper + # This will use QuantConv3d wrapper (registered in the registry) + proj_cfg = qcfg.child("proj") if qcfg else None + self.proj = PTQWrapper( + fp_patch_embed.proj, + qcfg=proj_cfg, + fp_name=f"{fp_name}.proj", + ) + + # Observer for input activation (raw video frames) + self.obs_hidden = self._make_obs("hidden") + + # Observer for output activation (patch embeddings) + self.obs_output = self._make_obs("output") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass with fake quantization. + + Args: + hidden_states: Input tensor of shape (batch_size, channels, depth, height, width) + Raw video frames (RGB: channels=3) + + Returns: + Patch embeddings of shape (batch_size * T' * H' * W', embed_dim) + Flattened 2D tensor + """ + # Quantize input activation + hidden = self._fq(hidden_states, self.obs_hidden) + + # Reshape input to (B*T*H*W, C, temporal_patch_size, patch_size, patch_size) + # This flattens batch and spatial dimensions into a single sequence dimension + hidden = hidden.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + + # Apply Conv3d patch embedding (quantized via PTQWrapper) + # Output: (B*T*H*W, embed_dim, 1, 1, 1) + hidden = self.proj(hidden) + + # Quantize intermediate output (after Conv3d, before reshape) + hidden = self._fq(hidden, self.obs_output) + + # Reshape output to (B*T*H*W, embed_dim) + hidden = hidden.view(-1, self.embed_dim) + + return hidden + + def _all_observers(self) -> Iterable: + """Yield all observers from this module and wrapped submodules.""" + # Local observers + yield from (self.obs_hidden, self.obs_output) + + # Observers from wrapped Conv3d layer + yield from self.proj.wrapped._all_observers() diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 6a0c2b83..8eb896bd 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -44,6 +44,7 @@ ## qwen_vl ## "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed", # add future core wrappers here )