Skip to content

[quantization] Introduce wrapper for Qwen3VLVisionPatchMerger#493

Open
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:quant_vision_patch_merger
Open

[quantization] Introduce wrapper for Qwen3VLVisionPatchMerger#493
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:quant_vision_patch_merger

Conversation

@dvsav
Copy link
Contributor

@dvsav dvsav commented Feb 13, 2026

This change introduces QuantQwen3VLVisionPatchMerger wrapper to support post-training quantization of Qwen3VLVisionPatchMerger module.

Why?

Qwen3VLVisionPatchMerger module is used in the image encoder part of Qwen model.
Trying to quantize Qwen3VLVisionPatchMerger via PTQ generates exception PTQQuantizer: no quantization wrapper for Qwen3VLVisionPatchMerger.

What

This change introduces:

  • Class QuantQwen3VLVisionPatchMerger (tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py).
  • Unit tests: class TestQuantQwen3VLVisionPatchMerger (test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py) - skipped if transformers package is not installed.
  • New entry tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_merger in _CORE_MODULES (tico/quantization/wrapq/wrappers/registry.py).
  • Example of Qwen3VLVisionPatchMerger quantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_patch_merger.py).

Unit Tests

Unit tests results with coverage information:

$ coverage run -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py -v
======================================================================================= test session starts ========================================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 10 items                                                                                                                                                                                 

test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_activation_stats_collected PASSED                                        [ 10%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_different_batch_sizes      PASSED                                        [ 20%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_forward_diff               PASSED                                        [ 30%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_mode_transitions           PASSED                                        [ 40%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_module_override            PASSED                                        [ 50%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_observer_count             PASSED                                        [ 60%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_output_shape               PASSED                                        [ 70%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_registration_in_registry   PASSED                                        [ 80%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_subgraph_export            PASSED                                        [ 90%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_patch_merger.py::TestQuantQwen3VLVisionPatchMerger::test_use_postshuffle_norm       PASSED                                        [100%]

================================================================================== 10 passed, 2 warnings in 8.50s ==================================================================================

Coverage info (irrelevant files skipped):

$ coverage report -m
Name                                                                   Stmts   Miss  Cover   Missing
----------------------------------------------------------------------------------------------------
...
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py      45      0   100%
...
----------------------------------------------------------------------------------------------------
TOTAL                                                                   10274   4932    52%

@dvsav dvsav force-pushed the quant_vision_patch_merger branch 3 times, most recently from 11bb2a1 to 430f65f Compare February 13, 2026 14:46
@dvsav dvsav marked this pull request as ready for review February 13, 2026 14:47
This change introduces QuantQwen3VLVisionPatchMerger wrapper to support post-training quantization of Qwen3VLVisionPatchMerger module.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
@dvsav dvsav force-pushed the quant_vision_patch_merger branch from 430f65f to 2650aa6 Compare February 13, 2026 15:16
@dvsav
Copy link
Contributor Author

dvsav commented Feb 16, 2026

For Reviewers

Below is the source code of Qwen3VLVisionPatchMerger module that can be used to check the correctness of QuantQwen3VLVisionPatchMerger implementation:

# transformers/models/qwen3_vl/modeling_qwen3_vl.py

class Qwen3VLVisionPatchMerger(nn.Module):
    def __init__(self, config: Qwen3VLVisionConfig, use_postshuffle_norm=False) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
        self.use_postshuffle_norm = use_postshuffle_norm
        self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, eps=1e-6)
        self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size)
        self.act_fn = nn.GELU()
        self.linear_fc2 = nn.Linear(self.hidden_size, config.out_hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size)
        x = self.linear_fc2(self.act_fn(self.linear_fc1(x)))
        return x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant