From b0a5b821f5c07aaf0fae61201aaf045b038955f7 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Tue, 3 Feb 2026 09:21:03 +0300 Subject: [PATCH] [quantization] Decoder output quantization This PR ensures output of `decoder` layer is quantized. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../wrappers/llama/test_quant_decoder_layer.py | 15 +++++++++++++++ .../wrapq/wrappers/llama/quant_decoder_layer.py | 10 +++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/test/quantization/wrapq/wrappers/llama/test_quant_decoder_layer.py b/test/quantization/wrapq/wrappers/llama/test_quant_decoder_layer.py index 65377b80..8e73eca9 100644 --- a/test/quantization/wrapq/wrappers/llama/test_quant_decoder_layer.py +++ b/test/quantization/wrapq/wrappers/llama/test_quant_decoder_layer.py @@ -16,7 +16,9 @@ import unittest import torch +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType from tico.quantization.wrapq.mode import Mode from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import ( QuantLlamaDecoderLayer, @@ -94,3 +96,16 @@ def test_forward_diff(self): self.assertGreater(diff, 0.0) self.assertLess(diff, 0.5) self.assertEqual(fp_out.shape, q_out.shape) + + def test_dtype_override(self): + # mlp_residual_out is the only observer currently created in QuantLlamaDecoderLayer + # if more observers will be added to QuantLlamaDecoderLayer, + # overrides of cfg will also need to be expanded + cfg = PTQConfig( + default_dtype=DType.int(16), + overrides={ + "mlp_residual_out": {"dtype": DType.uint(8)}, + }, + ) + qcustom = QuantLlamaDecoderLayer(self.fp_layer, qcfg=cfg) + self.assertEqual(qcustom.obs_mlp_residual_out.dtype, DType.uint(8)) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py index d68d4617..404863e6 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py @@ -93,6 +93,9 @@ def __init__( assert hasattr(fp_layer, "post_attention_layernorm") and isinstance( fp_layer.post_attention_layernorm, torch.nn.Module ) + + self.obs_mlp_residual_out = self._make_obs("mlp_residual_out") + self.input_layernorm = PTQWrapper( fp_layer.input_layernorm, qcfg=input_norm, fp_name=f"{fp_name}.input_norm" ) @@ -163,7 +166,11 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states + + hidden_states = ( + residual + hidden_states + ) # residual/hidden_states are assumed to be quantized + hidden_states = self._fq(hidden_states, self.obs_mlp_residual_out) # Return type policy: # - If use_cache: always return (hidden_states, present_key_value) @@ -182,3 +189,4 @@ def forward( def _all_observers(self): yield from self.self_attn._all_observers() yield from self.mlp._all_observers() + yield self.obs_mlp_residual_out