diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn.py b/tico/quantization/wrapq/wrappers/llama/quant_attn.py index a671cb2b..babdeed2 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn.py @@ -233,35 +233,34 @@ def forward( attn_weights_parts = [] attn_out_parts = [] - n_kv = k_for_attn.size(1) # num_key_value_heads - kv_rep = self.kv_rep # num_key_value_groups + n_kv_h = k_for_attn.size(1) + kv_rep = self.kv_rep # TODO Consider attaching a separate observer to each computation. - for i in range(n_kv): - # (B, 1, K, H) - k_i = k_for_attn[:, i : i + 1, :, :] - v_i = v_for_attn[:, i : i + 1, :, :] - - # (B, G, S, H) where G=kv_rep - h0 = i * kv_rep - h1 = (i + 1) * kv_rep - q_i = q_rot[:, h0:h1, :, :] - - # logits: (B, G, S, K) - logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits) - - # mask add: broadcast on head axis (1 -> G). - logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add) - - # softmax - attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype) - attn_i = self._fq(attn_i, self.obs_softmax) - - # out: (B, G, S, H) - out_i = self._fq(attn_i @ v_i, self.obs_attn_out) - - attn_weights_parts.append(attn_i) - attn_out_parts.append(out_i) + for kv_i in range(n_kv_h): + # k_h, v_h: (B, 1, K, H) + k_i = k_for_attn[:, kv_i : kv_i + 1, :, :] + v_i = v_for_attn[:, kv_i : kv_i + 1, :, :] + for rep_i in range(kv_rep): + # q_h: (B, 1, S, H) + q_idx = kv_i * kv_rep + rep_i + q_i = q_rot[:, q_idx : q_idx + 1, :, :] + + # logits: (B, 1, S, K) + logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits) + + # mask add + logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add) + + # softmax + attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype) + attn_i = self._fq(attn_i, self.obs_softmax) + + # out: (B, 1, S, H) + out_i = self._fq(attn_i @ v_i, self.obs_attn_out) + + attn_weights_parts.append(attn_i) + attn_out_parts.append(out_i) # concat heads back: (B, n_h, S, K) / (B, n_h, S, H) attn_weights = self._fq(