Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 68 additions & 65 deletions tico/quantization/wrapq/wrappers/llama/quant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import copy
from typing import Optional, Tuple
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -52,6 +52,7 @@ def __init__(
cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads
)
self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads
self.n_kv = cfg.num_key_value_heads

# ---- Wrap q k v o projections via PTQWrapper ---------------
q_cfg = qcfg.child("q_proj") if qcfg else None
Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(
mask.triu_(1)
self.register_buffer("causal_mask_template", mask, persistent=False)

def _rot(self, t, o_x1, o_x2, o_neg, o_cat):
def _rot(self, t, o_x1, o_x2, o_cat):
x1, x2 = torch.chunk(t, 2, dim=-1)
x1 = self._fq(x1, o_x1)
x2 = self._fq(x2, o_x2)
Expand All @@ -151,34 +152,25 @@ def _concat_kv(
past: Optional[Tuple[torch.Tensor, torch.Tensor]],
k_new: torch.Tensor,
v_new: torch.Tensor,
h_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Concat along sequence dim (dim=2): (B, n_kv, S, H)."""
"""Concat along sequence dim (dim=1): (B, S, H) given head index."""
if past is None:
return k_new, v_new
past_k, past_v = past
k = torch.cat([past_k, k_new], dim=2)
v = torch.cat([past_v, v_new], dim=2)
k = torch.cat([past_k[:, h_idx, :, :], k_new], dim=1)
v = torch.cat([past_v[:, h_idx, :, :], v_new], dim=1)
return k, v

def _apply_rope(self, q, k, cos, sin, unsqueeze_dim: int = 1):
cos_u = cos.unsqueeze(unsqueeze_dim)
sin_u = sin.unsqueeze(unsqueeze_dim)

q_half = self._rot(
q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
)
q_cos = self._fq(q * cos_u, self.obs_q_cos)
q_sin = self._fq(q_half * sin_u, self.obs_q_sin)
q_rot = self._fq(q_cos + q_sin, self.obs_q_rot)

k_half = self._rot(
k, self.obs_k_x1, self.obs_k_x2, self.obs_k_neg, self.obs_k_cat
)
k_cos = self._fq(k * cos_u, self.obs_k_cos)
k_sin = self._fq(k_half * sin_u, self.obs_k_sin)
k_rot = self._fq(k_cos + k_sin, self.obs_k_rot)
def _apply_rope(
self, t, cos, sin, obs_x1, obs_x2, obs_cat, obs_cos, obs_sin, obs_rot
):
t_half = self._rot(t, obs_x1, obs_x2, obs_cat)
t_cos = self._fq(t * cos, obs_cos)
t_sin = self._fq(t_half * sin, obs_sin)
t_rot = self._fq(t_cos + t_sin, obs_rot)

return q_rot, k_rot
return t_rot

def forward(
self,
Expand All @@ -194,36 +186,23 @@ def forward(
B, S, _ = hidden.shape
H = self.head_dim

# Projections
q = self.q_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_h, S, H)
k = self.k_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
v = self.v_proj(hidden).view(B, S, -1, H).transpose(1, 2) # (B, n_kv, S, H)
kv_rep = self.kv_rep
q = self.q_proj(hidden).view(B, S, -1, H) # (B, S, n_h, H)
k = self.k_proj(hidden).view(B, S, -1, H) # (B, K, n_kv, H)
v = self.v_proj(hidden).view(B, S, -1, H) # (B, K, n_kv, H)

# Rope tables
# Ropte tables
cos, sin = position_embeddings
cos = self._fq(cos, self.obs_cos)
sin = self._fq(sin, self.obs_sin)
q_rot, k_rot = self._apply_rope(q, k, cos, sin, unsqueeze_dim=1)

# --- build/update KV for attention & present_key_value -------------
present_key_value: Tuple[torch.Tensor, torch.Tensor]

# TODO Revisit cache logic
# HF Cache path (if available)
if use_cache and hasattr(past_key_value, "update"):
k_total, v_total = past_key_value.update(k_rot, v)
present_key_value = (k_total, v_total)
k_for_attn, v_for_attn = k_total, v_total
else:
# Tuple or None path
pkv_tuple = past_key_value if isinstance(past_key_value, tuple) else None
k_for_attn, v_for_attn = self._concat_kv(pkv_tuple, k_rot, v)
present_key_value = (k_for_attn, v_for_attn)
# --- KV for attention & present_key_value -------------
present_key_value: List[List[torch.Tensor, torch.Tensor]] = []

# Build causal mask if needed
if attention_mask is None or attention_mask.dtype == torch.bool:
q_len = q_rot.size(2)
k_len = k_for_attn.size(2)
q_len = q.size(1)
k_len = k.size(1)
assert isinstance(self.causal_mask_template, torch.Tensor)
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
hidden_states.device
Expand All @@ -232,44 +211,68 @@ def forward(

attn_weights_parts = []
attn_out_parts = []

n_kv_h = k_for_attn.size(1)
kv_rep = self.kv_rep

# TODO Consider attaching a separate observer to each computation.
for kv_i in range(n_kv_h):
# k_h, v_h: (B, 1, K, H)
k_i = k_for_attn[:, kv_i : kv_i + 1, :, :]
v_i = v_for_attn[:, kv_i : kv_i + 1, :, :]
for kv_i in range(self.n_kv):
# k_h, v_h: (B, K, H)
k_i = k[:, :, kv_i, :]
v_i = v[:, :, kv_i, :]

k_i = self._apply_rope(
k_i,
cos,
sin,
self.obs_k_x1,
self.obs_k_x2,
self.obs_k_cat,
self.obs_k_cos,
self.obs_k_sin,
self.obs_k_rot,
)
k_i, v_i = self._concat_kv(past_key_value, k_i, v_i, kv_i)
for rep_i in range(kv_rep):
# q_h: (B, 1, S, H)
q_idx = kv_i * kv_rep + rep_i
q_i = q_rot[:, q_idx : q_idx + 1, :, :]

# logits: (B, 1, S, K)
# q_h: (B, S, H)
q_i = q[:, :, q_idx, :]
q_i = self._apply_rope(
q_i,
cos,
sin,
self.obs_q_x1,
self.obs_q_x2,
self.obs_q_cat,
self.obs_q_cos,
self.obs_q_sin,
self.obs_q_rot,
)

# logits: (B, S, K)
logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits)

# mask add
logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add)
logits_i = self._fq(
logits_i + attention_mask.view(1, q_i.size(1), k_i.size(1)),
self.obs_mask_add,
)

# softmax
attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype)
attn_i = self._fq(attn_i, self.obs_softmax)

# out: (B, 1, S, H)
# out: (B, S, H)
out_i = self._fq(attn_i @ v_i, self.obs_attn_out)

attn_weights_parts.append(attn_i)
attn_out_parts.append(out_i)

# concat heads back: (B, n_h, S, K) / (B, n_h, S, H)
# concat heads back
# (B, n_h, S, K)
attn_weights = self._fq(
torch.cat(attn_weights_parts, dim=1), self.obs_attn_weights
torch.stack(attn_weights_parts, dim=1), self.obs_attn_weights
)
attn_out_h = self._fq(torch.cat(attn_out_parts, dim=1), self.obs_attn_out_h)
# (B, n_h, S, H)
attn_out_h = self._fq(torch.stack(attn_out_parts, dim=1), self.obs_attn_out_h)

# Attention output
attn_out = attn_out_h.transpose(1, 2).reshape(B, S, -1) # (B, S, n_h * H)
# Attention output: (B, S, n_h * H)
attn_out = attn_out_h.transpose(1, 2).reshape(B, S, -1)

# Final projection
out = self.o_proj(attn_out)
Expand Down
Loading