Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
bff0bc7
Add freeze functionality and update training scripts
7radians Jan 26, 2025
172ab75
Merge branch 'ACEsuit:main' into mace-freeze
7radians Jan 26, 2025
a4f7a80
Update README.md
7radians Jan 26, 2025
691a304
Update README.md
7radians Jan 26, 2025
91f5cd1
Update README.md
7radians Jan 26, 2025
5355bd6
Update README.md
7radians Jan 26, 2025
7b64cde
Merge branch 'ACEsuit:main' into mace-freeze
7radians Feb 1, 2025
230c37e
Updated freeze logic
7radians Feb 1, 2025
cd1e55b
Tidy up
7radians Feb 1, 2025
b4000e6
Tidy up
7radians Feb 2, 2025
24711bf
Merge branch 'ACEsuit:main' into mace-freeze
7radians Feb 4, 2025
2547209
Fix gradient norms check in active layers
7radians Feb 25, 2025
f1e44a9
Tidy up
7radians Feb 26, 2025
7baab8f
Removed commented out lines that are redundant
7radians Feb 27, 2025
4ab67a2
Merge branch 'ACEsuit:main' into mace-freeze
7radians Mar 13, 2025
f32cc23
Merge branch 'ACEsuit:main' into mace-freeze
7radians Mar 16, 2025
471f1d1
Merge branch 'ACEsuit:main' into mace-freeze
7radians Mar 18, 2025
179cee6
fix freeze with cuequivariance 0.2.0
7radians Mar 19, 2025
1aadb3b
linting
7radians Mar 19, 2025
e96e521
fix linting mistake
7radians Mar 22, 2025
0c4b689
Merge upstream/main into mace-freeze
7radians Apr 3, 2025
3b4f27f
freeze test added
7radians Apr 8, 2025
8d5dd59
added more freeze tests
7radians Apr 18, 2025
a0c3254
Merge branch 'ACEsuit:main' into mace-freeze
7radians Apr 18, 2025
f8a583c
tidy up, fix freeze cueq test
7radians Apr 18, 2025
fee9a3e
added soft-freeze (lr rescaling)
7radians Apr 21, 2025
5b89c6f
soft-freeze instructions + freeze improvement
7radians Apr 21, 2025
71d8604
minor README tweaks
7radians Apr 21, 2025
15a0157
Update README.md
7radians Apr 21, 2025
97e7b01
soft-freeze tests added
7radians Apr 21, 2025
19faa2f
Merge branch 'mace-freeze' of github.com:7radians/mace-freeze into ma…
7radians Apr 21, 2025
879e5a9
test fix
7radians Apr 21, 2025
10ed205
fix test
7radians Apr 22, 2025
2fe7b30
fix test
7radians Apr 22, 2025
50f94a3
bug fix freeze=0
7radians Apr 29, 2025
d429d9e
update to 0.3.13
7radians May 8, 2025
c45f805
note added
7radians May 8, 2025
fb13251
note added
7radians Jun 4, 2025
933231d
update+refactor
7radians Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- [Example usage in ASE](#example-usage-in-ase-1)
- [Finetuning foundation models](#finetuning-foundation-models)
- [Latest recommended foundation models](#latest-recommended-foundation-models)
- [MACE-freeze](#mace-freeze)
- [Caching](#caching)
- [Development](#development)
- [References](#references)
Expand Down Expand Up @@ -321,6 +322,105 @@ mace_run_train \
Other options are "medium" and "large", or the path to a foundation model.
If you want to finetune another model, the model will be loaded from the path provided `--foundation_model=$path_model`, all the hypers will be extracted automatically.

<a id="mace-freeze"></a>
## MACE-freeze

> [!Note]
> If using MACE 0.3.13 with freezing and cuEquivariance acceleration, please update all cuEquivariance dependencies to 0.4.0.

> [!Note]
> When using the latest foundation models: MPA, OMAT, MATPES — the recommended starting point is --freeze=6


**Installation**

To install the MACE-freeze version of MACE, clone the mace-freeze branch into your work folder:
```sh
git clone -b mace-freeze https://github.com/7radians/mace-freeze.git
pip install ./mace-freeze
```
### Full-freezing mode

This functionality allows to freeze neural network layers/parameters for transfer learning or other applications.

**Usage**

Use the `--freeze=<N>` to freeze the layers from the first one to N inclusive. Freezing a layer prevents its parameters from being updated during training. For example, to freeze the first 5 layers, provide a positive integer:

```sh
--freeze=5
```

To freeze the last N layers, provide a negative integer:

```sh
--freeze=-5
```
For more fine-grained control, the `--freeze_par` argument allows freezing by parameter tensors or vectors that form the model layers.
The log file will contain the hierarchical list of the layers and named parameters of the model within them, which can be used to identify specific components to freeze.
For example, to freeze the first 6 parameters of the model, use:

```sh
--freeze_par=6
```
`--freeze_par` takes integer values and works in the same manner as `--freeze`, freezing from 1 to N (if positive) or the last N (if negative).

**Note**

- By default, MACE-freeze assumes all layers and parameters are trainable. This is equivalent to: `--freeze=0` or `--freeze_par=0`.
- The `--freeze` and `--freeze_par` are mutually exclusive. If both are provided, only `--freeze` will take effect.
- If you intend to use freezing as the sole fine-tuning strategy, ensure that `--multiheads_finetuning=False`.
- Conversely, if you are using the multiheads fine-tuning method independently, either omit the --freeze argument or explicitly set `--freeze=0` in your training script.

### Soft-freezing mode

Soft-freezing provides a flexible approach to fine-tuning models by assigning a reduced learning rate to selected layers, rather than fully freezing them. This can be useful when adapting pretrained models, enabling a smooth transition between frozen and actively trained parameters.

**Overview**

Soft-freezing scales down the learning rate in a subset of layers using a multiplicative factor. This allows certain layers to update their weights more slowly than others, preserving learned representations while still allowing gradual adaptation. It can be used:
- On its own.
- In conjunction with layer freezing.
- Alongside multi-head fine-tuning.

**Configuration**
- `--soft_freeze=<N>`: Specifies the number of layers to soft-freeze.
- `--soft_freeze_factor=<float>`: A scaling factor between 0 and 1 applied to the base learning rate (`--lr`) for soft-frozen parameters.
- `--soft_freeze_swa` will also apply the soft-freeze logic to the Stage Two training by applying `--soft_freeze_factor` scaling to `--lr_swa`. Omit this flag to make the Stage Two learning rate uniform across parameters.

**Behavior with freezing**

When combined with standard layer freezing (`--freeze`):
- The first `--freeze` layers are fully frozen (i.e., excluded from optimization).
- The soft-freezing begins immediately after the frozen layers. For example:
`--freeze=5` and `--soft_freeze=1` will freeze layers 1–5, soft-freeze layer 6, and leave all subsequent layers fully trainable.

When combined with parameter-level freezing (`--freeze_par`):
- If a layer contains a mix of frozen and active parameters, the soft-freezing logic begins at the first layer with any active parameters.
- A single layer can contain both soft-frozen and fully frozen parameters, depending on the granularity of the freeze.

**Logging**

During training, the structure of the model, including its layers and parameters, is printed to the logs in a hierarchical format. Each parameter is annotated with its learning rate, showing:
- Fully frozen parameters (either showing the baseline `--lr` or n/a, as the learning rate is irrelevant for parameters that are not updated)
- Soft-frozen parameters (lr = `--lr` × `--soft_freeze_factor`)
- Actively trained parameters (lr = `--lr`)
This detailed breakdown allows for informed decision-making about the fine-tuning strategy applied to each part of the model.

### Referencing MACE-freeze
If you use the freezing functionality, please cite [this paper](https://arxiv.org/abs/2502.15582):
<details> <summary>BibTeX</summary>
@misc{radova2025freeze,
title={Fine-tuning foundation models of materials interatomic potentials with frozen transfer learning},
author={Mariia Radova and Wojciech G. Stark and Connor S. Allen and Reinhard J. Maurer and Albert P. Bartók},
year={2025},
eprint={2502.15582},
archivePrefix={arXiv},
primaryClass={cond-mat.mtrl-sci},
url={https://arxiv.org/abs/2502.15582},
}
</details>

## Caching

By default automatically downloaded models, like mace_mp, mace_off and data for fine tuning, end up in `~/.cache/mace`. The path can be changed by using
Expand Down
45 changes: 34 additions & 11 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.tools import torch_geometric
from mace.tools.distributed_tools import init_distributed
from mace.tools.freeze import freeze_layers, freeze_param
from mace.tools.model_script_utils import configure_model
from mace.tools.multihead_tools import (
HeadConfig,
Expand Down Expand Up @@ -61,6 +62,7 @@
get_optimizer,
get_params_options,
get_swa,
log_soft_freeze,
print_git_commit,
remove_pt_head,
setup_wandb,
Expand Down Expand Up @@ -689,16 +691,6 @@ def run(args) -> None:
logging.debug(model)
logging.info(f"Total number of parameters: {tools.count_parameters(model)}")
logging.info("")
logging.info("===========OPTIMIZER INFORMATION===========")
logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
logging.info(f"Batch size: {args.batch_size}")
if args.ema:
logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}")
logging.info(
f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}"
)
logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
logging.info(loss_fn)

# Cueq and OEQ conversion
if args.enable_cueq and args.enable_oeq:
Expand All @@ -716,8 +708,39 @@ def run(args) -> None:
assert model.__class__.__name__ in ["MACE", "ScaleShiftMACE", "MACELES"]
model = run_e3nn_to_oeq(deepcopy(model), device=device)


# Freeze layers or parameter groups
freeze, freeze_par = args.freeze, args.freeze_par
if freeze:
if freeze_par:
logging.info("Both --freeze and --freeze_par detected, using --freeze")
freeze_layers(model, freeze)
elif freeze_par:
freeze_param(model, freeze_par)


logging.info("")
logging.info("===========OPTIMIZER INFORMATION===========")
logging.info(f"Using {args.optimizer.upper()} as parameter optimizer")
logging.info(f"Batch size: {args.batch_size}")
if args.ema:
logging.info(f"Using Exponential Moving Average with decay: {args.ema_decay}")
logging.info(
f"Number of gradient updates: {int(args.max_num_epochs*len(train_set)/args.batch_size)}"
)
logging.info(f"Learning rate: {args.lr}, weight decay: {args.weight_decay}")
logging.info(loss_fn)


# Optimizer
param_options = get_params_options(args, model)

def log_stage(stage, param_options, model):
logging.info(f"========== {stage} PARAMETERS ==========")
log_soft_freeze(param_options, model)

log_stage("STAGE ONE", param_options, model)

optimizer: torch.optim.Optimizer
optimizer = get_optimizer(args, param_options)
if args.device == "xpu":
Expand All @@ -733,7 +756,7 @@ def run(args) -> None:
swas = [False]
if args.swa:
swa, swas = get_swa(args, model, optimizer, swas, dipole_only)

log_stage("STAGE TWO", param_options, model)
checkpoint_handler = tools.CheckpointHandler(
directory=args.checkpoints_dir,
tag=tag,
Expand Down
42 changes: 34 additions & 8 deletions mace/cli/visualise_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,10 +569,16 @@ def update(self, batch, output): # pylint: disable=arguments-differ
self.pred_energies_per_atom.append(output["energy"] / atoms_per_config)

self.n_energy += filter_nonzero_weight(
batch, self.ref_energies, batch.weight, batch.energy_weight,
batch,
self.ref_energies,
batch.weight,
batch.energy_weight,
)
filter_nonzero_weight(
batch, self.pred_energies, batch.weight, batch.energy_weight,
batch,
self.pred_energies,
batch.weight,
batch.energy_weight,
)
filter_nonzero_weight(
batch,
Expand Down Expand Up @@ -632,10 +638,18 @@ def update(self, batch, output): # pylint: disable=arguments-differ
self.pred_forces.append(output["forces"])

self.n_forces += filter_nonzero_weight(
batch, self.ref_forces, batch.weight, batch.forces_weight, spread_atoms=True,
batch,
self.ref_forces,
batch.weight,
batch.forces_weight,
spread_atoms=True,
)
filter_nonzero_weight(
batch, self.pred_forces, batch.weight, batch.forces_weight, spread_atoms=True,
batch,
self.pred_forces,
batch.weight,
batch.forces_weight,
spread_atoms=True,
)

# Stress
Expand All @@ -644,10 +658,16 @@ def update(self, batch, output): # pylint: disable=arguments-differ
self.pred_stress.append(output["stress"])

self.n_stress += filter_nonzero_weight(
batch, self.ref_stress, batch.weight, batch.stress_weight,
batch,
self.ref_stress,
batch.weight,
batch.stress_weight,
)
filter_nonzero_weight(
batch, self.pred_stress, batch.weight, batch.stress_weight,
batch,
self.pred_stress,
batch.weight,
batch.stress_weight,
)

# Virials
Expand All @@ -660,10 +680,16 @@ def update(self, batch, output): # pylint: disable=arguments-differ
self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d)

self.n_virials += filter_nonzero_weight(
batch, self.ref_virials, batch.weight, batch.virials_weight,
batch,
self.ref_virials,
batch.weight,
batch.virials_weight,
)
filter_nonzero_weight(
batch, self.pred_virials, batch.weight, batch.virials_weight,
batch,
self.pred_virials,
batch.weight,
batch.virials_weight,
)
filter_nonzero_weight(
batch,
Expand Down
2 changes: 1 addition & 1 deletion mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
elec_temp: Optional[torch.Tensor], # [,]
total_charge: Optional[torch.Tensor] = None, # [,]
total_spin: Optional[torch.Tensor] = None, # [,]
pbc: Optional[torch.Tensor] = None, # [, 3]
pbc: Optional[torch.Tensor] = None, # [, 3]
):
# Check shapes
num_nodes = node_attrs.shape[0]
Expand Down
4 changes: 1 addition & 3 deletions mace/modules/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def forward(
no_pbc_mask_cfg = ~pbc_tensor.any(dim=-1)
no_pbc_mask_rows = no_pbc_mask_cfg.repeat_interleave(3)
cell_les[no_pbc_mask_rows] = torch.zeros(
(no_pbc_mask_rows.sum(), 3),
dtype=cell_les.dtype,
device=cell_les.device
(no_pbc_mask_rows.sum(), 3), dtype=cell_les.dtype, device=cell_les.device
)

# Atomic energies
Expand Down
32 changes: 32 additions & 0 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,38 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
type=str2bool,
default=False,
)
parser.add_argument(
"--freeze",
help="Freeze layers [1..N]. Negative indices count from the end (e.g., -1 = last layer). 0 or None = no freeze.",
type=int,
default=None,
)
parser.add_argument(
"--freeze_par",
help="Freeze parameters [1..N]. Negative indices count from the end (e.g., -1 = last parameter). 0 or None = no freeze.",
type=int,
default=None,
)
parser.add_argument(
"--soft_freeze",
help="Soft-freeze layers [1..N]. If combined with full freezing, the count starts after freezing",
type=int,
default=None,
)
parser.add_argument(
"--soft_freeze_factor",
help="A fraction of --lr for soft-freeze",
type=float,
default=None,
)
parser.add_argument(
"--soft_freeze_swa",
"--soft_freeze_stage_two",
help="Apply soft-freezing to Stage Two, using the same soft_freeze_factor",
action="store_true",
default=False,
dest="soft_freeze_swa",
)

# Keys
parser.add_argument(
Expand Down
34 changes: 34 additions & 0 deletions mace/tools/custom_swa_lr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Patched SWALR scheduler for Stage Two, if soft-freezing (learning rate rescaling) requested
# Keeps the defaults of SWALR, just introduces non-uniform lr
import math

from torch.optim.swa_utils import SWALR


class CustomSWALR(SWALR):
def __init__(self, optimizer, swa_lr, **kwargs):
# Extract anneal settings early
self.anneal_epochs = kwargs.get("anneal_epochs", 1)
self.anneal_strategy = kwargs.get("anneal_strategy", "linear")
self.swa_lr = swa_lr

# Compute lr scaling ratios
max_lr = max(group["lr"] for group in optimizer.param_groups)
self.lr_ratios = [group["lr"] / max_lr for group in optimizer.param_groups]
self.base_lrs = [swa_lr * r for r in self.lr_ratios]

# Call parent constructor (sets up internal state)
super().__init__(optimizer, swa_lr=swa_lr, **kwargs)

def get_lr(self):
anneal_step = getattr(self, "_anneal_step", 0)
if self.anneal_strategy == "linear":
alpha = 1.0 - anneal_step / self.anneal_epochs
elif self.anneal_strategy == "cos":
alpha = 0.5 * (1 + math.cos(math.pi * anneal_step / self.anneal_epochs))
else:
raise ValueError(f"Invalid annealing strategy: {self.anneal_strategy}")

return [
self.swa_lr + (base_lr - self.swa_lr) * alpha for base_lr in self.base_lrs
]
Loading