Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/quantization/wrapq/wrappers/llama/test_quant_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
import unittest

import torch
from tico.quantization.config.ptq import PTQConfig

from tico.quantization.wrapq.dtypes import DType
from tico.quantization.wrapq.mode import Mode
from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import (
QuantLlamaDecoderLayer,
Expand Down Expand Up @@ -94,3 +96,16 @@ def test_forward_diff(self):
self.assertGreater(diff, 0.0)
self.assertLess(diff, 0.5)
self.assertEqual(fp_out.shape, q_out.shape)

def test_dtype_override(self):
# mlp_residual_out is the only observer currently created in QuantLlamaDecoderLayer
# if more observers will be added to QuantLlamaDecoderLayer,
# overrides of cfg will also need to be expanded
cfg = PTQConfig(
default_dtype=DType.int(16),
overrides={
"mlp_residual_out": {"dtype": DType.uint(8)},
},
)
qcustom = QuantLlamaDecoderLayer(self.fp_layer, qcfg=cfg)
self.assertEqual(qcustom.obs_mlp_residual_out.dtype, DType.uint(8))
10 changes: 9 additions & 1 deletion tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def __init__(
assert hasattr(fp_layer, "post_attention_layernorm") and isinstance(
fp_layer.post_attention_layernorm, torch.nn.Module
)

self.obs_mlp_residual_out = self._make_obs("mlp_residual_out")

self.input_layernorm = PTQWrapper(
fp_layer.input_layernorm, qcfg=input_norm, fp_name=f"{fp_name}.input_norm"
)
Expand Down Expand Up @@ -163,7 +166,11 @@ def forward(
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

hidden_states = (
residual + hidden_states
) # residual/hidden_states are assumed to be quantized
hidden_states = self._fq(hidden_states, self.obs_mlp_residual_out)

# Return type policy:
# - If use_cache: always return (hidden_states, present_key_value)
Expand All @@ -182,3 +189,4 @@ def forward(
def _all_observers(self):
yield from self.self_attn._all_observers()
yield from self.mlp._all_observers()
yield self.obs_mlp_residual_out