Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions tico/quantization/wrapq/wrappers/llama/quant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,35 +233,34 @@ def forward(
attn_weights_parts = []
attn_out_parts = []

n_kv = k_for_attn.size(1) # num_key_value_heads
kv_rep = self.kv_rep # num_key_value_groups
n_kv_h = k_for_attn.size(1)
kv_rep = self.kv_rep

# TODO Consider attaching a separate observer to each computation.
for i in range(n_kv):
# (B, 1, K, H)
k_i = k_for_attn[:, i : i + 1, :, :]
v_i = v_for_attn[:, i : i + 1, :, :]

# (B, G, S, H) where G=kv_rep
h0 = i * kv_rep
h1 = (i + 1) * kv_rep
q_i = q_rot[:, h0:h1, :, :]

# logits: (B, G, S, K)
logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits)

# mask add: broadcast on head axis (1 -> G).
logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add)

# softmax
attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype)
attn_i = self._fq(attn_i, self.obs_softmax)

# out: (B, G, S, H)
out_i = self._fq(attn_i @ v_i, self.obs_attn_out)

attn_weights_parts.append(attn_i)
attn_out_parts.append(out_i)
for kv_i in range(n_kv_h):
# k_h, v_h: (B, 1, K, H)
k_i = k_for_attn[:, kv_i : kv_i + 1, :, :]
v_i = v_for_attn[:, kv_i : kv_i + 1, :, :]
for rep_i in range(kv_rep):
# q_h: (B, 1, S, H)
q_idx = kv_i * kv_rep + rep_i
q_i = q_rot[:, q_idx : q_idx + 1, :, :]

# logits: (B, 1, S, K)
logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits)

# mask add
logits_i = self._fq(logits_i + attention_mask, self.obs_mask_add)

# softmax
attn_i = torch.softmax(logits_i, -1, dtype=torch.float32).to(q_i.dtype)
attn_i = self._fq(attn_i, self.obs_softmax)

# out: (B, 1, S, H)
out_i = self._fq(attn_i @ v_i, self.obs_attn_out)

attn_weights_parts.append(attn_i)
attn_out_parts.append(out_i)

# concat heads back: (B, n_h, S, K) / (B, n_h, S, H)
attn_weights = self._fq(
Expand Down