-
Notifications
You must be signed in to change notification settings - Fork 23
basic knot merging #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
basic knot merging #140
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new Knot merging transform for LoRA experts, computing SVD components, storing them locally, and merging expert weights uniformly via TIES.
- Added
KnotMergeandKnotMergeConfigto perform SVD-based merges. - Updated
TiesMergeto factor out parameter-merging logic intomerge_param. - Added a
test_knot_mergeunit test to validate the new transform.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tests/test_library_transforms.py | Added test_knot_merge to verify the KnotMerge flow. |
| mttl/models/library/library_transforms.py | Implemented KnotMerge transform and refactored TiesMerge. |
Comments suppressed due to low confidence (3)
mttl/models/library/library_transforms.py:11
- Add
import torchat the top of this file so that all references totorch.save,torch.load, and other torch APIs resolve correctly.
from typing import Dict, List, Union
mttl/models/library/library_transforms.py:361
- [nitpick] Rename the variable
ties_mergerttoties_mergerto fix the typo and clarify its purpose.
ties_mergert = TiesMerge()
tests/test_library_transforms.py:96
- [nitpick] Consider adding assertions to verify that the merged weights themselves match expected values (e.g., compare against a manual U @ final_param calculation) to improve test coverage.
assert len(merged_layers) == len(exp.expert_weights.keys()) == 1
| @LibraryTransform.register("weighted_knot_merge", KnotMergeConfig) | ||
| class KnotMerge(LibraryTransform): | ||
| """ | ||
| Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735 |
Copilot
AI
Jun 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the typo ezperts to experts in the docstring.
| Computes a weighted KnoT merge for LoRA ezperts as in https://arxiv.org/pdf/2410.19735 | |
| Computes a weighted KnoT merge for LoRA experts as in https://arxiv.org/pdf/2410.19735 |
| used += keep_mask.sum().item() | ||
| else: | ||
| # sign majority vote | ||
| sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() |
Copilot
AI
Jun 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second assignment to sign_per_dim overrides the first; remove the redundant line or clarify which operation is intended.
| sign_per_dim = expert_weights.sign().sum(0, keepdim=True).sign() |
Uh oh!
There was an error while loading. Please reload this page.