Skip to content
Open

Develop #1272

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3cf7cd0
first commit
SunZichen-2004 Sep 12, 2025
35502b0
remove unused variables to fix pylint errors
vue1999 Sep 15, 2025
96d9965
superficial changes
vue1999 Sep 15, 2025
2ba89df
small changes
vue1999 Sep 15, 2025
692efa6
Adding new test, lr_params_factors input changed to str
vue1999 Sep 16, 2025
d112e56
Merge pull request #1 from vue1999/fix-checks
SunZichen-2004 Sep 16, 2025
3226dde
Starting Linear LoRAs
vue1999 Sep 18, 2025
2488e1c
LinearLoRA fixes
vue1999 Sep 18, 2025
b371f56
Attempt at generalised equivariant loras
vue1999 Sep 18, 2025
0eae9e5
linear lora fixes
ttompa Sep 19, 2025
f6b3f72
fully connected layer LORA - initial commit
ttompa Sep 19, 2025
ef7eacf
cleanup
ttompa Sep 19, 2025
83cbb11
code cleanup
ttompa Sep 27, 2025
75d276b
add rank and alpha args
ttompa Sep 28, 2025
cc9999b
equivariance tests
ttompa Sep 28, 2025
19eeffe
add option to force pseudolabel stress
ttompa Sep 28, 2025
9477ecc
fix pylint
vue1999 Nov 3, 2025
29fe80d
pylint fixes
vue1999 Nov 3, 2025
9f357f7
Merge pull request #9 from ACEsuit/develop
vue1999 Nov 3, 2025
87a82ad
pre-commit fixes
ttompa Nov 4, 2025
524a5f0
fix failing test on python 3.8 due to mixed dtype
ttompa Nov 4, 2025
779530c
Remove default dtype setting for lora tests
vue1999 Nov 5, 2025
985c7fe
Revert "fix failing test on python 3.8 due to mixed dtype"
vue1999 Nov 5, 2025
0b7a0a1
fix lora tests
ttompa Nov 6, 2025
e27027d
Merge pull request #12 from vue1999/develop
ttompa Nov 7, 2025
e0f4ef6
Merge pull request #1259 from vue1999/lora-finetuning
ilyes319 Nov 11, 2025
ddd9feb
change selection head ft to info
ilyes319 Nov 11, 2025
96f94b7
Merge branch 'develop' of https://github.com/ACEsuit/mace into develop
ilyes319 Nov 11, 2025
36ee765
add mace-mh-0/1 to readme
ilyes319 Nov 11, 2025
f90a2dc
add More info
ilyes319 Nov 11, 2025
271d7bf
Merge branch 'develop' into freeze-weights
ttompa Nov 15, 2025
9ff5b8a
remove skip on fixtures
ilyes319 Nov 20, 2025
9498340
fix the array equality in info dict
ilyes319 Nov 20, 2025
25276d1
skip checking for equality for arrays
ilyes319 Nov 20, 2025
501a321
fix the beta for the schedulefree
ilyes319 Nov 20, 2025
075c544
fix the linting
ilyes319 Nov 20, 2025
0212de8
Merge pull request #1258 from vue1999/freeze-weights
ilyes319 Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ Foundation models are a rapidly evolving field. Please look at the [MACE-MP GitH
| MACE-MATPES-PBE-0 | 89 | MATPES-PBE | DFT (PBE) | Materials | [medium](https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model) | >=v0.3.10 | No +U correction. | ASL |
| MACE-MATPES-r2SCAN-0 | 89 | MATPES-r2SCAN | DFT (r2SCAN) | Materials | [medium](https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model) | >=v0.3.10 | Better functional for materials. | ASL |
| MACE-OMOL-0 | 89 | OMOL | DFT (wB97M-VV10) | Molecules/Transition metals/Cations | [large](https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/MACE-omol-0-extra-large-1024.model) | >=v0.3.14 | Charge/Spin embedding, very good molecular accuracy. | ASL |
| MACE-MH-0/1 | 89 | OMAT/OMOL/OC20/MATPES | DFT (PBE/R2SCAN/wB97M-VV10) | Inorganic crystals, molecules and surfaces. [More info.](https://huggingface.co/mace-foundations/mace-mh-1) | [mh-0](https://github.com/ACEsuit/mace-foundations/releases/download/mace_mh_1/mace-mh-0.model) [mh-1](https://github.com/ACEsuit/mace-foundations/releases/download/mace_mh_1/mace-mh-1.model) | >=v0.3.14 | Very good cross domain performance on surfaces/bulk/molecules. | ASL |


### MACE-MP: Materials Project Force Fields

Expand Down
14 changes: 13 additions & 1 deletion mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,20 @@ def check_state(self, atoms, tol: float = 1e-15) -> list:
Returns:
list: A list of changes detected in the system.
"""

def _infos_equal(a: dict, b: dict) -> bool:
if a.keys() != b.keys():
return False
for k in a:
va, vb = a[k], b[k]
if isinstance(va, np.ndarray) or isinstance(vb, np.ndarray):
continue
if va != vb:
return False
return True

state = super().check_state(atoms, tol=tol)
if (not state) and (self.atoms.info != atoms.info):
if (not state) and (not _infos_equal(self.atoms.info, atoms.info)):
state.append("info")
return state

Expand Down
46 changes: 37 additions & 9 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.tools import torch_geometric
from mace.tools.distributed_tools import init_distributed
from mace.tools.lora_tools import inject_LoRAs
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
HeadConfig,
Expand Down Expand Up @@ -201,8 +202,8 @@ def run(args) -> None:
)
if hasattr(model_foundation, "heads"):
if len(model_foundation.heads) > 1:
logging.warning(
f"Mutlihead finetuning with models with more than one head is not supported, using the head {args.foundation_head} as foundation head."
logging.info(
f"Selecting the head {args.foundation_head} as foundation head."
)
model_foundation = remove_pt_head(
model_foundation, args.foundation_head
Expand Down Expand Up @@ -548,7 +549,8 @@ def run(args) -> None:
pt_head_config=head_config,
r_max=args.r_max,
device=device,
batch_size=args.batch_size
batch_size=args.batch_size,
force_stress=args.pseudolabel_replay_compute_stress,
):
logging.info("Successfully applied pseudolabels to pt_head configurations")
else:
Expand Down Expand Up @@ -704,9 +706,28 @@ def run(args) -> None:
model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs)
model.to(device)

logging.debug(model)
logging.info(f"Total number of parameters: {tools.count_parameters(model)}")
logging.info("")
if args.lora:
lora_rank = args.lora_rank
lora_alpha = args.lora_alpha

logging.info(
"Injecting LoRA layers with rank=%s and alpha=%s",
lora_rank,
lora_alpha,
)

logging.info(
"Original model has %s trainable parameters.",
tools.count_parameters(model),
)

model = inject_LoRAs(model, rank=lora_rank, alpha=lora_alpha)

logging.info(
"Model with LoRA has %s trainable parameters.",
tools.count_parameters(model),
)

logging.info("===========OPTIMIZER INFORMATION===========")
logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
logging.info(f"Batch size: {args.batch_size}")
Expand Down Expand Up @@ -736,8 +757,18 @@ def run(args) -> None:

# Optimizer
param_options = get_params_options(args, model)

optimizer: torch.optim.Optimizer
optimizer = get_optimizer(args, param_options)
logging.info("=== Layer's learning rates ===")
for name, p in model.named_parameters():
st = optimizer.state.get(p, {})
if st:
logging.info(f"Param: {name}: {list(st.keys())}")

for i, param_group in enumerate(optimizer.param_groups):
logging.info(f"Param group {i}: lr = {param_group['lr']}")

if args.device == "xpu":
logging.info("Optimzing model and optimzier for XPU")
model, optimizer = ipex.optimize(model, optimizer=optimizer)
Expand Down Expand Up @@ -784,9 +815,6 @@ def run(args) -> None:
ema: Optional[ExponentialMovingAverage] = None
if args.ema:
ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay)
else:
for group in optimizer.param_groups:
group["lr"] = args.lr

if args.lbfgs:
logging.info("Switching optimizer to LBFGS")
Expand Down
48 changes: 48 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str2bool,
default=False,
)
parser.add_argument(
"--pseudolabel_replay_compute_stress",
help="When replay pseudolabels are generated, always generate stress labels even if the original replay data lacked stress",
type=str2bool,
default=False,
)
parser.add_argument(
"--foundation_filter_elements",
help="Filter element during fine-tuning",
Expand Down Expand Up @@ -534,6 +540,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str2bool,
default=False,
)
parser.add_argument(
"--lora",
help="Use Low-Rank Adaptation for the fine-tuning",
type=str2bool,
default=False,
)
parser.add_argument(
"--lora_rank",
help="Rank of the LoRA matrices",
type=int,
default=4,
)
parser.add_argument(
"--lora_alpha",
help="Scaling factor for LoRA",
type=float,
default=1.0,
)

# Keys
parser.add_argument(
Expand Down Expand Up @@ -745,6 +769,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=float,
default=0.9,
)
parser.add_argument(
"--beta1_schedulefree",
help="Beta1 parameter for the ScheduleFree optimizer",
type=float,
default=0.9,
)
parser.add_argument(
"--beta2_schedulefree",
help="Beta2 parameter for the ScheduleFree optimizer",
type=float,
default=0.98,
)
parser.add_argument("--batch_size", help="batch size", type=int, default=10)
parser.add_argument(
"--valid_batch_size", help="Validation batch size", type=int, default=10
Expand All @@ -763,6 +799,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7
)
parser.add_argument(
"--lr_params_factors",
help="Learning rate factors to multiply on the original lr",
type=str,
default='{"embedding_lr_factor": 1.0, "interactions_lr_factor": 1.0, "products_lr_factor": 1.0, "readouts_lr_factor": 1.0}',
)
parser.add_argument(
"--freeze",
help="Freeze layers from 1 to N. Can be positive or negative, e.g. -1 means the last layer is frozen. 0 or None means all layers are active and is a default setting",
type=int,
default=None,
)
parser.add_argument(
"--amsgrad",
help="use amsgrad variant of optimizer",
Expand Down
185 changes: 185 additions & 0 deletions mace/tools/lora_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import torch
from e3nn import o3
from e3nn.nn._fc import _Layer as E3NNFCLayer
from torch import nn


def build_lora_irreps(
irreps_in: o3.Irreps, irreps_out: o3.Irreps, rank: int
) -> o3.Irreps:
"""
Choose an equivariant bottleneck irreps that preserves symmetry: for every irrep
present in BOTH input and output, allocate `rank` copies.
"""
in_set = {ir for _, ir in o3.Irreps(irreps_in)}
out_set = {ir for _, ir in o3.Irreps(irreps_out)}
shared = sorted(in_set & out_set, key=lambda ir: (ir.l, ir.p))
if not shared:
raise ValueError(
f"No shared irreps between input ({irreps_in}) and output ({irreps_out}); cannot build equivariant LoRA."
)
parts = [f"{rank}x{ir}" for ir in shared]
return o3.Irreps(" + ".join(parts))


class LoRAO3Linear(nn.Module):
"""LoRA for equivariant o3.Linear-like layers (preserves O(3) equivariance)."""

def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0):
super().__init__()
self.base = base_linear
self.irreps_in = self.base.irreps_in
self.irreps_out = self.base.irreps_out
self.scaling = float(alpha) / float(rank)
self.lora_irreps = build_lora_irreps(self.irreps_in, self.irreps_out, rank)
# Use the same class as base to avoid layout mismatches if possible
layer_type = type(self.base)
self.lora_A = layer_type(
self.irreps_in, self.lora_irreps, internal_weights=True, biases=False
)
self.lora_B = layer_type(
self.lora_irreps, self.irreps_out, internal_weights=True, biases=False
)
# Match dtype/device to base
base_param = next(self.base.parameters())
self.lora_A.to(dtype=base_param.dtype, device=base_param.device)
self.lora_B.to(dtype=base_param.dtype, device=base_param.device)

with torch.no_grad():
for p in self.lora_B.parameters():
p.zero_()
for p in self.lora_A.parameters():
if p.dim() >= 2:
p.normal_(mean=0.0, std=1e-3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
base = self.base(x)
delta = self.lora_B(self.lora_A(x))
return base + self.scaling * delta


class LoRADenseLinear(nn.Module):
"""LoRA for torch.nn.Linear"""

def __init__(self, base_linear: nn.Linear, rank: int = 4, alpha: float = 1.0):
super().__init__()
self.base = base_linear
in_f = self.base.in_features
out_f = self.base.out_features
self.scaling = float(alpha) / float(rank)
self.lora_A = nn.Linear(in_f, rank, bias=False)
self.lora_B = nn.Linear(rank, out_f, bias=False)

# match dtype/device to base
base_param = next(self.base.parameters())
self.lora_A.to(dtype=base_param.dtype, device=base_param.device)
self.lora_B.to(dtype=base_param.dtype, device=base_param.device)

with torch.no_grad():
nn.init.zeros_(self.lora_B.weight)
nn.init.normal_(self.lora_A.weight, mean=0.0, std=1e-3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
base = self.base(x)
delta = self.lora_B(self.lora_A(x))
return base + self.scaling * delta


class LoRAFCLayer(nn.Module):
"""LoRA for e3nn.nn._fc._Layer used by FullyConnectedNet (scalar MLP).
Adds a low-rank delta on the weight matrix.
"""

def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0):
super().__init__()
if not hasattr(base_layer, "weight"):
raise TypeError("LoRAFCLayer requires a layer with a 'weight' parameter")
self.base = base_layer

w = self.base.weight # type: ignore[attr-defined]
in_f, out_f = int(w.shape[0]), int(w.shape[1])
self.scaling = float(alpha) / float(rank)

# Use explicit parameters to match e3nn layout [in, out]
self.lora_A = nn.Parameter(
torch.empty(in_f, rank, device=w.device, dtype=w.dtype)
)
self.lora_B = nn.Parameter(
torch.empty(rank, out_f, device=w.device, dtype=w.dtype)
)

with torch.no_grad():
torch.nn.init.normal_(self.lora_A, mean=0.0, std=1e-3)
torch.nn.init.zeros_(self.lora_B)

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Replicate e3nn _Layer normalization
W = self.base.weight # type: ignore[attr-defined]
h_in = getattr(self.base, "h_in")
var_in = getattr(self.base, "var_in")
var_out = getattr(self.base, "var_out")
act = getattr(self.base, "act", None)

delta = self.lora_A @ self.lora_B
W_sum = W + self.scaling * delta

if act is not None:
denom = (h_in * var_in) ** 0.5
w = W_sum / denom
x = x @ w
x = act(x)
x = x * (var_out**0.5)
else:
denom = (h_in * var_in / var_out) ** 0.5
w = W_sum / denom
x = x @ w
return x


def inject_lora(
module: nn.Module,
rank: int = 4,
alpha: float = 1.0,
wrap_equivariant: bool = True,
wrap_dense: bool = True,
_is_root: bool = True,
) -> None:
"""
Recursively replace eligible linears with LoRA-wrapped versions.
"""

for child_name, child in list(module.named_children()):
# Skip already wrapped
if isinstance(child, (LoRAO3Linear, LoRADenseLinear, LoRAFCLayer)):
continue
# Equivariant o3.Linear
if wrap_equivariant and isinstance(child, o3.Linear):
try:
wrapped = LoRAO3Linear(child, rank=rank, alpha=alpha)
except ValueError: # If no shared irreps, skip
continue
module._modules[child_name] = wrapped # pylint: disable=protected-access
# Dense nn.Linear
if wrap_dense and isinstance(child, nn.Linear):
wrapped = LoRADenseLinear(child, rank=rank, alpha=alpha)
module._modules[child_name] = wrapped # pylint: disable=protected-access
continue
# e3nn FullyConnectedNet internal layer
if wrap_dense and isinstance(child, E3NNFCLayer):
wrapped = LoRAFCLayer(child, rank=rank, alpha=alpha)
module._modules[child_name] = wrapped # pylint: disable=protected-access
continue
# Recurse
inject_lora(child, rank, alpha, wrap_equivariant, wrap_dense, _is_root=False)

if _is_root:
for name, p in module.named_parameters():
if ("lora_A" in name) or ("lora_B" in name):
p.requires_grad = True
else:
p.requires_grad = False


def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1):
inject_lora(model, rank=rank, alpha=alpha, wrap_equivariant=True, wrap_dense=True)
return model
Loading
Loading