From 8030686f8cae9ded93942689b1c6ed0965d6681c Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Mon, 2 Feb 2026 10:35:58 +0300 Subject: [PATCH] [quantization] Introduce a wrapper for nn.Embedding This commit introduces a wrapper for nn.Embedding. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../wrapq/wrappers/nn/test_quant_embedding.py | 82 +++++++++++++++++++ .../wrapq/wrappers/nn/quant_embedding.py | 73 +++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 test/quantization/wrapq/wrappers/nn/test_quant_embedding.py create mode 100644 tico/quantization/wrapq/wrappers/nn/quant_embedding.py diff --git a/test/quantization/wrapq/wrappers/nn/test_quant_embedding.py b/test/quantization/wrapq/wrappers/nn/test_quant_embedding.py new file mode 100644 index 00000000..8f06daf1 --- /dev/null +++ b/test/quantization/wrapq/wrappers/nn/test_quant_embedding.py @@ -0,0 +1,82 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +import torch.nn.functional as F +from tico.quantization.config.ptq import PTQConfig + +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.nn.quant_embedding import QuantEmbedding + + +class TestQuantEmbedding(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + num_embeddings = 4 # vocab_size + embedding_dim = 2 # inner_dim + seq_len = 128 + self.fp32 = torch.nn.Embedding(num_embeddings, embedding_dim) + self.x = torch.randint(0, num_embeddings, (seq_len, num_embeddings)) + + self.q_emb = QuantEmbedding(self.fp32) + + def test_mode_transitions(self): + self.assertIs(self.q_emb._mode, Mode.NO_QUANT) + + # Calibration (re-collect static weight range right here) + self.q_emb.enable_calibration() + _ = self.q_emb(self.x) + self.assertIs(self.q_emb._mode, Mode.CALIB) + + self.q_emb.freeze_qparams() + self.assertIs(self.q_emb._mode, Mode.QUANT) + + def test_quantised_output_close(self): + self.q_emb.enable_calibration() + _ = self.q_emb(self.x) + self.q_emb.freeze_qparams() + + with torch.no_grad(): + q_out = self.q_emb(self.x) + fp_out = F.embedding(self.x, self.fp32.weight) + + diff = (fp_out - q_out).abs().mean().item() + self.assertGreater(diff, 0.0) + self.assertLess(diff, 0.4) + + def test_weight_stats_survive(self): + self.q_emb.enable_calibration() + self.q_emb.weight_obs.compute_qparams() + assert hasattr(self.q_emb.weight_obs, "_cached_scale") + pre_scale = self.q_emb.weight_obs._cached_scale.clone() + + # calibration cycle + self.q_emb.enable_calibration() + self.q_emb.freeze_qparams() + + post_scale = self.q_emb.weight_obs._cached_scale + self.assertTrue(torch.allclose(pre_scale, post_scale)) + + def test_dtype_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "act_out": {"dtype": DType.uint(4)}, + }, + ) + qcustom = QuantEmbedding(self.fp32, qcfg=cfg) + self.assertEqual(qcustom.act_out_obs.dtype, DType.uint(4)) diff --git a/tico/quantization/wrapq/wrappers/nn/quant_embedding.py b/tico/quantization/wrapq/wrappers/nn/quant_embedding.py new file mode 100644 index 00000000..79ca99ac --- /dev/null +++ b/tico/quantization/wrapq/wrappers/nn/quant_embedding.py @@ -0,0 +1,73 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig + +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.qscheme import QScheme +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register("torch.nn.Embedding") +class QuantEmbedding(QuantModuleBase): + """Per-channel weight fake-quant, eager-output activation fake-quant.""" + + def __init__( + self, + fp: nn.Embedding, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None + ): + super().__init__(qcfg, fp_name=fp_name) + self.weight_obs = self._make_obs( + "weight", + qscheme=QScheme.PER_CHANNEL_ASYMM, # tensorwise quantization breaks the model + channel_axis=0, # weight ~ (vocab_size, inner_dim) so that weight_scales ~ (1, vocab_size) + ) + self.act_out_obs = self._make_obs("act_out") + self.module = fp + + def enable_calibration(self) -> None: + super().enable_calibration() + # immediately capture the fixed weight range + self.weight_obs.collect(self.module.weight) + + def forward(self, x: torch.Tensor): + + # x is supposed to be in int64 form so no quantization of activations is needed + w = self.module.weight + if self._mode is Mode.QUANT: + w = self.weight_obs.fake_quant(w) + + y = torch.nn.functional.embedding( + x, + w, + self.module.padding_idx, + self.module.max_norm, + self.module.norm_type, + self.module.scale_grad_by_freq, + self.module.sparse, + ) + + return self._fq(y, self.act_out_obs) + + def _all_observers(self): + return (self.act_out_obs, self.weight_obs)