Skip to content

[quantization] Fold scale#465

Merged
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:merge_scale
Feb 5, 2026
Merged

[quantization] Fold scale#465
mhs4670go merged 1 commit intoSamsung:mainfrom
stamalakhov:merge_scale

Conversation

@stamalakhov
Copy link
Contributor

This PR folds static scale to k_proj to reduce number of Mul operations.

Running `quantize_full_qmodel_with_gptq.py` from #436 for `SmolLM`:

┌── Wikitext-2 original test perplexity ─────────────
│ FP32 :    17.40
└───────────────────────────────────────────
Applying GPTQ …
Quantizing layers: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [01:01<00:00,  2.04s/layer]
Wrapping layers with PTQWrapper …                                                                                                                                                                                    
Calibrating PTQ obeservers…
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:42<00:00,  3.02it/s]

Calculating perplexities …
PPL:  99%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 148/149 [00:37<00:00,  3.96it/s]

┌── Wikitext-2 test perplexity ─────────────
│ int16 :    27.83
└───────────────────────────────────────────

which is almost the same as before

./ccex test -k quantization.wrapq.wrappers.llama.test_quant_attn

RUN unit tests with -k quantization.wrapq.wrappers.llama.test_quant_attn ...
test_cache_grows_across_multiple_single_token_steps (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_cache_mock_object_update (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_cache_tuple_concat_prefill_then_decode (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_forward_diff (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_forward_with_float_attention_mask (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_mask_slicing_with_cache_q_len_lt_k_len (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_mode_transitions (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_per_projection_override (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok
test_use_cache_no_past_equivalence (quantization.wrapq.wrappers.llama.test_quant_attn.TestQuantLlamaAttention) ... ok

----------------------------------------------------------------------
Ran 9 tests in 0.132s

OK

which includes test_forward_diff with estimation of output difference

Draft: #436
TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov requested a review from mhs4670go February 4, 2026 14:29
@stamalakhov stamalakhov self-assigned this Feb 4, 2026
Comment on lines 94 to 96
self.k_proj.wrapped.module.weight = torch.nn.Parameter(
self.k_proj.wrapped.module.weight * scale_t
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about go with in-place updates? It would be safer in PyTorch environment.

with torch.no_grad():
  lin = self.k_proj.wrapped.module.weight
  lin.weight.mul_(scale_t)
  if lin.bias is not None:
    lin.bias.mul_(scale_t)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go
Fixed.

This PR folds static scale to `k_proj` to reduce  number of `Mul` operations.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
Copy link
Contributor

@mhs4670go mhs4670go left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@mhs4670go mhs4670go merged commit f1e2c2f into Samsung:main Feb 5, 2026
7 checks passed
@stamalakhov stamalakhov deleted the merge_scale branch February 5, 2026 06:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants