-
Notifications
You must be signed in to change notification settings - Fork 24
[DRAFT][quantization] Introduce Qwen3VLVisionMLP wrapper #484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
ec90278
ssfdf
stamalakhov 4ac91ce
[quantization] Introduce wrapper for GELUTanh
594fe9d
[quantization] Introduce Qwen3VLVisionMLP wrapper
stamalakhov 268adb1
Merge branch 'main' into vision_mlp
stamalakhov 4173a26
Apply suggestions from code review
stamalakhov File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
146 changes: 146 additions & 0 deletions
146
test/quantization/wrapq/wrappers/nn/test_quant_gelutanh.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| # Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import importlib.util | ||
| import unittest | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from tico.quantization.config.ptq import PTQConfig | ||
| from tico.quantization.wrapq.dtypes import DType | ||
| from tico.quantization.wrapq.mode import Mode | ||
| from tico.quantization.wrapq.wrappers.nn.quant_gelutanh import QuantGELUTanh | ||
|
|
||
| trans_spec = importlib.util.find_spec("transformers") | ||
| skip_msg = "transformers not installed — skipping TestQuantGELUTanh tests" | ||
|
|
||
|
|
||
| @unittest.skipUnless(trans_spec, skip_msg) | ||
| class TestQuantGELUTanh(unittest.TestCase): | ||
| gelutanh: type | ||
|
|
||
| @classmethod | ||
| def setUpClass(cls): | ||
| import transformers | ||
|
|
||
| cls.gelutanh = transformers.activations.GELUTanh | ||
|
|
||
| def setUp(self): | ||
| torch.manual_seed(0) | ||
| self.x = torch.randn(128, 4) * 3 # wider than N(0,1) for better tanh coverage | ||
| self.fp_gelu_tanh = self.gelutanh() | ||
| self.qgelu_tanh = QuantGELUTanh(self.fp_gelu_tanh) # default uint8 | ||
|
|
||
| def test_mode_transitions(self): | ||
| """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" | ||
| self.assertIs(self.qgelu_tanh._mode, Mode.NO_QUANT) | ||
| self.qgelu_tanh.enable_calibration() | ||
| self.assertIs(self.qgelu_tanh._mode, Mode.CALIB) | ||
| _ = self.qgelu_tanh(self.x) # collect stats | ||
| self.qgelu_tanh.freeze_qparams() | ||
| self.assertIs(self.qgelu_tanh._mode, Mode.QUANT) | ||
|
|
||
| def test_quantised_output(self): | ||
| """ | ||
| Test that quantized output is acceptably close to FP32 reference. | ||
| After calibration and freeze, quantized output should: | ||
| - Differ from FP reference (quantization actually applied) | ||
| - Stay within reasonable error bounds | ||
| """ | ||
| self.qgelu_tanh.enable_calibration() | ||
| _ = self.qgelu_tanh(self.x) | ||
| self.qgelu_tanh.freeze_qparams() | ||
|
|
||
| with torch.no_grad(): | ||
| q_out = self.qgelu_tanh(self.x) | ||
| fp_out = self.gelutanh()(self.x) | ||
|
|
||
| diff = (q_out - fp_out).abs().mean().item() | ||
| self.assertGreater(diff, 0.0) # not identical (quantization applied) | ||
| self.assertLess(diff, 0.3) # acceptably close (same tolerance as SiLU) | ||
|
|
||
| def test_dtype_override(self): | ||
| """ | ||
| PTQConfig overrides should propagate to observers created by QuantGELUTanh. | ||
| Test that different dtypes can be applied to intermediate activations. | ||
| """ | ||
| cfg = PTQConfig( | ||
| default_dtype=DType.uint(8), | ||
| overrides={ | ||
| "tanh": {"dtype": DType.uint(4)}, | ||
| "mul": {"dtype": DType.uint(4)}, | ||
| }, | ||
| ) | ||
| qgelu_custom = QuantGELUTanh(self.fp_gelu_tanh, qcfg=cfg) | ||
|
|
||
| # Check that overrides were applied | ||
| self.assertEqual(qgelu_custom.obs_tanh.dtype, DType.uint(4)) | ||
| self.assertEqual(qgelu_custom.obs_mul.dtype, DType.uint(4)) | ||
|
|
||
| def test_activation_stats_collected(self): | ||
| """ | ||
| Test that activation statistics are properly collected during calibration. | ||
| All three observers (act_in, tanh, mul) should collect statistics. | ||
| """ | ||
| self.qgelu_tanh.enable_calibration() | ||
|
|
||
| # Run forward pass to collect stats | ||
| _ = self.qgelu_tanh(self.x) | ||
|
|
||
| # Check that activation observers have collected stats | ||
| self.assertTrue( | ||
| self.qgelu_tanh.obs_act_in.has_qparams | ||
| or self.qgelu_tanh.obs_act_in.min_val.numel() > 0 | ||
| ) | ||
| self.assertTrue( | ||
| self.qgelu_tanh.obs_tanh.has_qparams | ||
| or self.qgelu_tanh.obs_tanh.min_val.numel() > 0 | ||
| ) | ||
| self.assertTrue( | ||
| self.qgelu_tanh.obs_mul.has_qparams | ||
| or self.qgelu_tanh.obs_mul.min_val.numel() > 0 | ||
| ) | ||
|
|
||
| # Freeze and check qparams exist | ||
| self.qgelu_tanh.freeze_qparams() | ||
| self.assertTrue(self.qgelu_tanh.obs_act_in.has_qparams) | ||
| self.assertTrue(self.qgelu_tanh.obs_tanh.has_qparams) | ||
| self.assertTrue(self.qgelu_tanh.obs_mul.has_qparams) | ||
|
|
||
| def test_no_quant_matches_reference(self): | ||
| """ | ||
| In NO_QUANT mode, output should match FP32 reference exactly | ||
| (up to numerical tolerances). | ||
| """ | ||
| # Create fresh wrapper that stays in NO_QUANT mode | ||
| qgelu = QuantGELUTanh(self.fp_gelu_tanh) | ||
|
|
||
| with torch.no_grad(): | ||
| q_out = qgelu(self.x) | ||
| fp_out = self.gelutanh()(self.x) | ||
|
|
||
| self.assertIs(qgelu._mode, Mode.NO_QUANT) | ||
| self.assertTrue(torch.allclose(q_out, fp_out, atol=1e-6, rtol=1e-6)) | ||
|
|
||
| def test_registration_in_registry(self): | ||
| """ | ||
| Test that GELUTanh is properly registered in the wrapper registry. | ||
| """ | ||
| from tico.quantization.wrapq.wrappers.nn.quant_gelutanh import QuantGELUTanh | ||
| from tico.quantization.wrapq.wrappers.registry import lookup | ||
|
|
||
| # Verify GELUTanh maps to QuantGELUTanh | ||
| wrapper_cls = lookup(self.gelutanh) | ||
| self.assertIs(wrapper_cls, QuantGELUTanh) |
87 changes: 87 additions & 0 deletions
87
test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_mlp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pathlib | ||
| import tempfile | ||
| import unittest | ||
| import warnings | ||
|
|
||
| import tico | ||
|
|
||
| import torch | ||
| from tico.quantization.wrapq.mode import Mode | ||
| from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp import ( | ||
| QuantQwen3VLVisionMLP, | ||
| ) | ||
| from transformers.activations import GELUTanh | ||
|
|
||
|
|
||
| class DummyMLP(torch.nn.Module): | ||
| """Tiny stand-in for HF LlamaMLP (hidden=4, inter=8).""" | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.linear_fc1 = torch.nn.Linear(4, 8) | ||
| self.linear_fc2 = torch.nn.Linear(8, 4) | ||
| self.act_fn = GELUTanh() | ||
|
|
||
| def forward(self, x): | ||
| return self.linear_fc2(self.act_fn(self.linear_fc1(x))) | ||
|
|
||
|
|
||
| class TestQuantQwenVisionMLP(unittest.TestCase): | ||
| def setUp(self): | ||
| torch.manual_seed(0) | ||
| self.fp32 = DummyMLP() | ||
| self.quant = QuantQwen3VLVisionMLP(self.fp32) | ||
| self.x = torch.randn(32, 4) | ||
|
|
||
| def test_mode_and_forward(self): | ||
| # calibration | ||
| self.quant.enable_calibration() | ||
| _ = self.quant(self.x) | ||
| self.quant.freeze_qparams() | ||
| self.assertIs(self.quant._mode, Mode.QUANT) | ||
|
|
||
| # forward diff | ||
| with torch.no_grad(): | ||
| q = self.quant(self.x) | ||
| f = self.fp32(self.x) | ||
| diff = (q - f).abs().mean().item() | ||
| self.assertLess(diff, 0.7) # loose bound | ||
| self.assertGreater(diff, 0.0) | ||
|
|
||
|
|
||
| class TestSubgraphExport(unittest.TestCase): | ||
| def setUp(self): | ||
| torch.manual_seed(0) | ||
| self.mlp_int8 = QuantQwen3VLVisionMLP(DummyMLP()).eval() | ||
| self.x = torch.randn(16, 4) | ||
|
|
||
| def test_calib_quant_export(self): | ||
| # calib | ||
| self.mlp_int8.enable_calibration() | ||
| _ = self.mlp_int8(self.x) | ||
| self.mlp_int8.freeze_qparams() | ||
|
|
||
| self.assertIs(self.mlp_int8._mode, Mode.QUANT) | ||
|
|
||
| # export | ||
| with tempfile.TemporaryDirectory() as td: | ||
| path = pathlib.Path(td) / "mlp.circle" | ||
| with warnings.catch_warnings(): | ||
| warnings.filterwarnings("ignore", category=UserWarning) | ||
| exported = tico.convert(self.mlp_int8, (self.x[:1],)) | ||
| exported.save(path) | ||
| self.assertTrue(path.exists()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
95 changes: 95 additions & 0 deletions
95
tico/quantization/wrapq/examples/quantize_qwen_vision_mlp.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| # Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pathlib | ||
|
|
||
| import torch | ||
| from transformers import AutoModelForVision2Seq | ||
|
|
||
| from tico.quantization import convert, prepare | ||
| from tico.quantization.config.ptq import PTQConfig | ||
| from tico.quantization.evaluation.metric import compute_peir | ||
| from tico.quantization.evaluation.utils import plot_two_outputs | ||
| from tico.quantization.wrapq.mode import Mode | ||
| from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp import ( | ||
| QuantQwen3VLVisionMLP, | ||
| ) | ||
| from tico.utils.utils import SuppressWarning | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 0. Load a Qwen3-VL model (text tower) + tokenizer | ||
| # ------------------------------------------------------------------------- | ||
| name = "Qwen/Qwen3-VL-2B-Instruct" | ||
| model = AutoModelForVision2Seq.from_pretrained( | ||
| name, | ||
| device_map="cpu", | ||
| trust_remote_code=True, | ||
| cache_dir="/mnt/storage/transformers_cache", | ||
| ) | ||
| model.eval() | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 1. Replace layer-0’s mlp with QuantQwen3VLVisionMLP | ||
| # ------------------------------------------------------------------------- | ||
| orig_mlp = model.model.visual.blocks[0].mlp | ||
| mlp_q = prepare(orig_mlp, PTQConfig()) | ||
| mlp_q.eval() | ||
| assert isinstance(mlp_q.wrapped, QuantQwen3VLVisionMLP) | ||
|
|
||
| inp_shape = (orig_mlp.intermediate_size, orig_mlp.hidden_size) | ||
| # ------------------------------------------------------------------------- | ||
| # 2. calibration | ||
| # ------------------------------------------------------------------------- | ||
| examples = [ | ||
| torch.randn(inp_shape), | ||
| torch.randn(inp_shape), | ||
| torch.randn(inp_shape), | ||
| ] | ||
|
|
||
| with torch.no_grad(): | ||
| for example in examples: | ||
| _ = mlp_q(example) | ||
|
|
||
| convert(mlp_q) | ||
| assert mlp_q._mode is Mode.QUANT, "Quantization mode should be active now." | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 3. Quick diff check (INT-sim vs FP32) | ||
| # ------------------------------------------------------------------------- | ||
| hidden = examples[0] | ||
|
|
||
| with torch.no_grad(): | ||
| int8_out = mlp_q(hidden) | ||
| fp_out = orig_mlp(hidden) | ||
|
|
||
| print("┌───────────── Quantization Error Summary ─────────────") | ||
| print(f"│ Mean |diff|: {(int8_out - fp_out).abs().mean().item():.6f}") | ||
| print(f"│ PEIR : {compute_peir(fp_out, int8_out) * 100:.6f} %") | ||
| print("└──────────────────────────────────────────────────────") | ||
| print(plot_two_outputs(fp_out, int8_out)) | ||
|
|
||
| # ------------------------------------------------------------------------- | ||
| # 4. Export the quantized block | ||
| # ------------------------------------------------------------------------- | ||
| import tico | ||
|
|
||
| save_path = pathlib.Path("qwen3vl_vision_mlp.q.circle") | ||
|
|
||
| example = torch.randn(inp_shape) | ||
|
|
||
| with SuppressWarning(UserWarning, ".*"): | ||
| cm = tico.convert(mlp_q, (example,)) | ||
| cm.save(save_path) | ||
|
|
||
| print(f"Quantized Circle model saved to {save_path.resolve()}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All changes in
mini_vqa_evalwill be removed in the final PR.