Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
595 changes: 595 additions & 0 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tico/quantization/wrapq/examples/quantize_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase


# Token-budget presets for activation calibration
TOKENS: dict[str, int] = {
# Smoke test (<1 min turnaround on CPU/GPU)
Expand All @@ -65,6 +64,7 @@
TRAIN_SPLIT = "train"
TEST_SPLIT = "test"


# -------------------------------------------------------------------------
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
# -------------------------------------------------------------------------
Expand Down
47 changes: 47 additions & 0 deletions tico/quantization/wrapq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,56 @@ def _wrap_supported(
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
"""
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
try:
return PTQWrapper(root, qcfg=qcfg, fp_name="model")
except NotImplementedError as e:
print("no special wrapper for model, wrappig using general case")

# Case A: HuggingFace-style transformers: model.model.layers
lm = getattr(root, "model", None)

embeddings = (
getattr(lm, "embed_tokens", None)
if isinstance(lm.embed_tokens, nn.Module)
else None
)
if isinstance(embeddings, nn.Module):
child_scope = "model.embeddings"
child_cfg = qcfg.child(child_scope)
wrapped = self._try_wrap(
embeddings,
child_cfg,
fp_name=child_scope,
raise_on_fail=self.strict_wrap,
)
lm.embed_tokens = wrapped # type: ignore[union-attr]

model_norm = (
getattr(lm, "norm", None) if isinstance(lm.norm, nn.Module) else None
)
if isinstance(model_norm, nn.Module):
child_scope = "model.norm"
child_cfg = qcfg.child(child_scope)
wrapped = self._try_wrap(
model_norm,
child_cfg,
fp_name=child_scope,
raise_on_fail=self.strict_wrap,
)
lm.norm = wrapped # type: ignore[union-attr]

lm_head = getattr(root, "lm_head", None) if isinstance(lm, nn.Module) else None
if isinstance(lm_head, nn.Module):
child_scope = "lm_head"
child_cfg = qcfg.child(child_scope)
wrapped = self._try_wrap(
lm_head,
child_cfg,
fp_name=child_scope,
raise_on_fail=self.strict_wrap,
)
root.lm_head = wrapped

layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
if isinstance(layers, nn.ModuleList):
new_list = nn.ModuleList()
Expand Down
14 changes: 10 additions & 4 deletions tico/quantization/wrapq/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,17 @@ def perplexity(
device = _resolve_device(device, model)
input_ids_full = input_ids_full.to(device)


if max_length is None:
assert hasattr(model, "config")
model_config = model.config
if hasattr(model.config, "text_config"):
model_config = model.config.text_config
if hasattr(model, "config"):
assert hasattr(model, "config")
model_config = model.config
else:
assert hasattr(model.wrapped, "config")
model_config = model.wrapped.config

if hasattr(model_config, "text_config"):
model_config = model_config.text_config
assert hasattr(model_config, "max_position_embeddings")
assert isinstance(model_config.max_position_embeddings, int)
max_length = model_config.max_position_embeddings
Expand Down
15 changes: 7 additions & 8 deletions tico/quantization/wrapq/wrappers/llama/quant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def __init__(
)

# Constant scale (1/√d)
scale_t = torch.tensor(
float(getattr(fp_attn, "scaling", self.head_dim**-0.5))
)
scale_t = torch.tensor(float(getattr(fp_attn, "scaling", self.head_dim**-0.5)))
# merge scale_t to k_proj, (otherwise merge it to q_proj)
with torch.no_grad():
lin = self.k_proj.wrapped.module
Expand Down Expand Up @@ -161,8 +159,9 @@ def _concat_kv(
return k, v

def _apply_rope(self, q, k, cos, sin, unsqueeze_dim: int = 1):
cos_u = cos.unsqueeze(unsqueeze_dim)
sin_u = sin.unsqueeze(unsqueeze_dim)
cos_u, sin_u = cos, sin
# cos_u = cos.unsqueeze(unsqueeze_dim)
# sin_u = sin.unsqueeze(unsqueeze_dim)

q_half = self._rot(
q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
Expand Down Expand Up @@ -201,8 +200,8 @@ def forward(

# Rope tables
cos, sin = position_embeddings
cos = self._fq(cos, self.obs_cos)
sin = self._fq(sin, self.obs_sin)
# cos = self._fq(cos, self.obs_cos)
# sin = self._fq(sin, self.obs_sin)
q_rot, k_rot = self._apply_rope(q, k, cos, sin, unsqueeze_dim=1)

# --- build/update KV for attention & present_key_value -------------
Expand All @@ -228,7 +227,7 @@ def forward(
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
hidden_states.device
)
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
attention_mask = self._fq(attention_mask, self.obs_causal_mask)

attn_weights_parts = []
attn_out_parts = []
Expand Down
45 changes: 34 additions & 11 deletions tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def __init__(
qcfg=post_attention_layernorm,
fp_name=f"{fp_name}.post_attention_layernorm",
)
self.obs_causal_mask = self._make_obs("causal_mask")
self.obs_cos = self._make_obs("cos")
self.obs_sin = self._make_obs("sin")

# Static causal mask template ---------------------------------------
assert hasattr(fp_layer.self_attn, "config") and hasattr(
Expand Down Expand Up @@ -166,6 +169,21 @@ def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor:
assert isinstance(self.causal_mask_template, torch.Tensor)
return self.causal_mask_template[..., :seq_len, :seq_len].to(device)

def get_attention_mask_for(self, x):
L = x.size(1)
attention_mask = self._slice_causal(L, x.device)
return attention_mask

def get_position_embeddings_for(self, hidden_states):
return (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
)

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -186,17 +204,16 @@ def forward(
hidden_states = self.input_layernorm(hidden_states)

if attention_mask is None or attention_mask.dtype == torch.bool:
L = hidden_states.size(1)
attention_mask = self._slice_causal(L, hidden_states.device)

position_embeddings = (
self.rope_cos_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
self.rope_sin_template.to(
dtype=hidden_states.dtype, device=hidden_states.device
),
)
attention_mask = self.get_attention_mask_for(hidden_states)
attention_mask = self._fq(attention_mask, self.obs_causal_mask)

if position_embeddings is None:
position_embeddings = self.get_position_embeddings_for(hidden_states)
cos, sin = position_embeddings
position_embeddings = (
self._fq(cos.unsqueeze(1), self.obs_cos),
self._fq(sin.unsqueeze(1), self.obs_sin),
)

attn_out = self.self_attn(
hidden_states=hidden_states,
Expand Down Expand Up @@ -242,6 +259,12 @@ def forward(

# No local observers; just recurse into children
def _all_observers(self):
yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin)
yield from self.self_attn._all_observers()
yield from self.mlp._all_observers()
yield self.obs_mlp_residual_out

def copy_quantizers(self, model):
self.obs_causal_mask = model.obs_causal_mask
self.obs_cos = model.obs_cos
self.obs_sin = model.obs_sin
Loading
Loading