diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py new file mode 100644 index 00000000..ea6bbbc8 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_text_rotary_embedding.py @@ -0,0 +1,377 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import unittest + +import torch +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embedding import ( + QuantQwen3VLTextRotaryEmbedding, +) + + +trans_spec = importlib.util.find_spec("transformers") +skip_msg = "transformers not installed — skipping Qwen3VLTextRotaryEmbedding tests" + + +@unittest.skipUnless(trans_spec, skip_msg) +class TestQuantQwen3VLTextRotaryEmbedding(unittest.TestCase): + fp_rope: torch.nn.Module + hidden_size: int + head_dim: int + + @classmethod + def setUpClass(cls): + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLTextConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLTextRotaryEmbedding, + ) + + # Use smaller config for testing + cfg = Qwen3VLTextConfig( + hidden_size=32, # Smaller for testing + num_attention_heads=4, + max_position_embeddings=512, + ) + cls.fp_rope = Qwen3VLTextRotaryEmbedding(cfg) + cls.hidden_size = cfg.hidden_size + cls.head_dim = ( + getattr(cfg, "head_dim", None) or cfg.hidden_size // cfg.num_attention_heads + ) # 8 + + def test_mode_transitions(self): + """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + self.assertIs(q_rope._mode, Mode.NO_QUANT) + + q_rope.enable_calibration() + self.assertIs(q_rope._mode, Mode.CALIB) + + # Run forward pass during calibration + x = torch.randn(2, 64, self.head_dim) + position_ids = torch.arange(64).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + self.assertIs(q_rope._mode, Mode.QUANT) + + def test_quantised_output_close(self): + """ + Test that quantized outputs (cos, sin) are acceptably close to FP32 reference. + """ + torch.manual_seed(42) + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + # Calibrate with different sequence lengths + for seq_len in [32, 64, 128]: + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + + seq_len = 64 + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + + with torch.no_grad(): + q_cos, q_sin = q_rope(x, position_ids) + fp_cos, fp_sin = self.fp_rope(x, position_ids) + + diff_cos = (fp_cos - q_cos).abs().mean().item() + diff_sin = (fp_sin - q_sin).abs().mean().item() + + self.assertGreater(diff_cos, 0.0) # not identical + self.assertGreater(diff_sin, 0.0) + self.assertLess(diff_cos, 0.4) # acceptably close + self.assertLess(diff_sin, 0.4) + self.assertEqual(fp_cos.shape, q_cos.shape) + self.assertEqual(fp_sin.shape, q_sin.shape) + + def test_output_shape(self): + """ + Test that output shapes are correct: (batch_size, seq_len, head_dim) + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + seq_len = 64 + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + + with torch.no_grad(): + q_cos, q_sin = q_rope(x, position_ids) + + expected_shape = (2, seq_len, self.head_dim) + self.assertEqual(q_cos.shape, expected_shape) + self.assertEqual(q_sin.shape, expected_shape) + + def test_output_range(self): + """ + Test that cos and sin outputs are in valid range [-1, 1]. + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + seq_len = 64 + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + + with torch.no_grad(): + q_cos, q_sin = q_rope(x, position_ids) + + # Check ranges (with some tolerance for quantization error) + self.assertLessEqual(q_cos.max(), 1.01) + self.assertGreaterEqual(q_cos.min(), -1.01) + self.assertLessEqual(q_sin.max(), 1.01) + self.assertGreaterEqual(q_sin.min(), -1.01) + + def test_different_sequence_lengths(self): + """ + Test that quantization works correctly with different sequence lengths. + Calibrate with maximum length to cover full range. + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + # Calibrate with MAXIMUM length + max_seq_len = 256 + for _ in range(3): + x = torch.randn(2, max_seq_len, self.head_dim) + position_ids = torch.arange(max_seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + + # Test with different lengths + for seq_len in [32, 64, 128, 256]: + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + + with torch.no_grad(): + q_cos, q_sin = q_rope(x, position_ids) + fp_cos, fp_sin = self.fp_rope(x, position_ids) + + diff_cos = (fp_cos - q_cos).abs().mean().item() + diff_sin = (fp_sin - q_sin).abs().mean().item() + + self.assertLess(diff_cos, 0.4) + self.assertLess(diff_sin, 0.4) + self.assertEqual(q_cos.shape[0], 2) + self.assertEqual(q_cos.shape[1], seq_len) + self.assertEqual(q_cos.shape[2], self.head_dim) + + def test_dtype_override(self): + """ + PTQConfig overrides should affect the observers. + """ + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "cos": {"dtype": DType.uint(4)}, + "sin": {"dtype": DType.uint(4)}, + }, + ) + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope, qcfg=cfg) + + self.assertEqual(q_rope.obs_cos.dtype, DType.uint(4)) + self.assertEqual(q_rope.obs_sin.dtype, DType.uint(4)) + + def test_activation_stats_collected(self): + """ + Test that activation statistics are properly collected during calibration. + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + # Run forward pass to collect stats + seq_len = 64 + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + # Check that observers have collected stats + self.assertTrue( + q_rope.obs_cos.has_qparams or q_rope.obs_cos.min_val.numel() > 0 + ) + self.assertTrue( + q_rope.obs_sin.has_qparams or q_rope.obs_sin.min_val.numel() > 0 + ) + + # Freeze and check qparams exist + q_rope.freeze_qparams() + self.assertTrue(q_rope.obs_cos.has_qparams) + self.assertTrue(q_rope.obs_sin.has_qparams) + + def test_observer_count(self): + """ + Test that the wrapper has the correct number of observers. + 6 observers: inv_freq, freqs, freqs_mrope, emb, cos, sin + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + + observers = list(q_rope._all_observers()) + self.assertEqual(len(observers), 6) + + def test_registration_in_registry(self): + """ + Test that Qwen3VLTextRotaryEmbedding is properly registered. + """ + from tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embedding import ( + QuantQwen3VLTextRotaryEmbedding, + ) + from tico.quantization.wrapq.wrappers.registry import lookup + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLTextRotaryEmbedding, + ) + + wrapper_cls = lookup(Qwen3VLTextRotaryEmbedding) + self.assertIs(wrapper_cls, QuantQwen3VLTextRotaryEmbedding) + + def test_no_learnable_parameters(self): + """ + Test that the wrapper has no learnable parameters (only buffers). + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + + # Check that there are no parameters + params = list(q_rope.parameters()) + self.assertEqual(len(params), 0) + + # Check that inv_freq is a buffer, not a parameter + self.assertIsInstance(q_rope.inv_freq, torch.Tensor) + self.assertIn("inv_freq", q_rope._buffers) + + def test_cos_sin_relationship(self): + """ + Test that cos² + sin² = 1 (unit circle property). + Quantization error should be small enough to preserve this property approximately. + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + seq_len = 64 + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + + with torch.no_grad(): + q_cos, q_sin = q_rope(x, position_ids) + + # Check unit circle property + unit_circle = q_cos.pow(2) + q_sin.pow(2) + # Allow some deviation due to quantization error + self.assertGreaterEqual(unit_circle.min(), 0.95) + self.assertLessEqual(unit_circle.max(), 1.05) + + def test_different_batch_sizes(self): + """ + Test that quantization works correctly with different batch sizes. + """ + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + q_rope.enable_calibration() + + seq_len = 64 + # Calibrate with batch size 2 + for _ in range(3): + x = torch.randn(2, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + _ = q_rope(x, position_ids) + + q_rope.freeze_qparams() + + # Test with different batch sizes + for batch_size in [1, 2, 4]: + x = torch.randn(batch_size, seq_len, self.head_dim) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch_size, -1) + + with torch.no_grad(): + q_cos, q_sin = q_rope(x, position_ids) + fp_cos, fp_sin = self.fp_rope(x, position_ids) + + diff_cos = (fp_cos - q_cos).abs().mean().item() + diff_sin = (fp_sin - q_sin).abs().mean().item() + + self.assertLess(diff_cos, 0.4) + self.assertLess(diff_sin, 0.4) + self.assertEqual(q_cos.shape[0], batch_size) + + def test_mrope_semantic_equivalence(self): + """ + Test that QuantQwen3VLTextRotaryEmbedding.apply_interleaved_mrope produces identical output + to the original Qwen3VLTextRotaryEmbedding.apply_interleaved_mrope. + """ + torch.manual_seed(42) + + # Create test freqs tensor + batch_size = 2 + seq_len = 64 + head_dim = self.head_dim + freqs = torch.randn(3, batch_size, seq_len, head_dim // 2) + + # Call original implementation + freqs_t_original = self.fp_rope.apply_interleaved_mrope( + freqs, self.fp_rope.mrope_section + ) + + # Call new implementation + q_rope = QuantQwen3VLTextRotaryEmbedding(self.fp_rope) + freqs_t_new = q_rope.apply_interleaved_mrope(freqs, q_rope.mrope_section) + + # Compare outputs + self.assertEqual(freqs_t_original.shape, freqs_t_new.shape) + + # Check exact equality (should be identical) + torch.testing.assert_close( + freqs_t_original, + freqs_t_new, + rtol=1e-5, + atol=1e-5, + msg="MRoPE implementations produce different outputs", + ) + + # Also check with different input shapes + test_configs = [ + (1, 32, head_dim), # Single sample, shorter sequence + (4, 128, head_dim), # Larger batch, longer sequence + (2, 256, head_dim), # Very long sequence + ] + + for bs, sl, hd in test_configs: + freqs = torch.randn(3, bs, sl, hd // 2) + + freqs_t_original = self.fp_rope.apply_interleaved_mrope( + freqs, self.fp_rope.mrope_section + ) + freqs_t_new = q_rope.apply_interleaved_mrope(freqs, q_rope.mrope_section) + + self.assertEqual(freqs_t_original.shape, freqs_t_new.shape) + self.assertTrue( + torch.equal(freqs_t_original, freqs_t_new), + f"MRoPE implementations differ for shape (3, {bs}, {sl}, {hd//2})", + ) diff --git a/tico/quantization/wrapq/examples/qwen/quantize_qwen_text_rotary_embedding.py b/tico/quantization/wrapq/examples/qwen/quantize_qwen_text_rotary_embedding.py new file mode 100644 index 00000000..daa441c5 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_qwen_text_rotary_embedding.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import sys + +import torch +import torch.nn as nn + +import tico +import tico.quantization +import tico.quantization.config.ptq + +# Check if transformers is available +trans_spec = importlib.util.find_spec("transformers") +if trans_spec is None: + print( + "Error: transformers package not installed. Cannot test Qwen3VLTextRotaryEmbedding." + ) + sys.exit(1) + +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLTextConfig +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextRotaryEmbedding + + +def generate_calibration_data(batch_size: int, sequence_lengths: list, head_dim: int): + """Generate calibration data for PTQ""" + calibration_data = [] + for _ in range(batch_size): + for seq_len in sequence_lengths: + # x tensor: shape (batch_size, seq_len, head_dim) + x = torch.randn(2, seq_len, head_dim) + # position_ids: shape (batch_size, seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(2, -1) + calibration_data.append((x, position_ids)) + return calibration_data + + +def main(): + # Create the text rotary embedding model + # Use typical Qwen3-VL dimensions + cfg = Qwen3VLTextConfig( + hidden_size=2048, # Typical for Qwen3-VL 2B + num_attention_heads=16, + max_position_embeddings=4096, + ) + model = Qwen3VLTextRotaryEmbedding(cfg) + model.eval() + + # Qwen3VLTextRotaryEmbedding( + # (inv_freq): Buffer [dim/2] # dim=128 for head_dim=64 + # ) + head_dim = ( + getattr(cfg, "head_dim", None) or cfg.hidden_size // cfg.num_attention_heads + ) + print( + f"Config: hidden_size={cfg.hidden_size}, num_attention_heads={cfg.num_attention_heads}" + ) + print(f"head_dim={head_dim}, inv_freq.shape={model.inv_freq.shape}") + + # Generate calibration data + # Calibrate with various sequence lengths to capture full dynamic range + # Important: Use maximum sequence length that will be used at inference + calibration_data = generate_calibration_data( + batch_size=20, + sequence_lengths=[128, 256, 512, 1024, 2048, 4096], + head_dim=head_dim, + ) + + # Configure PTQ + ptq_config = tico.quantization.config.ptq.PTQConfig() + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Calibrate the model (collect statistics) + with torch.no_grad(): + for i, (x, position_ids) in enumerate(calibration_data): + _ = prepared_model(x, position_ids) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Convert to Circle format + # example_inputs: tuple containing (x, position_ids) + example_seq_len = 256 + example_x = torch.randn(2, example_seq_len, head_dim) + example_position_ids = torch.arange(example_seq_len).unsqueeze(0).expand(2, -1) + circle_model = tico.convert(quantized_model, (example_x, example_position_ids)) + + # Save the Circle model + filename = "quantized_text_rotary_embedding.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_rotary_embedding.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_rotary_embedding.py new file mode 100644 index 00000000..7e37deae --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_rotary_embedding.py @@ -0,0 +1,223 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRotaryEmbedding", +) +class QuantQwen3VLTextRotaryEmbedding(QuantModuleBase): + """ + Quantization wrapper for Qwen3VLTextRotaryEmbedding module. + + This module generates MRoPE (multimodal rotary positional embeddings) cos/sin values + for text attention. All floating-point computations are quantized. + """ + + def __init__( + self, + fp_rope: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + + assert hasattr(fp_rope, "config") + assert hasattr(fp_rope, "inv_freq") + assert hasattr(fp_rope, "mrope_section") + assert hasattr(fp_rope, "attention_scaling") + + self.config = fp_rope.config + self.mrope_section = fp_rope.mrope_section + self.attention_scaling = fp_rope.attention_scaling + + # Copy the inv_freq buffer to the wrapper + self.register_buffer("inv_freq", fp_rope.inv_freq.clone()) + + # Observers for all intermediate tensor values + mk = self._make_obs + self.obs_inv_freq = mk("inv_freq") # Constant buffer + self.obs_freqs = mk("freqs") # After matrix multiplication + self.obs_freqs_mrope = mk("freqs_mrope") # After MRoPE + self.obs_emb = mk("emb") # After concatenation + self.obs_cos = mk("cos") # Final cosine output + self.obs_sin = mk("sin") # Final sine output + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """ + Forward pass with fake quantization. + + Args: + x: Input tensor (used only for device/dtype) + position_ids: Position identifiers (batch_size, seq_len) + + Returns: + (cos, sin): Tuple of rotary embeddings + """ + # Expand position_ids for MRoPE + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + # Expand inv_freq for batched computation + # Shape: (3, batch_size, head_dim//2, 1) + inv_freq_expanded = ( + self.inv_freq[None, None, :, None] + .float() + .expand(3, position_ids.shape[1], -1, 1) + ) + inv_freq_expanded = self._fq(inv_freq_expanded, self.obs_inv_freq) + + # Reshape position_ids for matrix multiplication + # Shape: (3, batch_size, 1, seq_len) + position_ids_expanded = position_ids[:, :, None, :].float() + + # Compute frequencies via matrix multiplication + # Shape: (3, batch_size, seq_len, head_dim//2) + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + + # Force float32 for precision + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose( + 2, 3 + ) + freqs = self._fq(freqs, self.obs_freqs) + + # Apply interleaved MRoPE + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + freqs = self._fq(freqs, self.obs_freqs_mrope) + + # Concatenate frequencies (duplicate for sin/cos) + # Shape: (batch_size, seq_len, head_dim) + emb = torch.cat((freqs, freqs), dim=-1) + emb = self._fq(emb, self.obs_emb) + + # Compute cos and sin + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + # Quantize final outputs + cos = self._fq(cos, self.obs_cos) + sin = self._fq(sin, self.obs_sin) + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """ + Apply interleaved MRoPE to 3D rotary embeddings. + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THWTHWTHW...TT], preserving frequency continuity. + + Args: + freqs: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + + Returns: + freqs_t: (bs, seq_len, head_dim // 2) + + Design Note: + This implementation is using slice_copy, index_select, and cat + to avoid yet unsupported slice_scatter with step=3 operation and + to avoid unsupported in-place operator index_put.default. + """ + # Start with T dimension (will keep some, replace some) + freqs_t_base = freqs[0] + + # For each dimension (H, W), extract frequency bands to be interleaved + h_w_bands = [] + + for dim, offset in enumerate((1, 2), start=1): # H, W dimensions + length = mrope_section[dim] * 3 + indices = torch.arange(offset, length, 3, device=freqs.device) + + # Select frequency bands from H/W dimensions + # freqs[dim] has shape (bs, seq_len, head_dim//2) + # index_select on last dim: (bs, seq_len, num_selected) + freqs_bands = freqs[dim].index_select(dim=-1, index=indices) + h_w_bands.append(freqs_bands) + + # Now we need to build the interleaved output + # Original T dimension has indices 0-63 + # We want to replace specific indices with H/W bands + + # The interleaving pattern: T0, H1, W2, T3, T4, H5, W6, T7, ... + # Where T, H, W bands follow the pattern from mrope_section + + # Build the output by slicing and concatenating + # Strategy: Slice T dimension into chunks, insert H/W bands, concatenate + + chunks = [] + pos = 0 + + # Total length in the last dimension + total_len = freqs_t_base.shape[-1] + + for i in range(total_len): + # Determine which dimension this position belongs to + # Pattern: T, H, W, T, T, H, W, T, ... + mod = i % 3 + + if mod == 0: + # T dimension position - take from T + # Slice just this index from T + chunk = freqs_t_base[..., i : i + 1] + chunks.append(chunk) + elif mod == 1: + # H dimension position - take from H + # Calculate which band this is + band_idx = (i - 1) // 3 + if band_idx < h_w_bands[0].shape[-1]: + chunk = h_w_bands[0][..., band_idx : band_idx + 1] + chunks.append(chunk) + else: + # Fallback to T if out of bounds + chunk = freqs_t_base[..., i : i + 1] + chunks.append(chunk) + else: # mod == 2 + # W dimension position - take from W + band_idx = (i - 2) // 3 + if band_idx < h_w_bands[1].shape[-1]: + chunk = h_w_bands[1][..., band_idx : band_idx + 1] + chunks.append(chunk) + else: + # Fallback to T if out of bounds + chunk = freqs_t_base[..., i : i + 1] + chunks.append(chunk) + + # Concatenate all chunks + freqs_t = torch.cat(chunks, dim=-1) + + return freqs_t + + def _all_observers(self): + """Yield all observers.""" + yield from ( + self.obs_inv_freq, + self.obs_freqs, + self.obs_freqs_mrope, + self.obs_emb, + self.obs_cos, + self.obs_sin, + ) diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 8eb896bd..852b8cbe 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -45,6 +45,7 @@ "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_attn", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_mlp", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_text_rotary_embedding", # add future core wrappers here )