Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions tico/quantization/wrapq/wrappers/llama/quant_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -52,12 +53,6 @@ def __init__(
)
self.kv_rep = cfg.num_attention_heads // cfg.num_key_value_heads

# Constant scale (1/√d)
self.scale_t = torch.tensor(
float(getattr(fp_attn, "scaling", self.head_dim**-0.5))
)
self.obs_scale = self._make_obs("scale")

# ---- Wrap q k v o projections via PTQWrapper ---------------
q_cfg = qcfg.child("q_proj") if qcfg else None
k_cfg = qcfg.child("k_proj") if qcfg else None
Expand All @@ -81,7 +76,7 @@ def __init__(
fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj"
)
self.k_proj = PTQWrapper(
fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj"
copy.deepcopy(fp_attn.k_proj), qcfg=k_cfg, fp_name=f"{fp_name}.k_proj"
)
self.v_proj = PTQWrapper(
fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj"
Expand All @@ -90,6 +85,17 @@ def __init__(
fp_attn.o_proj, qcfg=o_cfg, fp_name=f"{fp_name}.o_proj"
)

# Constant scale (1/√d)
scale_t = torch.tensor(
float(getattr(fp_attn, "scaling", self.head_dim**-0.5))
)
# merge scale_t to k_proj, (otherwise merge it to q_proj)
with torch.no_grad():
lin = self.k_proj.wrapped.module
lin.weight.mul_(scale_t)
if lin.bias is not None:
lin.bias.mul_(scale_t)

mk = self._make_obs
self.obs_hidden = mk("hidden")

Expand Down Expand Up @@ -119,7 +125,6 @@ def __init__(

# Masking & attention math
self.obs_causal_mask = mk("causal_mask")
self.obs_logits_raw = mk("logits_raw")
self.obs_logits = mk("logits")
self.obs_mask_add = mk("mask_add")
self.obs_softmax = mk("softmax")
Expand Down Expand Up @@ -226,9 +231,7 @@ def forward(
v_rep = v_for_attn.repeat_interleave(self.kv_rep, dim=1) # (B, n_h, K, H)

# Attention logits: q @ k^T
logits_raw = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits_raw)
scale = self._fq(self.scale_t, self.obs_scale)
logits = self._fq(logits_raw * scale, self.obs_logits)
logits = self._fq(q_rot @ k_rep.transpose(-2, -1), self.obs_logits)

# Build causal mask if needed
if attention_mask is None or attention_mask.dtype == torch.bool:
Expand Down Expand Up @@ -265,7 +268,6 @@ def _all_observers(self):
# local first
yield from (
self.obs_hidden,
self.obs_scale,
self.obs_cos,
self.obs_sin,
self.obs_causal_mask,
Expand All @@ -283,7 +285,6 @@ def _all_observers(self):
self.obs_k_cos,
self.obs_k_sin,
self.obs_k_rot,
self.obs_logits_raw,
self.obs_logits,
self.obs_mask_add,
self.obs_softmax,
Expand Down