-
Notifications
You must be signed in to change notification settings - Fork 362
Description
Describe the bug
When fine-tuning a MACELES model using --foundation_model, the LES-specific parameters (les_readouts) are not loaded from the foundation model. The load_foundations_elements function in mace/tools/finetuning_utils.py only loads standard MACE parameters (node_embedding, interactions, products, readouts, scale_shift) but ignores the les_readouts parameters that are specific to MACELES models . This causes the fine-tuned model to lose pre-trained LES knowledge and start with randomly initialized LES parameters extensions.py:62-70 .
To Reproduce
Steps to reproduce the behavior:
Train a MACELES model and save it
Attempt to fine-tune this model using mace_run_train --model="MACELES" --foundation_model="path/to/saved/MACELES.model" ...
The fine-tuning will start with randomly initialized les_readouts parameters instead of loading them from the foundation model
Check the model parameters to verify that les_readouts are not copied from the foundation model
Expected behavior
When fine-tuning a MACELES model with a foundation model, all parameters including the LES-specific les_readouts should be loaded from the foundation model. This ensures that the fine-tuned model retains the pre-trained LES knowledge and continues from the same parameter state.
Screenshots
Not applicable - this is a parameter loading issue that can be verified by checking model parameters before and after loading.
Additional context
The MACELES model extends ScaleShiftMACE with additional LES functionality extensions.py:47-71 . The les_readouts are created during model initialization as a copy of the standard readouts but with independent parameters extensions.py:66-70 .
The fix is straightforward - add LES parameter loading to the load_foundations_elements function after line 298 and before the return statement at line 299 :
Load LES readouts if present (for MACELES models) :
if hasattr(model, "les_readouts") and hasattr(model_foundations, "les_readouts"):
for (, param_1), (, param_2) in zip(
model.les_readouts.named_parameters(),
model_foundations.les_readouts.named_parameters(),
):
param_1.data.copy_(param_2.data)
This fix follows the same parameter copying pattern used elsewhere in the function and maintains backward compatibility with non-MACELES models through the hasattr checks.