Skip to content

MACELES fine-tuning doesn't load LES parameters from keywords '--foundation model' #1305

@yixian929

Description

@yixian929

Describe the bug
When fine-tuning a MACELES model using --foundation_model, the LES-specific parameters (les_readouts) are not loaded from the foundation model. The load_foundations_elements function in mace/tools/finetuning_utils.py only loads standard MACE parameters (node_embedding, interactions, products, readouts, scale_shift) but ignores the les_readouts parameters that are specific to MACELES models . This causes the fine-tuned model to lose pre-trained LES knowledge and start with randomly initialized LES parameters extensions.py:62-70 .

To Reproduce
Steps to reproduce the behavior:

Train a MACELES model and save it
Attempt to fine-tune this model using mace_run_train --model="MACELES" --foundation_model="path/to/saved/MACELES.model" ...
The fine-tuning will start with randomly initialized les_readouts parameters instead of loading them from the foundation model
Check the model parameters to verify that les_readouts are not copied from the foundation model
Expected behavior
When fine-tuning a MACELES model with a foundation model, all parameters including the LES-specific les_readouts should be loaded from the foundation model. This ensures that the fine-tuned model retains the pre-trained LES knowledge and continues from the same parameter state.

Screenshots
Not applicable - this is a parameter loading issue that can be verified by checking model parameters before and after loading.

Additional context
The MACELES model extends ScaleShiftMACE with additional LES functionality extensions.py:47-71 . The les_readouts are created during model initialization as a copy of the standard readouts but with independent parameters extensions.py:66-70 .

The fix is straightforward - add LES parameter loading to the load_foundations_elements function after line 298 and before the return statement at line 299 :

Load LES readouts if present (for MACELES models) :

if hasattr(model, "les_readouts") and hasattr(model_foundations, "les_readouts"):
for (, param_1), (, param_2) in zip(
model.les_readouts.named_parameters(),
model_foundations.les_readouts.named_parameters(),
):
param_1.data.copy_(param_2.data)

This fix follows the same parameter copying pattern used elsewhere in the function and maintains backward compatibility with non-MACELES models through the hasattr checks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions