perf: optimize abliteration matrix op#46
Conversation
|
This was actually itself a performance optimization. Note that for MoE models, there can be many MLP matrices per layer, as many as there are experts. So for a 128-expert MoE, pre-computing the projector saves 128 matrix-vector multiplications per layer (though of course their total rank is lower than when multiplying with the full projector matrix, the number of coefficient multiplications is still less). I am also somewhat hesitant in general to complicate the ablation logic in any way unless the resulting gains are substantial, because reliably testing that logic is difficult so it needs to be obviously correct, especially since it's about to become more complex anyway with #43. |
The proof of correctness is really trivial, it's 1 line and follows from the associativity of matrix multiplication: instead of doing I took a look at #43 and I'm 99% sure my changes are compatible, I'm also OK with waiting for it to be merged and then rebasing. Lmk what you think! |
|
Have you benchmarked this with a large MoE model to test how much worse the performance is because of the issue I described above? |
Tried Qwen1.5-MoE-A2.7B on a Quadro RTX 6000. Without the optimization, double the VRAM is used for the abliteration op (32MiB -> 16MiB) and it takes 4x the time due to the issue you described. I'm going to try with Phi-3.5-MoE-instruct on an A100 sometime later this week, but this looks OK to me After optimizationBefore optimization |
|
Wait, the new logic is also faster? How does that work given that there are fewer coefficients to multiply? |
Forming the dense projector The new approach
The real cost isn't in computing the projection, it's multiplying the dense projection matrix into every expert matrix. Matrix-matrix multiplies tend to be very expensive. This way, we can avoid constructing the projector matrix and doing matrix-matrix multiplies altogether, giving us some time and space back. Let me know if this clears it up, I'm not an expert on linear algebra so I'm learning here as well :) |
|
Ah yes, sorry, I just crawled out of bed and had the dimensionalities confused there. I literally just tried to do the same analysis in my head and got the reverse results, but your math is certainly correct 😄 We're of course merging this then, let me just do a quick review. |
p-e-w
left a comment
There was a problem hiding this comment.
Just some comments regarding the explanations. They also show how incredibly subtle this stuff is, which is why I always hesitate to modify such code.
src/heretic/model.py
Outdated
| layer_refusal_direction, | ||
| ).to(self.model.dtype) | ||
| # We use the property (r r^T) W = r (r^T W) to avoid computing | ||
| # the O(d^2) projector matrix and the O(d^3) matrix multiplication. |
There was a problem hiding this comment.
The multiplication is actually
src/heretic/model.py
Outdated
|
|
||
| # Calculate the projection scalars: (r^T W) | ||
| # hat_r is (d, 1), matrix is (d, k) -> result is (k,) | ||
| r_transpose_W = torch.matmul(hat_r_device, matrix) |
There was a problem hiding this comment.
It's inconsistent to use hat_r above but not hat_r_transpose here, even though the paper uses
There was a problem hiding this comment.
I suggest removing the hat_ prefix everywhere, as it just complicates things and serves no explanatory purpose in the code.
src/heretic/model.py
Outdated
| hat_r_device = hat_r.to(matrix.device) | ||
|
|
||
| # Calculate the projection scalars: (r^T W) | ||
| # hat_r is (d, 1), matrix is (d, k) -> result is (k,) |
There was a problem hiding this comment.
No, that doesn't work. The shapes you describe aren't compatible. The inner dimensions must match in matrix multiplication.
What actually happens is that Torch prepends (1, to 1d vectors, so hat_r is (1, d), not (d, 1).
src/heretic/model.py
Outdated
| r_transpose_W = torch.matmul(hat_r_device, matrix) | ||
|
|
||
| # Calculate the update matrix: r * (r^T W) | ||
| # Outer product of (d, 1) and (1, k) -> result is (d, k) |
There was a problem hiding this comment.
Not quite, actually. What you are describing is matrix multiplication. The outer product takes two column vectors
There was a problem hiding this comment.
Mmm, technically the outer product already includes a transposition of the second vector, though at this point it's just semantics. I get that the comment is a bit misleading though so I'll be more clear on why the outer product is used here
There was a problem hiding this comment.
The outer product is formally not a type of matrix multiplication at all, though it happens to be equal to a matrix multiplication. The inputs are not transposed, and passing tensors of shapes (d, 1) and (1, k) is undefined, and would raise an error in PyTorch.
src/heretic/model.py
Outdated
| layer_refusal_direction, | ||
| ).to(self.model.dtype) | ||
| # We use the property (r r^T) W = r (r^T W) to avoid computing | ||
| # the O(d^2 k) projector matrix and the O(d^3) matrix multiplication. |
There was a problem hiding this comment.
I think those complexities are still not correct. Computing the projector takes
src/heretic/model.py
Outdated
| ).to(self.model.dtype) | ||
| # We use the property (r r^T) W = r (r^T W) to avoid computing | ||
| # the O(d^2 k) projector matrix and the O(d^3) matrix multiplication. | ||
| # W_new = W - r * (r^T W) |
There was a problem hiding this comment.
We have a weight, so this formula is incomplete.
src/heretic/model.py
Outdated
| r_device = r.to(matrix.device) | ||
|
|
||
| # Calculate the projection scalars: (r^T W) | ||
| # r is (1, d), matrix is (d, k) -> result is (k,) |
There was a problem hiding this comment.
(d,). torch.matmul internally transforms it to (1, d).
src/heretic/model.py
Outdated
| return torch.cat(logprobs, dim=0) | ||
|
|
||
| def stream_chat_response(self, chat: list[dict[str, str]]) -> str: | ||
| def stream_cresponse(self, chat: list[dict[str, str]]) -> str: |
|
Okay, this is going in. Thanks! |
* perf: optimize abliteration matrix op * refactor: comments and var names correspond with arditi * refactor: fix comments and improve var notation * fix: accidental line change and improve comments --------- Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
|
Unfortunately, I had to revert this PR, because it breaks if I'd be happy to merge a fixed version of this PR again, but this late in the release cycle I can't take such risks, so for now, this has to be reverted. |
* Add files via upload * perf: optimize abliteration matrix op (#46) * perf: optimize abliteration matrix op * refactor: comments and var names correspond with arditi * refactor: fix comments and improve var notation * fix: accidental line change and improve comments --------- Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> * Fix line endings to LF * Add hybrid approach for GPT-OSS compatibility - Check for LoRA adapters before attempting LoRA abliteration - Fall back to direct weight modification for nn.Parameter (GPT-OSS) - Ensures compatibility across all model architectures * Fix projector bug, update print statement, revert README * Revert README changes to match upstream * Fix import sorting for ruff * Fix reload_model for evaluate_model, add type hints and validation * Apply ruff formatting * Replace load_in_4bit with quantization enum * Fix precision loss: use FP32 refusal direction directly * Move r assignment into non-LoRA path * Fix linting: apply ruff formatting * Add auto-merge for LoRA adapters on save/upload * Fix linting: apply ruff formatting * Implement CPU-based merge for 4-bit models with OOM fallback * Remove use_lora flag (LoRA always on), add user prompt for 4-bit export * Fix: PEFT target_modules expects module names without path prefix * Fix linting: apply ruff formatting * Add LoRA fallback and fix quantization_config handling - Add try/except around LoRA initialization with fallback to direct weight modification - Only pass quantization_config when not None (fixes gpt-oss loading) - Use simple forward pass instead of generate() for model test (avoids chat template issues) - Reset non-LoRA models by reloading in reload_model() - Check self.use_lora before accessing LoRA adapters in abliterate() * Add 8-bit quantization support via bitsandbytes - Add BNB_8BIT option to QuantizationMethod enum - Add --load-in-8bit CLI support (auto via pydantic-settings) - Update documentation in config.py and config.default.toml - Useful for mid-range VRAM (12-16 GB) as balance between memory and numeric stability * Improve LoRA merge warning and fix linting * Apply final ruff formatting * Fix CI: apply ruff import sorting * Use tiny model for CI efficiency * Fix import sorting in test_lora.py * Fix formatting in test_lora.py * feat: Show merge warning for all models (requires high RAM) * style: Apply ruff fixes * Fix undefined Style import in main.py * Fix(model): Support MoE/3D tensors and enforce dtype safety in abliterate * Fix(ci): Format model.py with ruff * Fix(main): Remove invalid style argument from prompt_select and unused import * Fix logic errors, memory leak, and redundant merges in main.py * Fix linting and formatting issues (isort, ruff) * chore: Simplify .gitattributes as requested * refactor: Remove defensive try-except around LoRA initialization * chore: Update uv.lock with peft and bitsandbytes * chore: Regenerate uv.lock to include missing peft dependency * style: Fix import sorting (isort) for CI compliance * style: Simplify .gitattributes to single line as requested * Address PR #60 feedback: Remove caching, fix LoRA reload, global LoRA usage, style fixes * Address PR review comments: clarify code, fix quantization, rename method - Add explanatory comments for warning suppression and gc behavior - Remove redundant gc.collect() calls (empty_cache handles it) - Fix output message order (ask merge strategy before 'Uploading...') - Add comment explaining 8-bit quantization doesn't need compute_dtype - Remove extra newline after dtype comment - Add future-proofing note for hybrid layer support (#43) - Remove leftover comment in get_merged_model - Delete test_lora.py (debug script, not a real test) - Add comment explaining needs_reload flag purpose - Extract quantization config into _get_quantization_config() helper - Rename reload_model() to reset_model_for_trial() for clarity - Fix reload_model to respect quantization config (fixes evaluate_model bug) - Remove unused gc import * Restore gc.collect() before empty_cache() for large models * refactor: Remove LoRA fallback remnants, simplify code - Remove use_lora flag (always true since LoRA is always applied) - Remove isinstance(PeftModel) check in get_merged_model() (always true) - Simplify reset_model_for_trial() by removing defensive try/except - Remove redundant gc.collect() calls (empty_cache handles GC) - Remove unused gc import from main.py * Address p-e-w review feedback: rename reset_model, remove loaded_model_name, fix type hints, remove GPT-OSS MoE, update assertion * Restore skip logic for non-LoRA modules and fix 4-bit base_layer.weight access * Remove defensive lora_A check per review - get_layer_modules already filters * Fix try_add: nest component init inside Module check, add assert for unexpected types * Add note about module.weight assumption for type checking * Change 'Reloading model' to 'Resetting model' in logging --------- Co-authored-by: accemlcc <accemlcc@users.noreply.github.com> Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> Co-authored-by: Hager <Michael.Hager@bruker.com>
Found a possible small VRAM optimization when hacking around. We want to compute
M - weight * v (v^T M)wherevis the refusal direction for a layer. Instead of building the projector matrix inO(d^2)memory andprojector @ matrixwhich isO(d^2*k), we can use the identity(v v^T) M = v(v^T M)to apply the transformation directly to the weights. I made the varnames and comments correspond to the original Arditi et. al paper just for clarityWas able to save about 2MB VRAM on my very low-power 4060 with Qwen-0.6B and it should increase for larger models. If my math is correct, the projector matrix size would be around 256MB for a 70B model which is a decent chunk saved, especially with the cost of VRAM these days! There was a very slight increase in computation time, possibly from CPU overhead from Python (I have no idea tbh) but it should be negligible.
Before optimization
After optimization