-
Notifications
You must be signed in to change notification settings - Fork 434
Description
I've been trying to figure out why the winsorization added in #52 actually helps, and I think I've worked it out: The winsorization indirectly mitigates the influence of ablation on dimensions affected by so-called massive activations.
Massive activations are a phenomenon where specific tokens, sometimes in a position dependent way, trigger massive activations in specific dimensions: https://github.com/locuslab/massive-activations
These activations aren't context sensitive (they have similar values regardless), tend to crop up suddenly in early to middle layers, and tend to persist through several layers before eventually being attenuated.
Due to their magnitude they can bias the mean and end up dominating the ablation direction. More importantly however, ablating them can be catastrophic for the model's performance!
Winsorization mitigates this by setting the magnitude of any massive activations to something more normal. This reduces their impact on the mean direction, but also reduces the impact of ablation on those "special" dimensions.
However, I think a better and more direct way to deal with them is to simply zero out those dimensions in the directions used for ablation (across all layers), so they are left completely untouched.
The more difficult part is actually detecting the affected dimensions. There's no real way to infer what tokens might trigger massive activations, so the only way to find them is to examine the hidden state for each token across different prompts.
Detection is heuristic: If a dimension of the hidden state (for a particular token in a particular layer) has abs_value > 100 and abs_value > 1000 * abs_layer_median then it's considered a massive activation. Since a characteristic of massive activations is that they tend to persist for several layers, the detection can be made more robust by checking that large activations are also seen in the next ~4 layers, e.g. with a slightly less strict criterion like abs_value > 200 * abs_layer_median to account for attenuation.
I've implemented this locally and it seems to help a lot: For Gemma-3-12b-it, dimension 2339 exhibits a lot of massive activations, and simply zeroing it out in the directions leads to similar improvements in refusals / KL divergence as what I've observed with winsorization (though I don't have enough data to confidently say that it's a complete replacement). The full detection identifies a further 3 massive activation dimensions for this model. Qwen3-4B-Instruct-2507 has two at dimensions 4 and 396, which trigger for a token from its chat template.
Detection can be integrated into get_residuals using per-layer hooks, since we already have to set output_hidden_states=True anyway. But a small dedicated detection with a prompt containing the most common trigger tokens might also be a good idea (since regular prompts may not contain all of the most common triggers). I've currently implemented a combination of both, and I think it's reasonably clean.
Would you be interested in a PR?