From 28a5bfeabdc71f540947e023aa56a5fc446e12d1 Mon Sep 17 00:00:00 2001 From: seongwoo Date: Fri, 13 Feb 2026 17:03:39 +0900 Subject: [PATCH] [DRAFT] Unroll all heads This draft unrolls all heads. TICO-DCO-1.0-Signed-off-by: seongwoo --- .../wrapq/wrappers/llama/quant_attn.py | 133 +++++++++--------- 1 file changed, 68 insertions(+), 65 deletions(-) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn.py b/tico/quantization/wrapq/wrappers/llama/quant_attn.py index babdeed2..d40152a1 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn.py @@ -13,7 +13,7 @@ # limitations under the License. import copy -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch import torch.nn as nn @@ -52,6 +52,7 @@ def __init__( cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads ) self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads + self.n_kv = cfg.num_key_value_heads # ---- Wrap q k v o projections via PTQWrapper --------------- q_cfg = qcfg.child("q_proj") if qcfg else None @@ -139,7 +140,7 @@ def __init__( mask.triu_(1) self.register_buffer("causal_mask_template", mask, persistent=False) - def _rot(self, t, o_x1, o_x2, o_neg, o_cat): + def _rot(self, t, o_x1, o_x2, o_cat): x1, x2 = torch.chunk(t, 2, dim=-1) x1 = self._fq(x1, o_x1) x2 = self._fq(x2, o_x2) @@ -151,34 +152,25 @@ def _concat_kv( past: Optional[Tuple[torch.Tensor, torch.Tensor]], k_new: torch.Tensor, v_new: torch.Tensor, + h_idx: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Concat along sequence dim (dim=2): (B, n_kv, S, H).""" + """Concat along sequence dim (dim=1): (B, S, H) given head index.""" if past is None: return k_new, v_new past_k, past_v = past - k = torch.cat([past_k, k_new], dim=2) - v = torch.cat([past_v, v_new], dim=2) + k = torch.cat([past_k[:, h_idx, :, :], k_new], dim=1) + v = torch.cat([past_v[:, h_idx, :, :], v_new], dim=1) return k, v - def _apply_rope(self, q, k, cos, sin, unsqueeze_dim: int = 1): - cos_u = cos.unsqueeze(unsqueeze_dim) - sin_u = sin.unsqueeze(unsqueeze_dim) - - q_half = self._rot( - q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat - ) - q_cos = self._fq(q * cos_u, self.obs_q_cos) - q_sin = self._fq(q_half * sin_u, self.obs_q_sin) - q_rot = self._fq(q_cos + q_sin, self.obs_q_rot) - - k_half = self._rot( - k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat - ) - k_cos = self._fq(k * cos_u, self.obs_k_cos) - k_sin = self._fq(k_half * sin_u, self.obs_k_sin) - k_rot = self._fq(k_cos + k_sin, self.obs_k_rot) + def _apply_rope( + self, t, cos, sin, obs_x1, obs_x2, obs_cat, obs_cos, obs_sin, obs_rot + ): + t_half = self._rot(t, obs_x1, obs_x2, obs_cat) + t_cos = self._fq(t * cos, obs_cos) + t_sin = self._fq(t_half * sin, obs_sin) + t_rot = self._fq(t_cos + t_sin, obs_rot) - return q_rot, k_rot + return t_rot def forward( self, @@ -194,36 +186,23 @@ def forward( B, S, _ = hidden.shape H = self.head_dim - # Projections - q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H) - k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H) - v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H) + kv_rep = self.kv_rep + q = self.q_proj(hidden).view(B, S, -1, H) # (B, S, n_h, H) + k = self.k_proj(hidden).view(B, S, -1, H) # (B, K, n_kv, H) + v = self.v_proj(hidden).view(B, S, -1, H) # (B, K, n_kv, H) - # Rope tables + # Ropte tables cos, sin = position_embeddings cos = self._fq(cos, self.obs_cos) sin = self._fq(sin, self.obs_sin) - q_rot, k_rot = self._apply_rope(q, k, cos, sin, unsqueeze_dim=1) - # --- build/update KV for attention & present_key_value ------------- - present_key_value: Tuple[torch.Tensor, torch.Tensor] - - # TODO Revisit cache logic - # HF Cache path (if available) - if use_cache and hasattr(past_key_value, "update"): - k_total, v_total = past_key_value.update(k_rot, v) - present_key_value = (k_total, v_total) - k_for_attn, v_for_attn = k_total, v_total - else: - # Tuple or None path - pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None - k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v) - present_key_value = (k_for_attn, v_for_attn) + # --- KV for attention & present_key_value ------------- + present_key_value: List[List[torch.Tensor, torch.Tensor]] = [] # Build causal mask if needed if attention_mask is None or attention_mask.dtype == torch.bool: - q_len = q_rot.size(2) - k_len = k_for_attn.size(2) + q_len = q.size(1) + k_len = k.size(1) assert isinstance(self.causal_mask_template, torch.Tensor) attention_mask = self.causal_mask_template[..., :q_len, :k_len].to( hidden_states.device @@ -232,44 +211,68 @@ def forward( attn_weights_parts = [] attn_out_parts = [] - - n_kv_h = k_for_attn.size(1) - kv_rep = self.kv_rep - - # TODO Consider attaching a separate observer to each computation. - for kv_i in range(n_kv_h): - # k_h, v_h: (B, 1, K, H) - k_i = k_for_attn[:, kv_i : kv_i + 1, :, :] - v_i = v_for_attn[:, kv_i : kv_i + 1, :, :] + for kv_i in range(self.n_kv): + # k_h, v_h: (B, K, H) + k_i = k[:, :, kv_i, :] + v_i = v[:, :, kv_i, :] + + k_i = self._apply_rope( + k_i, + cos, + sin, + self.obs_k_x1, + self.obs_k_x2, + self.obs_k_cat, + self.obs_k_cos, + self.obs_k_sin, + self.obs_k_rot, + ) + k_i, v_i = self._concat_kv(past_key_value, k_i, v_i, kv_i) for rep_i in range(kv_rep): - # q_h: (B, 1, S, H) q_idx = kv_i * kv_rep + rep_i - q_i = q_rot[:, q_idx : q_idx + 1, :, :] - - # logits: (B, 1, S, K) + # q_h: (B, S, H) + q_i = q[:, :, q_idx, :] + q_i = self._apply_rope( + q_i, + cos, + sin, + self.obs_q_x1, + self.obs_q_x2, + self.obs_q_cat, + self.obs_q_cos, + self.obs_q_sin, + self.obs_q_rot, + ) + + # logits: (B, S, K) logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits) # mask add - logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add) + logits_i = self._fq( + logits_i + attention_mask.view(1, q_i.size(1), k_i.size(1)), + self.obs_mask_add, + ) # softmax attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype) attn_i = self._fq(attn_i, self.obs_softmax) - # out: (B, 1, S, H) + # out: (B, S, H) out_i = self._fq(attn_i @ v_i, self.obs_attn_out) attn_weights_parts.append(attn_i) attn_out_parts.append(out_i) - # concat heads back: (B, n_h, S, K) / (B, n_h, S, H) + # concat heads back + # (B, n_h, S, K) attn_weights = self._fq( - torch.cat(attn_weights_parts, dim=1), self.obs_attn_weights + torch.stack(attn_weights_parts, dim=1), self.obs_attn_weights ) - attn_out_h = self._fq(torch.cat(attn_out_parts, dim=1), self.obs_attn_out_h) + # (B, n_h, S, H) + attn_out_h = self._fq(torch.stack(attn_out_parts, dim=1), self.obs_attn_out_h) - # Attention output - attn_out = attn_out_h.transpose(1, 2).reshape(B, S, -1) # (B, S, n_h * H) + # Attention output: (B, S, n_h * H) + attn_out = attn_out_h.transpose(1, 2).reshape(B, S, -1) # Final projection out = self.o_proj(attn_out)