diff --git a/README.md b/README.md index e0749e60c..4928aa3ee 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,8 @@ Foundation models are a rapidly evolving field. Please look at the [MACE-MP GitH | MACE-MATPES-PBE-0 | 89 | MATPES-PBE | DFT (PBE) | Materials | [medium](https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-pbe-omat-ft.model) | >=v0.3.10 | No +U correction. | ASL | | MACE-MATPES-r2SCAN-0 | 89 | MATPES-r2SCAN | DFT (r2SCAN) | Materials | [medium](https://github.com/ACEsuit/mace-foundations/releases/download/mace_matpes_0/MACE-matpes-r2scan-omat-ft.model) | >=v0.3.10 | Better functional for materials. | ASL | | MACE-OMOL-0 | 89 | OMOL | DFT (wB97M-VV10) | Molecules/Transition metals/Cations | [large](https://github.com/ACEsuit/mace-foundations/releases/download/mace_omol_0/MACE-omol-0-extra-large-1024.model) | >=v0.3.14 | Charge/Spin embedding, very good molecular accuracy. | ASL | +| MACE-MH-0/1 | 89 | OMAT/OMOL/OC20/MATPES | DFT (PBE/R2SCAN/wB97M-VV10) | Inorganic crystals, molecules and surfaces. [More info.](https://huggingface.co/mace-foundations/mace-mh-1) | [mh-0](https://github.com/ACEsuit/mace-foundations/releases/download/mace_mh_1/mace-mh-0.model) [mh-1](https://github.com/ACEsuit/mace-foundations/releases/download/mace_mh_1/mace-mh-1.model) | >=v0.3.14 | Very good cross domain performance on surfaces/bulk/molecules. | ASL | + ### MACE-MP: Materials Project Force Fields diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 9fe83a38d..45b070da5 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -332,8 +332,20 @@ def check_state(self, atoms, tol: float = 1e-15) -> list: Returns: list: A list of changes detected in the system. """ + + def _infos_equal(a: dict, b: dict) -> bool: + if a.keys() != b.keys(): + return False + for k in a: + va, vb = a[k], b[k] + if isinstance(va, np.ndarray) or isinstance(vb, np.ndarray): + continue + if va != vb: + return False + return True + state = super().check_state(atoms, tol=tol) - if (not state) and (self.atoms.info != atoms.info): + if (not state) and (not _infos_equal(self.atoms.info, atoms.info)): state.append("info") return state diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 919a30b85..e730a021f 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -36,6 +36,7 @@ from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.tools import torch_geometric from mace.tools.distributed_tools import init_distributed +from mace.tools.lora_tools import inject_LoRAs from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( HeadConfig, @@ -201,8 +202,8 @@ def run(args) -> None: ) if hasattr(model_foundation, "heads"): if len(model_foundation.heads) > 1: - logging.warning( - f"Mutlihead finetuning with models with more than one head is not supported, using the head {args.foundation_head} as foundation head." + logging.info( + f"Selecting the head {args.foundation_head} as foundation head." ) model_foundation = remove_pt_head( model_foundation, args.foundation_head @@ -548,7 +549,8 @@ def run(args) -> None: pt_head_config=head_config, r_max=args.r_max, device=device, - batch_size=args.batch_size + batch_size=args.batch_size, + force_stress=args.pseudolabel_replay_compute_stress, ): logging.info("Successfully applied pseudolabels to pt_head configurations") else: @@ -704,9 +706,28 @@ def run(args) -> None: model, output_args = configure_model(args, train_loader, atomic_energies, model_foundation, heads, z_table, head_configs) model.to(device) - logging.debug(model) - logging.info(f"Total number of parameters: {tools.count_parameters(model)}") - logging.info("") + if args.lora: + lora_rank = args.lora_rank + lora_alpha = args.lora_alpha + + logging.info( + "Injecting LoRA layers with rank=%s and alpha=%s", + lora_rank, + lora_alpha, + ) + + logging.info( + "Original model has %s trainable parameters.", + tools.count_parameters(model), + ) + + model = inject_LoRAs(model, rank=lora_rank, alpha=lora_alpha) + + logging.info( + "Model with LoRA has %s trainable parameters.", + tools.count_parameters(model), + ) + logging.info("===========OPTIMIZER INFORMATION===========") logging.info(f"Using {args.optimizer.upper()} as parameter optimizer") logging.info(f"Batch size: {args.batch_size}") @@ -736,8 +757,18 @@ def run(args) -> None: # Optimizer param_options = get_params_options(args, model) + optimizer: torch.optim.Optimizer optimizer = get_optimizer(args, param_options) + logging.info("=== Layer's learning rates ===") + for name, p in model.named_parameters(): + st = optimizer.state.get(p, {}) + if st: + logging.info(f"Param: {name}: {list(st.keys())}") + + for i, param_group in enumerate(optimizer.param_groups): + logging.info(f"Param group {i}: lr = {param_group['lr']}") + if args.device == "xpu": logging.info("Optimzing model and optimzier for XPU") model, optimizer = ipex.optimize(model, optimizer=optimizer) @@ -784,9 +815,6 @@ def run(args) -> None: ema: Optional[ExponentialMovingAverage] = None if args.ema: ema = ExponentialMovingAverage(model.parameters(), decay=args.ema_decay) - else: - for group in optimizer.param_groups: - group["lr"] = args.lr if args.lbfgs: logging.info("Switching optimizer to LBFGS") diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index ce965a74f..4e8371051 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -441,6 +441,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str2bool, default=False, ) + parser.add_argument( + "--pseudolabel_replay_compute_stress", + help="When replay pseudolabels are generated, always generate stress labels even if the original replay data lacked stress", + type=str2bool, + default=False, + ) parser.add_argument( "--foundation_filter_elements", help="Filter element during fine-tuning", @@ -534,6 +540,24 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str2bool, default=False, ) + parser.add_argument( + "--lora", + help="Use Low-Rank Adaptation for the fine-tuning", + type=str2bool, + default=False, + ) + parser.add_argument( + "--lora_rank", + help="Rank of the LoRA matrices", + type=int, + default=4, + ) + parser.add_argument( + "--lora_alpha", + help="Scaling factor for LoRA", + type=float, + default=1.0, + ) # Keys parser.add_argument( @@ -745,6 +769,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=float, default=0.9, ) + parser.add_argument( + "--beta1_schedulefree", + help="Beta1 parameter for the ScheduleFree optimizer", + type=float, + default=0.9, + ) + parser.add_argument( + "--beta2_schedulefree", + help="Beta2 parameter for the ScheduleFree optimizer", + type=float, + default=0.98, + ) parser.add_argument("--batch_size", help="batch size", type=int, default=10) parser.add_argument( "--valid_batch_size", help="Validation batch size", type=int, default=10 @@ -763,6 +799,18 @@ def build_default_arg_parser() -> argparse.ArgumentParser: parser.add_argument( "--weight_decay", help="weight decay (L2 penalty)", type=float, default=5e-7 ) + parser.add_argument( + "--lr_params_factors", + help="Learning rate factors to multiply on the original lr", + type=str, + default='{"embedding_lr_factor": 1.0, "interactions_lr_factor": 1.0, "products_lr_factor": 1.0, "readouts_lr_factor": 1.0}', + ) + parser.add_argument( + "--freeze", + help="Freeze layers from 1 to N. Can be positive or negative, e.g. -1 means the last layer is frozen. 0 or None means all layers are active and is a default setting", + type=int, + default=None, + ) parser.add_argument( "--amsgrad", help="use amsgrad variant of optimizer", diff --git a/mace/tools/lora_tools.py b/mace/tools/lora_tools.py new file mode 100644 index 000000000..39d40a153 --- /dev/null +++ b/mace/tools/lora_tools.py @@ -0,0 +1,185 @@ +import torch +from e3nn import o3 +from e3nn.nn._fc import _Layer as E3NNFCLayer +from torch import nn + + +def build_lora_irreps( + irreps_in: o3.Irreps, irreps_out: o3.Irreps, rank: int +) -> o3.Irreps: + """ + Choose an equivariant bottleneck irreps that preserves symmetry: for every irrep + present in BOTH input and output, allocate `rank` copies. + """ + in_set = {ir for _, ir in o3.Irreps(irreps_in)} + out_set = {ir for _, ir in o3.Irreps(irreps_out)} + shared = sorted(in_set & out_set, key=lambda ir: (ir.l, ir.p)) + if not shared: + raise ValueError( + f"No shared irreps between input ({irreps_in}) and output ({irreps_out}); cannot build equivariant LoRA." + ) + parts = [f"{rank}x{ir}" for ir in shared] + return o3.Irreps(" + ".join(parts)) + + +class LoRAO3Linear(nn.Module): + """LoRA for equivariant o3.Linear-like layers (preserves O(3) equivariance).""" + + def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.base = base_linear + self.irreps_in = self.base.irreps_in + self.irreps_out = self.base.irreps_out + self.scaling = float(alpha) / float(rank) + self.lora_irreps = build_lora_irreps(self.irreps_in, self.irreps_out, rank) + # Use the same class as base to avoid layout mismatches if possible + layer_type = type(self.base) + self.lora_A = layer_type( + self.irreps_in, self.lora_irreps, internal_weights=True, biases=False + ) + self.lora_B = layer_type( + self.lora_irreps, self.irreps_out, internal_weights=True, biases=False + ) + # Match dtype/device to base + base_param = next(self.base.parameters()) + self.lora_A.to(dtype=base_param.dtype, device=base_param.device) + self.lora_B.to(dtype=base_param.dtype, device=base_param.device) + + with torch.no_grad(): + for p in self.lora_B.parameters(): + p.zero_() + for p in self.lora_A.parameters(): + if p.dim() >= 2: + p.normal_(mean=0.0, std=1e-3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + base = self.base(x) + delta = self.lora_B(self.lora_A(x)) + return base + self.scaling * delta + + +class LoRADenseLinear(nn.Module): + """LoRA for torch.nn.Linear""" + + def __init__(self, base_linear: nn.Linear, rank: int = 4, alpha: float = 1.0): + super().__init__() + self.base = base_linear + in_f = self.base.in_features + out_f = self.base.out_features + self.scaling = float(alpha) / float(rank) + self.lora_A = nn.Linear(in_f, rank, bias=False) + self.lora_B = nn.Linear(rank, out_f, bias=False) + + # match dtype/device to base + base_param = next(self.base.parameters()) + self.lora_A.to(dtype=base_param.dtype, device=base_param.device) + self.lora_B.to(dtype=base_param.dtype, device=base_param.device) + + with torch.no_grad(): + nn.init.zeros_(self.lora_B.weight) + nn.init.normal_(self.lora_A.weight, mean=0.0, std=1e-3) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + base = self.base(x) + delta = self.lora_B(self.lora_A(x)) + return base + self.scaling * delta + + +class LoRAFCLayer(nn.Module): + """LoRA for e3nn.nn._fc._Layer used by FullyConnectedNet (scalar MLP). + Adds a low-rank delta on the weight matrix. + """ + + def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0): + super().__init__() + if not hasattr(base_layer, "weight"): + raise TypeError("LoRAFCLayer requires a layer with a 'weight' parameter") + self.base = base_layer + + w = self.base.weight # type: ignore[attr-defined] + in_f, out_f = int(w.shape[0]), int(w.shape[1]) + self.scaling = float(alpha) / float(rank) + + # Use explicit parameters to match e3nn layout [in, out] + self.lora_A = nn.Parameter( + torch.empty(in_f, rank, device=w.device, dtype=w.dtype) + ) + self.lora_B = nn.Parameter( + torch.empty(rank, out_f, device=w.device, dtype=w.dtype) + ) + + with torch.no_grad(): + torch.nn.init.normal_(self.lora_A, mean=0.0, std=1e-3) + torch.nn.init.zeros_(self.lora_B) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Replicate e3nn _Layer normalization + W = self.base.weight # type: ignore[attr-defined] + h_in = getattr(self.base, "h_in") + var_in = getattr(self.base, "var_in") + var_out = getattr(self.base, "var_out") + act = getattr(self.base, "act", None) + + delta = self.lora_A @ self.lora_B + W_sum = W + self.scaling * delta + + if act is not None: + denom = (h_in * var_in) ** 0.5 + w = W_sum / denom + x = x @ w + x = act(x) + x = x * (var_out**0.5) + else: + denom = (h_in * var_in / var_out) ** 0.5 + w = W_sum / denom + x = x @ w + return x + + +def inject_lora( + module: nn.Module, + rank: int = 4, + alpha: float = 1.0, + wrap_equivariant: bool = True, + wrap_dense: bool = True, + _is_root: bool = True, +) -> None: + """ + Recursively replace eligible linears with LoRA-wrapped versions. + """ + + for child_name, child in list(module.named_children()): + # Skip already wrapped + if isinstance(child, (LoRAO3Linear, LoRADenseLinear, LoRAFCLayer)): + continue + # Equivariant o3.Linear + if wrap_equivariant and isinstance(child, o3.Linear): + try: + wrapped = LoRAO3Linear(child, rank=rank, alpha=alpha) + except ValueError: # If no shared irreps, skip + continue + module._modules[child_name] = wrapped # pylint: disable=protected-access + # Dense nn.Linear + if wrap_dense and isinstance(child, nn.Linear): + wrapped = LoRADenseLinear(child, rank=rank, alpha=alpha) + module._modules[child_name] = wrapped # pylint: disable=protected-access + continue + # e3nn FullyConnectedNet internal layer + if wrap_dense and isinstance(child, E3NNFCLayer): + wrapped = LoRAFCLayer(child, rank=rank, alpha=alpha) + module._modules[child_name] = wrapped # pylint: disable=protected-access + continue + # Recurse + inject_lora(child, rank, alpha, wrap_equivariant, wrap_dense, _is_root=False) + + if _is_root: + for name, p in module.named_parameters(): + if ("lora_A" in name) or ("lora_B" in name): + p.requires_grad = True + else: + p.requires_grad = False + + +def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1): + inject_lora(model, rank=rank, alpha=alpha, wrap_equivariant=True, wrap_dense=True) + return model diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index a43581974..e501b4066 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -224,6 +224,7 @@ def generate_pseudolabels_for_configs( r_max: float, device: torch.device, batch_size: int, + force_stress: bool = False, ) -> List[Configuration]: """ Generate pseudolabels for a list of Configuration objects. @@ -282,6 +283,15 @@ def generate_pseudolabels_for_configs( if not hasattr(config_copy, "properties"): config_copy.properties = {} + if not hasattr(config_copy, "property_weights"): + config_copy.property_weights = {} + + original_stress_weight = config.property_weights.get("stress", 0.0) + had_stress = ( + config.properties.get("stress") is not None + and original_stress_weight > 0.0 + ) + # Update config properties with pseudolabels if "energy" in out and out["energy"] is not None: config_copy.properties["energy"] = ( @@ -296,9 +306,13 @@ def generate_pseudolabels_for_configs( out["forces"][node_start:node_end].detach().cpu().numpy() ) if "stress" in out and out["stress"] is not None: - config_copy.properties["stress"] = ( - out["stress"][j].detach().cpu().numpy() - ) + if had_stress or force_stress: + config_copy.properties["stress"] = ( + out["stress"][j].detach().cpu().numpy() + ) + config_copy.property_weights["stress"] = ( + original_stress_weight if had_stress else 1.0 + ) if "virials" in out and out["virials"] is not None: config_copy.properties["virials"] = ( out["virials"][j].detach().cpu().numpy() @@ -339,6 +353,7 @@ def apply_pseudolabels_to_pt_head_configs( r_max: float, device: torch.device, batch_size: int, + force_stress: bool = False, ) -> bool: """ Apply pseudolabels to pt_head configurations using the foundation model. @@ -391,6 +406,7 @@ def apply_pseudolabels_to_pt_head_configs( r_max=r_max, device=device, batch_size=batch_size, + force_stress=force_stress, ) # Replace the original configurations with updated ones @@ -414,6 +430,7 @@ def apply_pseudolabels_to_pt_head_configs( r_max=r_max, device=device, batch_size=batch_size, + force_stress=force_stress, ) # Replace the original configurations with updated ones diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index 218679666..63babc441 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -778,6 +778,11 @@ def get_swa( return swa, swas +def freeze_module(module: torch.nn.Module, freeze: bool = True): + for p in module.parameters(): + p.requires_grad = not freeze + + def get_params_options( args: argparse.Namespace, model: torch.nn.Module ) -> Dict[str, Any]: @@ -789,32 +794,57 @@ def get_params_options( else: no_decay_interactions[name] = param + lr_params_factors = json.loads(args.lr_params_factors) + + if args.freeze: + if args.freeze >= 7: + logging.info("Freezing readout weights") + lr_params_factors["readouts_lr_factor"] = 0.0 + freeze_module(model.readouts, True) + if args.freeze >= 6: + logging.info("Freezing product weights") + lr_params_factors["products_lr_factor"] = 0.0 + freeze_module(model.products, True) + if args.freeze >= 5: + logging.info("Freezing interaction linear weights") + lr_params_factors["interactions_lr_factor"] = 0.0 + freeze_module(model.interactions, True) + if args.freeze >= 1: + logging.info("Freezing embedding weights") + lr_params_factors["embedding_lr_factor"] = 0.0 + freeze_module(model.node_embedding, True) + param_options = dict( params=[ { "name": "embedding", "params": model.node_embedding.parameters(), "weight_decay": 0.0, + "lr": lr_params_factors.get("embedding_lr_factor", 1.0) * args.lr, }, { "name": "interactions_decay", "params": list(decay_interactions.values()), "weight_decay": args.weight_decay, + "lr": lr_params_factors.get("interactions_lr_factor", 1.0) * args.lr, }, { "name": "interactions_no_decay", "params": list(no_decay_interactions.values()), "weight_decay": 0.0, + "lr": lr_params_factors.get("interactions_lr_factor", 1.0) * args.lr, }, { "name": "products", "params": model.products.parameters(), "weight_decay": args.weight_decay, + "lr": lr_params_factors.get("products_lr_factor", 1.0) * args.lr, }, { "name": "readouts", "params": model.readouts.parameters(), "weight_decay": 0.0, + "lr": lr_params_factors.get("readouts_lr_factor", 1.0) * args.lr, }, ], lr=args.lr, @@ -861,7 +891,10 @@ def get_optimizer( "`schedulefree` is not installed. Please install it via `pip install schedulefree` or `pip install mace-torch[schedulefree]`" ) from exc _param_options = {k: v for k, v in param_options.items() if k != "amsgrad"} - optimizer = adamw_schedulefree.AdamWScheduleFree(**_param_options) + _param_options.pop("betas", None) + optimizer = adamw_schedulefree.AdamWScheduleFree( + **_param_options, betas=(args.beta1_schedulefree, args.beta2_schedulefree) + ) else: optimizer = torch.optim.Adam(**param_options) return optimizer diff --git a/mace/tools/train.py b/mace/tools/train.py index 2110ac81e..d8b97c64a 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -8,7 +8,7 @@ import logging import time from collections import defaultdict -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -538,6 +538,20 @@ def closure(): return loss, loss_dict +# Keep parameters frozen/active after evaluation +@contextmanager +def preserve_grad_state(model): + # save the original requires_grad state for all parameters + requires_grad_backup = {param: param.requires_grad for param in model.parameters()} + try: + # temporarily disable gradients for all parameters + for param in model.parameters(): + param.requires_grad = False + yield # perform evaluation here + finally: + # restore the original requires_grad states + for param, requires_grad in requires_grad_backup.items(): + param.requires_grad = requires_grad def evaluate( model: torch.nn.Module, @@ -546,31 +560,28 @@ def evaluate( output_args: Dict[str, bool], device: torch.device, ) -> Tuple[float, Dict[str, Any]]: - for param in model.parameters(): - param.requires_grad = False + metrics = MACELoss(loss_fn=loss_fn).to(device) start_time = time.time() - for batch in data_loader: - batch = batch.to(device) - batch_dict = batch.to_dict() - output = model( - batch_dict, - training=False, - compute_force=output_args["forces"], - compute_virials=output_args["virials"], - compute_stress=output_args["stress"], - ) - avg_loss, aux = metrics(batch, output) + with preserve_grad_state(model): + for batch in data_loader: + batch = batch.to(device) + batch_dict = batch.to_dict() + output = model( + batch_dict, + training=False, + compute_force=output_args["forces"], + compute_virials=output_args["virials"], + compute_stress=output_args["stress"], + ) + avg_loss, aux = metrics(batch, output) avg_loss, aux = metrics.compute() aux["time"] = time.time() - start_time metrics.reset() - for param in model.parameters(): - param.requires_grad = True - return avg_loss, aux diff --git a/tests/test_finetuning_pseudolabels.py b/tests/test_finetuning_pseudolabels.py index b0b4d9f20..3dc2dda7a 100644 --- a/tests/test_finetuning_pseudolabels.py +++ b/tests/test_finetuning_pseudolabels.py @@ -9,7 +9,6 @@ import pytest from ase.atoms import Atoms - run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" diff --git a/tests/test_freeze.py b/tests/test_freeze.py new file mode 100644 index 000000000..5d2c277a4 --- /dev/null +++ b/tests/test_freeze.py @@ -0,0 +1,257 @@ +import os +import subprocess +import sys +from pathlib import Path + +import ase.io +import numpy as np +import pytest +import torch +from ase.atoms import Atoms + +from mace.calculators import MACECalculator + +try: + import cuequivariance as cue # pylint: disable=unused-import + + CUET_AVAILABLE = True +except ImportError: + CUET_AVAILABLE = False + +device = "cuda" if torch.cuda.is_available() else "cpu" +# device = "cpu" + +run_train = Path(__file__).parent.parent / "mace" / "cli" / "run_train.py" + + +@pytest.fixture(name="fitting_configs") +def fixture_fitting_configs(): + water = Atoms( + numbers=[8, 1, 1], + positions=[[0, -2.0, 0], [1, 0, 0], [0, 1, 0]], + cell=[4] * 3, + pbc=[True] * 3, + ) + fit_configs = [ + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3), + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3), + ] + fit_configs[0].info["REF_energy"] = 0.0 + fit_configs[0].info["config_type"] = "IsolatedAtom" + fit_configs[1].info["REF_energy"] = 0.0 + fit_configs[1].info["config_type"] = "IsolatedAtom" + + np.random.seed(5) + for _ in range(20): + c = water.copy() + c.positions += np.random.normal(0.1, size=c.positions.shape) + c.info["REF_energy"] = np.random.normal(0.1) + print(c.info["REF_energy"]) + c.new_array("REF_forces", np.random.normal(0.1, size=c.positions.shape)) + c.info["REF_stress"] = np.random.normal(0.1, size=6) + fit_configs.append(c) + + return fit_configs + + +@pytest.fixture(name="pretraining_configs") +def fixture_pretraining_configs(): + configs = [] + for _ in range(10): + atoms = Atoms( + numbers=[8, 1, 1], + positions=np.random.rand(3, 3) * 3, + cell=[5, 5, 5], + pbc=[True] * 3, + ) + atoms.info["REF_energy"] = np.random.normal(0, 1) + atoms.arrays["REF_forces"] = np.random.normal(0, 1, size=(3, 3)) + atoms.info["REF_stress"] = np.random.normal(0, 1, size=6) + configs.append(atoms) + configs.append( + Atoms(numbers=[8], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3), + ) + configs.append( + Atoms(numbers=[1], positions=[[0, 0, 0]], cell=[6] * 3, pbc=[True] * 3) + ) + configs[-2].info["REF_energy"] = -2.0 + configs[-2].info["config_type"] = "IsolatedAtom" + configs[-1].info["REF_energy"] = -4.0 + configs[-1].info["config_type"] = "IsolatedAtom" + return configs + + +_mace_params = { + "name": "MACE", + "valid_fraction": 0.05, + "energy_weight": 1.0, + "forces_weight": 10.0, + "stress_weight": 1.0, + "model": "MACE", + "hidden_irreps": "128x0e", + "r_max": 3.5, + "batch_size": 5, + "max_num_epochs": 2, + "swa": None, + "start_swa": 5, + "ema": None, + "ema_decay": 0.99, + "amsgrad": None, + "restart_latest": None, + "device": device, + "seed": 5, + "loss": "stress", + "energy_key": "REF_energy", + "forces_key": "REF_forces", + "stress_key": "REF_stress", + "eval_interval": 2, +} + + +def test_run_train_freeze(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + mace_params["freeze"] = 6 + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = ( + sys.executable + + " " + + str(run_train) + + " " + + " ".join( + [ + (f"--{k}={v}" if v is not None else f"--{k}") + for k, v in mace_params.items() + ] + ) + ) + + print(f"Running command: {cmd}") + p = subprocess.run(cmd.split(), env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device=device, default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + + ref_Es = [ + 5.348334089807952, + 2.4128907878403982, + 8.5566950528953, + 7.743803832228654, + 5.788643738738498, + 9.103127501095454, + 8.719323994063377, + 8.169843256425096, + 8.077166786336269, + 8.679676296893602, + 12.189297325152948, + 6.911712148654615, + 8.290506707079263, + 5.303821445834231, + 7.296761518032694, + 5.946962420990914, + 9.043336244248948, + 7.446979685692335, + 5.764245581904601, + 6.975111618768769, + 6.931624082425803, + 6.72206658924676, + ] + + assert np.allclose(Es, ref_Es) + + +def test_run_train_soft_freeze(tmp_path, fitting_configs): + ase.io.write(tmp_path / "fit.xyz", fitting_configs) + + mace_params = _mace_params.copy() + mace_params["checkpoints_dir"] = str(tmp_path) + mace_params["model_dir"] = str(tmp_path) + mace_params["train_file"] = tmp_path / "fit.xyz" + mace_params["loss"] = "weighted" + mace_params["foundation_model"] = "small" + mace_params["hidden_irreps"] = "128x0e" + mace_params["r_max"] = 6.0 + mace_params["default_dtype"] = "float64" + mace_params["num_radial_basis"] = 10 + mace_params["interaction_first"] = "RealAgnosticResidualInteractionBlock" + mace_params["multiheads_finetuning"] = False + mace_params["lr_params_factors"] = '{"embedding_lr_factor": 0.0, "interactions_lr_factor": 1.0, "products_lr_factor": 1.0, "readouts_lr_factor": 1.0}' + + run_env = os.environ.copy() + sys.path.insert(0, str(Path(__file__).parent.parent)) + run_env["PYTHONPATH"] = ":".join(sys.path) + + cmd = [sys.executable, str(run_train)] + for k, v in mace_params.items(): + if v is not None: + cmd.append(f"--{k}={v}") + else: + cmd.append(f"--{k}") + + print(f"Running command: {cmd}") + p = subprocess.run(cmd, env=run_env, check=True) + assert p.returncode == 0 + + calc = MACECalculator( + model_paths=tmp_path / "MACE.model", device=device, default_dtype="float64" + ) + + Es = [] + for at in fitting_configs: + at.calc = calc + Es.append(at.get_potential_energy()) + + print("Es", Es) + + ref_Es = [ + 4.077101520328611, + 1.9125514950167353, + 4.6390361860381795, + 4.6415570296531214, + 3.9153698530138845, + 4.487578378535444, + 4.439674506695098, + 4.906251552572849, + 4.6943771636613985, + 4.443480673870315, + 12.392544826986759, + 4.8014551746345475, + 4.6380462142293455, + 4.126315015844008, + 4.923222049125721, + 4.442558518514199, + 4.556565520687697, + 4.935513763430022, + 4.077869607943539, + 4.4407761603911124, + 5.10253699303561, + 4.537672050884654, + ] + + assert np.allclose(Es, ref_Es) + diff --git a/tests/test_lora.py b/tests/test_lora.py new file mode 100644 index 000000000..ce90b7154 --- /dev/null +++ b/tests/test_lora.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +import math +from typing import Callable, List, Tuple + +import numpy as np +import pytest +import torch +from e3nn import o3 + +from mace import data, modules, tools +from mace.data import Configuration +from mace.tools import torch_geometric +from mace.tools.lora_tools import inject_lora + + +def _random_config() -> Configuration: + atomic_numbers = np.array([6, 1, 1], dtype=int) + positions = np.random.normal(scale=0.5, size=(3, 3)) + properties = { + "energy": np.random.normal(scale=0.1), + "forces": np.random.normal(scale=0.1, size=(3, 3)), + } + prop_weights = {"energy": 1.0, "forces": 1.0} + return Configuration( + atomic_numbers=atomic_numbers, + positions=positions, + properties=properties, + property_weights=prop_weights, + cell=np.eye(3) * 8.0, + pbc=(True, True, True), + ) + + +def _build_model() -> Tuple[modules.MACE, tools.AtomicNumberTable]: + table = tools.AtomicNumberTable([1, 6]) + model = modules.MACE( + r_max=4.5, + num_bessel=4, + num_polynomial_cutoff=3, + max_ell=1, + interaction_cls=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + interaction_cls_first=modules.interaction_classes[ + "RealAgnosticResidualInteractionBlock" + ], + num_interactions=2, + num_elements=2, + hidden_irreps=o3.Irreps("16x0e + 16x1o"), + MLP_irreps=o3.Irreps("16x0e"), + gate=torch.nn.functional.silu, + atomic_energies=np.random.normal(scale=0.1, size=len(table.zs)), + avg_num_neighbors=6.0, + atomic_numbers=table.zs, + correlation=2, + radial_type="bessel", + ) + return model, table + + +def _atomic_data_from_config( + config: Configuration, + table: tools.AtomicNumberTable, + cutoff: float = 4.5, +) -> data.AtomicData: + return data.AtomicData.from_config(config, z_table=table, cutoff=cutoff) + + +def _forward_energy_forces( + model: torch.nn.Module, + configs: List[Configuration], + table: tools.AtomicNumberTable, +) -> Tuple[torch.Tensor, torch.Tensor]: + dataset = [_atomic_data_from_config(cfg, table) for cfg in configs] + loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=len(dataset), + shuffle=False, + drop_last=False, + ) + batch = next(iter(loader)) + outputs = model(batch.to_dict()) + energies = outputs["energy"].detach() + forces = outputs["forces"].detach() + return energies, forces + + +def _randomize_lora_parameters(model: torch.nn.Module) -> None: + with torch.no_grad(): + for name, param in model.named_parameters(): + if "lora_A" in name or "lora_B" in name: + torch.nn.init.normal_(param, mean=0.0, std=0.05) + + +def _rotation_matrix() -> np.ndarray: + axis = np.random.normal(size=3) + axis /= np.linalg.norm(axis) + theta = np.random.uniform(0, 2 * math.pi) + K = np.array( + [ + [0.0, -axis[2], axis[1]], + [axis[2], 0.0, -axis[0]], + [-axis[1], axis[0], 0.0], + ] + ) + R = np.eye(3) + math.sin(theta) * K + (1.0 - math.cos(theta)) * (K @ K) + return R + + +def _rotate_config(config: Configuration, R: np.ndarray) -> Configuration: + return Configuration( + atomic_numbers=config.atomic_numbers.copy(), + positions=config.positions @ R.T, + properties=config.properties.copy(), + property_weights=config.property_weights.copy(), + cell=config.cell @ R.T if config.cell is not None else None, + pbc=config.pbc, + weight=config.weight, + config_type=config.config_type, + head=config.head, + ) + + +def _translate_config(config: Configuration, shift: np.ndarray) -> Configuration: + return Configuration( + atomic_numbers=config.atomic_numbers.copy(), + positions=config.positions + shift.reshape(1, 3), + properties=config.properties.copy(), + property_weights=config.property_weights.copy(), + cell=config.cell, + pbc=config.pbc, + weight=config.weight, + config_type=config.config_type, + head=config.head, + ) + + +def _reflect_config( + config: Configuration, normal: np.ndarray +) -> Tuple[Configuration, np.ndarray]: + normal = normal / np.linalg.norm(normal) + R = np.eye(3) - 2.0 * np.outer(normal, normal) + reflected = _rotate_config(config, R) + return reflected, R + + +@pytest.fixture(name="random_configs") +def _random_configs() -> Tuple[Configuration, Configuration]: + return _random_config(), _random_config() + + +@pytest.fixture(name="build_lora_model") +def _build_lora_model_fixture() -> ( + Callable[[int, float, bool], Tuple[modules.MACE, tools.AtomicNumberTable]] +): + def _builder( + rank: int = 2, + alpha: float = 0.5, + randomize: bool = True, + ) -> Tuple[modules.MACE, tools.AtomicNumberTable]: + model, table = _build_model() + inject_lora(model, rank=rank, alpha=alpha) + if randomize: + _randomize_lora_parameters(model) + return model, table + + return _builder + + +def test_lora_trainable_parameter_count(build_lora_model) -> None: + model, _ = build_lora_model(rank=2, alpha=0.5, randomize=True) + trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + expected = sum(p.numel() for name, p in model.named_parameters() if "lora_" in name) + assert trainable == expected + + non_lora_trainable = [ + name + for name, p in model.named_parameters() + if "lora_" not in name and p.requires_grad + ] + assert ( + not non_lora_trainable + ), f"Non-LoRA parameters trainable: {non_lora_trainable}" + + # Ensure LoRA parameters were randomized away from zero + for name, param in model.named_parameters(): + if "lora_B" in name: + assert torch.any( + torch.abs(param) > 0 + ), f"LoRA parameter {name} incorrectly zero" + + +def test_lora_symmetry_equivariance(build_lora_model, random_configs) -> None: + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + model.eval() + base_cfg = random_configs[0] + + energy, forces = _forward_energy_forces(model, [base_cfg], table) + energy_val = energy.item() + forces_val = forces.squeeze(0).detach().numpy() + + # Rotation invariance / covariance + R = _rotation_matrix() + rotated_cfg = _rotate_config(base_cfg, R) + energy_rot, forces_rot = _forward_energy_forces(model, [rotated_cfg], table) + assert np.allclose(energy_rot.item(), energy_val, rtol=1e-6, atol=1e-6) + assert np.allclose( + forces_val @ R.T, forces_rot.squeeze(0).detach().numpy(), rtol=1e-5, atol=1e-5 + ) + + # Translation invariance + shift = np.array([0.17, -0.05, 0.08]) + translated_cfg = _translate_config(base_cfg, shift) + energy_trans, forces_trans = _forward_energy_forces(model, [translated_cfg], table) + assert np.allclose(energy_trans.item(), energy_val, rtol=1e-6, atol=1e-6) + assert np.allclose( + forces_trans.squeeze(0).detach().numpy(), forces_val, rtol=1e-6, atol=1e-6 + ) + + # Reflection invariance / covariance + reflected_cfg, R_reflect = _reflect_config(base_cfg, np.array([1.0, -2.0, 3.0])) + energy_ref, forces_ref = _forward_energy_forces(model, [reflected_cfg], table) + assert np.allclose(energy_ref.item(), energy_val, rtol=1e-6, atol=1e-6) + assert np.allclose( + forces_val @ R_reflect.T, + forces_ref.squeeze(0).detach().numpy(), + rtol=1e-5, + atol=1e-5, + ) diff --git a/tests/test_maceles.py b/tests/test_maceles.py index eca490066..e3465a2fc 100644 --- a/tests/test_maceles.py +++ b/tests/test_maceles.py @@ -279,7 +279,6 @@ def mace_model_path_fixture(tmp_path: Path) -> Path: return path -@pytest.mark.skipif(not LES_AVAILABLE, reason="LES library is not available") @pytest.fixture(name="maceles_model_path") def maceles_model_path_fixture(tmp_path: Path) -> Path: """Create and save a MACELES model."""