Skip to content

perf: optimize abliteration matrix op#46

Merged
p-e-w merged 4 commits intop-e-w:masterfrom
red40maxxer:abliteration-matmul-optimization
Dec 2, 2025
Merged

perf: optimize abliteration matrix op#46
p-e-w merged 4 commits intop-e-w:masterfrom
red40maxxer:abliteration-matmul-optimization

Conversation

@red40maxxer
Copy link
Contributor

@red40maxxer red40maxxer commented Nov 24, 2025

Found a possible small VRAM optimization when hacking around. We want to compute M - weight * v (v^T M) where v is the refusal direction for a layer. Instead of building the projector matrix in O(d^2) memory and projector @ matrix which is O(d^2*k), we can use the identity (v v^T) M = v(v^T M) to apply the transformation directly to the weights. I made the varnames and comments correspond to the original Arditi et. al paper just for clarity

Was able to save about 2MB VRAM on my very low-power 4060 with Qwen-0.6B and it should increase for larger models. If my math is correct, the projector matrix size would be around 256MB for a 70B model which is a decent chunk saved, especially with the cost of VRAM these days! There was a very slight increase in computation time, possibly from CPU overhead from Python (I have no idea tbh) but it should be negligible.

Before optimization

GPU type: NVIDIA GeForce RTX 4060 Laptop GPU

Loading model Qwen/Qwen3-0.6B...
* Trying dtype auto... Ok
* Transformer model with 28 layers
* Abliterable components:
  * attn.o_proj: 1 matrices per layer
  * mlp.down_proj: 1 matrices per layer

* Transformer model with 28 layers
* Abliterable components:
  * attn.o_proj: 1 matrices per layer
  * mlp.down_proj: 1 matrices per layer

Loading good prompts from mlabonne/harmless_alpaca...
* 400 prompts loaded

Loading bad prompts from mlabonne/harmful_behaviors...
* 400 prompts loaded

Loading good evaluation prompts from mlabonne/harmless_alpaca...
* 100 prompts loaded
* Obtaining first-token probability distributions...

Loading bad evaluation prompts from mlabonne/harmful_behaviors...
* 100 prompts loaded
* Counting model refusals...
* Initial refusals: 38/100

Calculating per-layer refusal directions...
* Obtaining residuals for good prompts...
* Obtaining residuals for bad prompts...

Running trial 1 of 200...
* Parameters:
  * direction_index = 14.25
  * attn.o_proj.max_weight = 1.40
  * attn.o_proj.max_weight_position = 21.88
  * attn.o_proj.min_weight = 1.15
  * attn.o_proj.min_weight_distance = 5.46
  * mlp.down_proj.max_weight = 1.21
  * mlp.down_proj.max_weight_position = 20.43
  * mlp.down_proj.min_weight = 0.72Was able to save about 2MB VRAM on my very low-power 4060 with Qwen-0.6B and it should increase for larger models. If my math is correct, the projector matrix size would be around 256MB for a 70B model which is a decent chunk saved, especially with the cost of VRAM these days! There was a very slight increase in computation time, possibly from CPU overhead from Python (I have no idea tbh) but it should be negligible. 

  * mlp.down_proj.min_weight_distance = 15.00
* Reloading model...
* Abliterating...
  Abliteration logic took 0.0188s
  Peak VRAM overhead: 14.00 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.02
  * Counting model refusals...
  * Refusals: 17/100

Elapsed time: 8s
Estimated remaining time: 25m 13s

Running trial 2 of 200...
* Parameters:
  * direction_index = per layer
  * attn.o_proj.max_weight = 1.24
  * attn.o_proj.max_weight_position = 22.26
  * attn.o_proj.min_weight = 0.94
  * attn.o_proj.min_weight_distance = 1.12
  * mlp.down_proj.max_weight = 1.37
  * mlp.down_proj.max_weight_position = 26.44
  * mlp.down_proj.min_weight = 1.04
  * mlp.down_proj.min_weight_distance = 3.13
* Reloading model...
* Abliterating...
  Abliteration logic took 0.0017s
  Peak VRAM overhead: 14.00 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.06
  * Counting model refusals...
  * Refusals: 33/100

Elapsed time: 15s
Estimated remaining time: 25m 14s

Running trial 3 of 200...
* Parameters:
  * direction_index = per layer
  * attn.o_proj.max_weight = 0.90
  * attn.o_proj.max_weight_position = 21.58
  * attn.o_proj.min_weight = 0.53
  * attn.o_proj.min_weight_distance = 4.89
  * mlp.down_proj.max_weight = 1.47
  * mlp.down_proj.max_weight_position = 18.15
  * mlp.down_proj.min_weight = 0.19
  * mlp.down_proj.min_weight_distance = 13.89
* Reloading model...
* Abliterating...
  Abliteration logic took 0.0043s
  Peak VRAM overhead: 14.00 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.14
  * Counting model refusals...

After optimization

GPU type: NVIDIA GeForce RTX 4060 Laptop GPU

Loading model Qwen/Qwen3-0.6B...
* Trying dtype auto... Ok
* Transformer model with 28 layers
* Abliterable components:
  * attn.o_proj: 1 matrices per layer
  * mlp.down_proj: 1 matrices per layer

Loading good prompts from mlabonne/harmless_alpaca...
* 400 prompts loaded

Loading bad prompts from mlabonne/harmful_behaviors...
* 400 prompts loaded

Loading good evaluation prompts from mlabonne/harmless_alpaca...
* 100 prompts loaded
* Obtaining first-token probability distributions...

Loading bad evaluation prompts from mlabonne/harmful_behaviors...
* 100 prompts loaded
* Counting model refusals...
* Initial refusals: 38/100

Calculating per-layer refusal directions...
* Obtaining residuals for good prompts...
* Obtaining residuals for bad prompts...

Running trial 1 of 200...
* Parameters:
  * direction_index = per layer
  * attn.o_proj.max_weight = 1.17
  * attn.o_proj.max_weight_position = 17.45
  * attn.o_proj.min_weight = 0.17
  * attn.o_proj.min_weight_distance = 1.93
  * mlp.down_proj.max_weight = 1.29
  * mlp.down_proj.max_weight_position = 24.54
  * mlp.down_proj.min_weight = 0.58
  * mlp.down_proj.min_weight_distance = 5.00
* Reloading model...
* Abliterating...
  Abliteration logic took 0.0087s
  Peak VRAM overhead: 12.01 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.08
  * Counting model refusals...
  * Refusals: 30/100

Elapsed time: 8s
Estimated remaining time: 26m 32s

Running trial 2 of 200...
* Parameters:
  * direction_index = 19.83
  * attn.o_proj.max_weight = 1.49
  * attn.o_proj.max_weight_position = 23.52
  * attn.o_proj.min_weight = 0.85
  * attn.o_proj.min_weight_distance = 12.74
  * mlp.down_proj.max_weight = 0.81
  * mlp.down_proj.max_weight_position = 25.95
  * mlp.down_proj.min_weight = 0.52
  * mlp.down_proj.min_weight_distance = 8.02
* Reloading model...
* Abliterating...
  Abliteration logic took 0.0041s
  Peak VRAM overhead: 12.01 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.05
  * Counting model refusals...
  * Refusals: 29/100

Elapsed time: 16s
Estimated remaining time: 25m 56s

Running trial 3 of 200...
* Parameters:
  * direction_index = per layer
  * attn.o_proj.max_weight = 1.05
  * attn.o_proj.max_weight_position = 20.51
  * attn.o_proj.min_weight = 0.85
  * attn.o_proj.min_weight_distance = 15.69
  * mlp.down_proj.max_weight = 1.05
  * mlp.down_proj.max_weight_position = 23.62
  * mlp.down_proj.min_weight = 0.49
  * mlp.down_proj.min_weight_distance = 12.59
* Reloading model...
* Abliterating...
  Abliteration logic took 0.0087s
  Peak VRAM overhead: 12.01 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.11
  * Counting model refusals...

@p-e-w
Copy link
Owner

p-e-w commented Nov 25, 2025

This was actually itself a performance optimization.

Note that for MoE models, there can be many MLP matrices per layer, as many as there are experts. So for a 128-expert MoE, pre-computing the projector saves 128 matrix-vector multiplications per layer (though of course their total rank is lower than when multiplying with the full projector matrix, the number of coefficient multiplications is still less).

I am also somewhat hesitant in general to complicate the ablation logic in any way unless the resulting gains are substantial, because reliably testing that logic is difficult so it needs to be obviously correct, especially since it's about to become more complex anyway with #43.

@red40maxxer
Copy link
Contributor Author

red40maxxer commented Nov 26, 2025

This was actually itself a performance optimization.

Note that for MoE models, there can be many MLP matrices per layer, as many as there are experts. So for a 128-expert MoE, pre-computing the projector saves 128 matrix-vector multiplications per layer (though of course their total rank is lower than when multiplying with the full projector matrix, the number of coefficient multiplications is still less).

I am also somewhat hesitant in general to complicate the ablation logic in any way unless the resulting gains are substantial, because reliably testing that logic is difficult so it needs to be obviously correct, especially since it's about to become more complex anyway with #43.

The proof of correctness is really trivial, it's 1 line and follows from the associativity of matrix multiplication: instead of doing projector <- r r^T and then projector @ W, we do r (r^T W). I think my comments do an OK job of demonstrating how it corresponds with the abliteration formula in Arditi et. al, but if it's not worth potential future confusion we can move on.

I took a look at #43 and I'm 99% sure my changes are compatible, I'm also OK with waiting for it to be merged and then rebasing. Lmk what you think!

@p-e-w
Copy link
Owner

p-e-w commented Nov 28, 2025

Have you benchmarked this with a large MoE model to test how much worse the performance is because of the issue I described above?

@red40maxxer
Copy link
Contributor Author

red40maxxer commented Dec 1, 2025

Have you benchmarked this with a large MoE model to test how much worse the performance is because of the issue I described above?

Tried Qwen1.5-MoE-A2.7B on a Quadro RTX 6000. Without the optimization, double the VRAM is used for the abliteration op (32MiB -> 16MiB) and it takes 4x the time due to the issue you described.

I'm going to try with Phi-3.5-MoE-instruct on an A100 sometime later this week, but this looks OK to me

After optimization

root@e7bfffcc2a37:~/heretic# heretic Qwen/Qwen1.5-MoE-A2.7B
█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀  v1.0.1
█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░
▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀  https://github.com/p-e-w/heretic

GPU type: Quadro RTX 6000

Loading model Qwen/Qwen1.5-MoE-A2.7B...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.12s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
Ok
* Transformer model with 24 layers
* Abliterable components:
  * attn.o_proj: 1 matrices per layer
  * mlp.down_proj: 60 matrices per layer

Loading good prompts from mlabonne/harmless_alpaca...
* 400 prompts loaded

Loading bad prompts from mlabonne/harmful_behaviors...
* 400 prompts loaded

Loading good evaluation prompts from mlabonne/harmless_alpaca...
* 100 prompts loaded
* Obtaining first-token probability distributions...

Loading bad evaluation prompts from mlabonne/harmful_behaviors...
* 100 prompts loaded
* Counting model refusals...
* Initial refusals: 78/100

Calculating per-layer refusal directions...
* Obtaining residuals for good prompts...
* Obtaining residuals for bad prompts...

Running trial 1 of 200...
* Parameters:
  * direction_index = per layer
  * attn.o_proj.max_weight = 1.07
  * attn.o_proj.max_weight_position = 14.52
  * attn.o_proj.min_weight = 0.56
  * attn.o_proj.min_weight_distance = 10.60
  * mlp.down_proj.max_weight = 0.83
  * mlp.down_proj.max_weight_position = 13.94
  * mlp.down_proj.min_weight = 0.26
  * mlp.down_proj.min_weight_distance = 11.26
* Reloading model...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.03s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
* Abliterating...
  Abliteration logic took 0.1855s
  Peak VRAM overhead: 16.01 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.02
  * Counting model refusals...
  * Refusals: 79/100

Elapsed time: 4m 33s
Estimated remaining time: 15h 4m

Running trial 2 of 200...
* Parameters:
  * direction_index = 12.71
  * attn.o_proj.max_weight = 1.38
  * attn.o_proj.max_weight_position = 13.96
  * attn.o_proj.min_weight = 0.37
  * attn.o_proj.min_weight_distance = 2.46
  * mlp.down_proj.max_weight = 0.94
  * mlp.down_proj.max_weight_position = 19.43
  * mlp.down_proj.min_weight = 0.81
  * mlp.down_proj.min_weight_distance = 5.78
* Reloading model...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.02s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
* Abliterating...
  Abliteration logic took 0.1515s
  Peak VRAM overhead: 16.02 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.00
  * Counting model refusals...

Before optimization

root@e7bfffcc2a37:~/heretic# heretic Qwen/Qwen1.5-MoE-A2.7B
█░█░█▀▀░█▀▄░█▀▀░▀█▀░█░█▀▀  v1.0.1
█▀█░█▀▀░█▀▄░█▀▀░░█░░█░█░░
▀░▀░▀▀▀░▀░▀░▀▀▀░░▀░░▀░▀▀▀  https://github.com/p-e-w/heretic

GPU type: Quadro RTX 6000

Loading model Qwen/Qwen1.5-MoE-A2.7B...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.01s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
Ok
* Transformer model with 24 layers
* Abliterable components:
  * attn.o_proj: 1 matrices per layer
  * mlp.down_proj: 60 matrices per layer

Loading good prompts from mlabonne/harmless_alpaca...
* 400 prompts loaded

Loading bad prompts from mlabonne/harmful_behaviors...
* 400 prompts loaded

Loading good evaluation prompts from mlabonne/harmless_alpaca...
* 100 prompts loaded
* Obtaining first-token probability distributions...

Loading bad evaluation prompts from mlabonne/harmful_behaviors...
* 100 prompts loaded
* Counting model refusals...
* Initial refusals: 78/100

Calculating per-layer refusal directions...
* Obtaining residuals for good prompts...
* Obtaining residuals for bad prompts...

Running trial 1 of 200...
* Parameters:
  * direction_index = 20.01
  * attn.o_proj.max_weight = 1.10
  * attn.o_proj.max_weight_position = 22.42
  * attn.o_proj.min_weight = 0.89
  * attn.o_proj.min_weight_distance = 12.21
  * mlp.down_proj.max_weight = 0.98
  * mlp.down_proj.max_weight_position = 17.17
  * mlp.down_proj.min_weight = 0.61
  * mlp.down_proj.min_weight_distance = 11.01
* Reloading model...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.01s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
* Abliterating...
  Abliteration logic took 0.6206s
  Peak VRAM overhead: 32.01 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.02
  * Counting model refusals...
  * Refusals: 75/100

Elapsed time: 4m 34s
Estimated remaining time: 15h 8m

Running trial 2 of 200...
* Parameters:
  * direction_index = 17.47
  * attn.o_proj.max_weight = 1.16
  * attn.o_proj.max_weight_position = 17.80
  * attn.o_proj.min_weight = 0.80
  * attn.o_proj.min_weight_distance = 9.82
  * mlp.down_proj.max_weight = 1.47
  * mlp.down_proj.max_weight_position = 15.15
  * mlp.down_proj.min_weight = 1.45
  * mlp.down_proj.min_weight_distance = 9.48
* Reloading model...
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:08<00:00,  1.01s/it]
Some parameters are on the meta device because they were offloaded to the cpu.
* Abliterating...
  Abliteration logic took 0.7275s
  Peak VRAM overhead: 32.01 MiB
* Evaluating...
  * Obtaining first-token probability distributions...
  * KL divergence: 0.25
  * Counting model refusals...

@p-e-w
Copy link
Owner

p-e-w commented Dec 1, 2025

Wait, the new logic is also faster? How does that work given that there are fewer coefficients to multiply?

@red40maxxer
Copy link
Contributor Author

red40maxxer commented Dec 2, 2025

Wait, the new logic is also faster? How does that work given that there are fewer coefficients to multiply?

Forming the dense projector P = v v^T turns a rank-1 operation into a full d × d matrix multiply. Applying it requires (d × d) @ (d × k) for every expert, which is O(d²·k) work per matrix and forces a large d² tensor through GPU memory.

The new approach v(v^T M) exploits the fact that the projector is rank-1. We use the associativity of matrix multiplication to factor the operation into 2 computationally cheap steps:

  1. v^T M, a matrix-vector multiply in O(d·k) time
  2. v (v^T M), outer product in O(d·k) time

The real cost isn't in computing the projection, it's multiplying the dense projection matrix into every expert matrix. Matrix-matrix multiplies tend to be very expensive. This way, we can avoid constructing the projector matrix and doing matrix-matrix multiplies altogether, giving us some time and space back.

Let me know if this clears it up, I'm not an expert on linear algebra so I'm learning here as well :)

@p-e-w
Copy link
Owner

p-e-w commented Dec 2, 2025

Ah yes, sorry, I just crawled out of bed and had the dimensionalities confused there. I literally just tried to do the same analysis in my head and got the reverse results, but your math is certainly correct 😄

We're of course merging this then, let me just do a quick review.

Copy link
Owner

@p-e-w p-e-w left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some comments regarding the explanations. They also show how incredibly subtle this stuff is, which is why I always hesitate to modify such code.

layer_refusal_direction,
).to(self.model.dtype)
# We use the property (r r^T) W = r (r^T W) to avoid computing
# the O(d^2) projector matrix and the O(d^3) matrix multiplication.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiplication is actually $O(d^2 k)$ as you noted.


# Calculate the projection scalars: (r^T W)
# hat_r is (d, 1), matrix is (d, k) -> result is (k,)
r_transpose_W = torch.matmul(hat_r_device, matrix)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's inconsistent to use hat_r above but not hat_r_transpose here, even though the paper uses $\hat{r}$ for both cases.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest removing the hat_ prefix everywhere, as it just complicates things and serves no explanatory purpose in the code.

hat_r_device = hat_r.to(matrix.device)

# Calculate the projection scalars: (r^T W)
# hat_r is (d, 1), matrix is (d, k) -> result is (k,)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, that doesn't work. The shapes you describe aren't compatible. The inner dimensions must match in matrix multiplication.

What actually happens is that Torch prepends (1, to 1d vectors, so hat_r is (1, d), not (d, 1).

r_transpose_W = torch.matmul(hat_r_device, matrix)

# Calculate the update matrix: r * (r^T W)
# Outer product of (d, 1) and (1, k) -> result is (d, k)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite, actually. What you are describing is matrix multiplication. The outer product takes two column vectors $a$ and $b$, and is the equivalent to the matrix multiplication $ab^T$, but the outer product does not transpose the second vector. Both arguments to the outer product are column vectors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm, technically the outer product already includes a transposition of the second vector, though at this point it's just semantics. I get that the comment is a bit misleading though so I'll be more clear on why the outer product is used here

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outer product is formally not a type of matrix multiplication at all, though it happens to be equal to a matrix multiplication. The inputs are not transposed, and passing tensors of shapes (d, 1) and (1, k) is undefined, and would raise an error in PyTorch.

@red40maxxer red40maxxer requested a review from p-e-w December 2, 2025 01:33
layer_refusal_direction,
).to(self.model.dtype)
# We use the property (r r^T) W = r (r^T W) to avoid computing
# the O(d^2 k) projector matrix and the O(d^3) matrix multiplication.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think those complexities are still not correct. Computing the projector takes $d^2$ float multiplications, not $d^2 k$, and multiplying it with the matrix takes $d^2 k$ multiplications, not $d^3$.

).to(self.model.dtype)
# We use the property (r r^T) W = r (r^T W) to avoid computing
# the O(d^2 k) projector matrix and the O(d^3) matrix multiplication.
# W_new = W - r * (r^T W)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a weight, so this formula is incomplete.

r_device = r.to(matrix.device)

# Calculate the projection scalars: (r^T W)
# r is (1, d), matrix is (d, k) -> result is (k,)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

$r$ is actually (d,). torch.matmul internally transforms it to (1, d).

return torch.cat(logprobs, dim=0)

def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
def stream_cresponse(self, chat: list[dict[str, str]]) -> str:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦

@p-e-w p-e-w merged commit 60bd531 into p-e-w:master Dec 2, 2025
4 checks passed
@p-e-w
Copy link
Owner

p-e-w commented Dec 2, 2025

Okay, this is going in. Thanks!

accemlcc added a commit to accemlcc/heretic-lora that referenced this pull request Dec 2, 2025
* perf: optimize abliteration matrix op

* refactor: comments and var names correspond with arditi

* refactor: fix comments and improve var notation

* fix: accidental line change and improve comments

---------

Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
p-e-w added a commit that referenced this pull request Dec 7, 2025
p-e-w added a commit that referenced this pull request Dec 7, 2025
@p-e-w
Copy link
Owner

p-e-w commented Dec 7, 2025

@red40maxxer

Unfortunately, I had to revert this PR, because it breaks if matrix is actually a matrix stack (3D tensor), as is the case for gpt-oss. torch.outer doesn't broadcast as stated in the documentation, so those cases aren't just analogous. This was pointed out in #72.

I'd be happy to merge a fixed version of this PR again, but this late in the release cycle I can't take such risks, so for now, this has to be reverted.

@p-e-w p-e-w mentioned this pull request Dec 7, 2025
p-e-w pushed a commit that referenced this pull request Dec 14, 2025
* Add files via upload

* perf: optimize abliteration matrix op (#46)

* perf: optimize abliteration matrix op

* refactor: comments and var names correspond with arditi

* refactor: fix comments and improve var notation

* fix: accidental line change and improve comments

---------

Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>

* Fix line endings to LF

* Add hybrid approach for GPT-OSS compatibility

- Check for LoRA adapters before attempting LoRA abliteration
- Fall back to direct weight modification for nn.Parameter (GPT-OSS)
- Ensures compatibility across all model architectures

* Fix projector bug, update print statement, revert README

* Revert README changes to match upstream

* Fix import sorting for ruff

* Fix reload_model for evaluate_model, add type hints and validation

* Apply ruff formatting

* Replace load_in_4bit with quantization enum

* Fix precision loss: use FP32 refusal direction directly

* Move r assignment into non-LoRA path

* Fix linting: apply ruff formatting

* Add auto-merge for LoRA adapters on save/upload

* Fix linting: apply ruff formatting

* Implement CPU-based merge for 4-bit models with OOM fallback

* Remove use_lora flag (LoRA always on), add user prompt for 4-bit export

* Fix: PEFT target_modules expects module names without path prefix

* Fix linting: apply ruff formatting

* Add LoRA fallback and fix quantization_config handling

- Add try/except around LoRA initialization with fallback to direct weight modification
- Only pass quantization_config when not None (fixes gpt-oss loading)
- Use simple forward pass instead of generate() for model test (avoids chat template issues)
- Reset non-LoRA models by reloading in reload_model()
- Check self.use_lora before accessing LoRA adapters in abliterate()

* Add 8-bit quantization support via bitsandbytes

- Add BNB_8BIT option to QuantizationMethod enum
- Add --load-in-8bit CLI support (auto via pydantic-settings)
- Update documentation in config.py and config.default.toml
- Useful for mid-range VRAM (12-16 GB) as balance between memory and numeric stability

* Improve LoRA merge warning and fix linting

* Apply final ruff formatting

* Fix CI: apply ruff import sorting

* Use tiny model for CI efficiency

* Fix import sorting in test_lora.py

* Fix formatting in test_lora.py

* feat: Show merge warning for all models (requires high RAM)

* style: Apply ruff fixes

* Fix undefined Style import in main.py

* Fix(model): Support MoE/3D tensors and enforce dtype safety in abliterate

* Fix(ci): Format model.py with ruff

* Fix(main): Remove invalid style argument from prompt_select and unused import

* Fix logic errors, memory leak, and redundant merges in main.py

* Fix linting and formatting issues (isort, ruff)

* chore: Simplify .gitattributes as requested

* refactor: Remove defensive try-except around LoRA initialization

* chore: Update uv.lock with peft and bitsandbytes

* chore: Regenerate uv.lock to include missing peft dependency

* style: Fix import sorting (isort) for CI compliance

* style: Simplify .gitattributes to single line as requested

* Address PR #60 feedback: Remove caching, fix LoRA reload, global LoRA usage, style fixes

* Address PR review comments: clarify code, fix quantization, rename method

- Add explanatory comments for warning suppression and gc behavior
- Remove redundant gc.collect() calls (empty_cache handles it)
- Fix output message order (ask merge strategy before 'Uploading...')
- Add comment explaining 8-bit quantization doesn't need compute_dtype
- Remove extra newline after dtype comment
- Add future-proofing note for hybrid layer support (#43)
- Remove leftover comment in get_merged_model
- Delete test_lora.py (debug script, not a real test)
- Add comment explaining needs_reload flag purpose
- Extract quantization config into _get_quantization_config() helper
- Rename reload_model() to reset_model_for_trial() for clarity
- Fix reload_model to respect quantization config (fixes evaluate_model bug)
- Remove unused gc import

* Restore gc.collect() before empty_cache() for large models

* refactor: Remove LoRA fallback remnants, simplify code

- Remove use_lora flag (always true since LoRA is always applied)
- Remove isinstance(PeftModel) check in get_merged_model() (always true)
- Simplify reset_model_for_trial() by removing defensive try/except
- Remove redundant gc.collect() calls (empty_cache handles GC)
- Remove unused gc import from main.py

* Address p-e-w review feedback: rename reset_model, remove loaded_model_name, fix type hints, remove GPT-OSS MoE, update assertion

* Restore skip logic for non-LoRA modules and fix 4-bit base_layer.weight access

* Remove defensive lora_A check per review - get_layer_modules already filters

* Fix try_add: nest component init inside Module check, add assert for unexpected types

* Add note about module.weight assumption for type checking

* Change 'Reloading model' to 'Resetting model' in logging

---------

Co-authored-by: accemlcc <accemlcc@users.noreply.github.com>
Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
Co-authored-by: Hager <Michael.Hager@bruker.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants