[quantization] Introduce wrapper for Qwen3VLVisionRotaryEmbedding#496
Open
dvsav wants to merge 1 commit intoSamsung:mainfrom
Open
[quantization] Introduce wrapper for Qwen3VLVisionRotaryEmbedding#496dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav wants to merge 1 commit intoSamsung:mainfrom
Conversation
This change introduces QuantQwen3VLVisionRotaryEmbedding wrapper to support post-training quantization of Qwen3VLVisionRotaryEmbedding module. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
Contributor
Author
For ReviewersBelow is the source code of # transformers/models/qwen3_vl/modeling_qwen3_vl.py
class Qwen3VLVisionRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This change introduces
QuantQwen3VLVisionRotaryEmbeddingwrapper to support post-training quantization ofQwen3VLVisionRotaryEmbeddingmodule.Why?
Qwen3VLVisionRotaryEmbeddingmodule is used in the image encoder of Qwen model.Trying to quantize
Qwen3VLVisionRotaryEmbeddingvia PTQ generates exceptionPTQQuantizer: no quantization wrapper for Qwen3VLVisionRotaryEmbedding.What
This change introduces:
QuantQwen3VLVisionRotaryEmbedding(tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_rotary_embedding.py).class TestQuantQwen3VLVisionRotaryEmbedding(test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_rotary_embedding.py) - skipped iftransformerspackage is not installed.tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_rotary_embeddingin_CORE_MODULES(tico/quantization/wrapq/wrappers/registry.py).Qwen3VLVisionRotaryEmbeddingquantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_qwen_vision_rotary_embedding.py).Unit Tests
Unit tests results with coverage information:
Coverage info (irrelevant files skipped):