From 8e24229153ab7cb00a7c0c44da1b67f7bdd61967 Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Tue, 18 Nov 2025 21:08:36 +0000 Subject: [PATCH 01/11] cleanup lora code --- mace/tools/lora_tools.py | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/mace/tools/lora_tools.py b/mace/tools/lora_tools.py index 39d40a153..e65b23400 100644 --- a/mace/tools/lora_tools.py +++ b/mace/tools/lora_tools.py @@ -113,27 +113,23 @@ def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0): torch.nn.init.zeros_(self.lora_B) def forward(self, x: torch.Tensor) -> torch.Tensor: - # Replicate e3nn _Layer normalization - W = self.base.weight # type: ignore[attr-defined] - h_in = getattr(self.base, "h_in") - var_in = getattr(self.base, "var_in") - var_out = getattr(self.base, "var_out") - act = getattr(self.base, "act", None) - + # Temporarily patch the weight of the base layer to include LoRA delta + # This avoids re-implementing the complex normalization/activation logic of e3nn _Layer + w_orig = self.base.weight delta = self.lora_A @ self.lora_B - W_sum = W + self.scaling * delta + weight_patched = w_orig + self.scaling * delta + + # Patch: self.base.weight is a Parameter. We replace it with a Tensor for the forward pass. + # To do this safely in PyTorch, we must temporarily remove it from _parameters. + del self.base._parameters["weight"] + self.base.weight = weight_patched - if act is not None: - denom = (h_in * var_in) ** 0.5 - w = W_sum / denom - x = x @ w - x = act(x) - x = x * (var_out**0.5) - else: - denom = (h_in * var_in / var_out) ** 0.5 - w = W_sum / denom - x = x @ w - return x + try: + return self.base(x) + finally: + # Restore the original Parameter + self.base.weight = w_orig + self.base._parameters["weight"] = w_orig def inject_lora( @@ -158,16 +154,16 @@ def inject_lora( wrapped = LoRAO3Linear(child, rank=rank, alpha=alpha) except ValueError: # If no shared irreps, skip continue - module._modules[child_name] = wrapped # pylint: disable=protected-access + setattr(module, child_name, wrapped) # Dense nn.Linear if wrap_dense and isinstance(child, nn.Linear): wrapped = LoRADenseLinear(child, rank=rank, alpha=alpha) - module._modules[child_name] = wrapped # pylint: disable=protected-access + setattr(module, child_name, wrapped) continue # e3nn FullyConnectedNet internal layer if wrap_dense and isinstance(child, E3NNFCLayer): wrapped = LoRAFCLayer(child, rank=rank, alpha=alpha) - module._modules[child_name] = wrapped # pylint: disable=protected-access + setattr(module, child_name, wrapped) continue # Recurse inject_lora(child, rank, alpha, wrap_equivariant, wrap_dense, _is_root=False) From 0dcae047e2124c02ea6302a8d5470dacbb3077cd Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:00:35 +0000 Subject: [PATCH 02/11] add option to use estimated E0s for finetuning --- mace/cli/run_train.py | 28 ++++++++- mace/data/__init__.py | 2 + mace/data/utils.py | 140 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 1 deletion(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 47940c071..656d1ca48 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -434,7 +434,7 @@ def run(args) -> None: for head_config in head_configs: if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: assert head_config.E0s is not None, "Atomic energies must be provided" - if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() != "foundation": + if all(check_path_ase_read(f) for f in head_config.train_file) and head_config.E0s.lower() not in ["foundation", "estimated"]: atomic_energies_dict[head_config.head_name] = get_atomic_energies( head_config.E0s, head_config.collections.train, head_config.z_table ) @@ -455,6 +455,32 @@ def run(args) -> None: ].item() for z in z_table.zs } + elif head_config.E0s.lower() == "estimated": + assert args.foundation_model is not None, "Foundation model must be provided for E0s estimation" + assert all(check_path_ase_read(f) for f in head_config.train_file), "E0s estimation requires training data in .xyz format" + logging.info("Estimating E0s from foundation model predictions on training data") + z_table_foundation = AtomicNumberTable( + [int(z) for z in model_foundation.atomic_numbers] + ) + foundation_atomic_energies = model_foundation.atomic_energies_fn.atomic_energies + if foundation_atomic_energies.ndim > 1: + foundation_atomic_energies = foundation_atomic_energies.squeeze() + if foundation_atomic_energies.ndim == 2: + foundation_atomic_energies = foundation_atomic_energies[0] + logging.info("Foundation model has multiple heads, using the first head for E0 estimation.") + foundation_e0s = { + z: foundation_atomic_energies[ + z_table_foundation.z_to_index(z) + ].item() + for z in z_table_foundation.zs + } + atomic_energies_dict[head_config.head_name] = data.estimate_e0s_from_foundation( + foundation_model=model_foundation, + foundation_e0s=foundation_e0s, + collections_train=head_config.collections.train, + z_table=head_config.z_table, + device=device, + ) else: atomic_energies_dict[head_config.head_name] = get_atomic_energies(head_config.E0s, None, head_config.z_table) else: diff --git a/mace/data/__init__.py b/mace/data/__init__.py index 8629cf521..c37439bcd 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -9,6 +9,7 @@ compute_average_E0s, config_from_atoms, config_from_atoms_list, + estimate_e0s_from_foundation, load_from_xyz, random_train_valid_split, save_AtomicData_to_HDF5, @@ -29,6 +30,7 @@ "config_from_atoms_list", "AtomicData", "compute_average_E0s", + "estimate_e0s_from_foundation", "save_dataset_as_HDF5", "HDF5Dataset", "dataset_from_sharded_hdf5", diff --git a/mace/data/utils.py b/mace/data/utils.py index 049d42cdd..45b147da4 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -14,6 +14,9 @@ import h5py import numpy as np +import torch +from ase.atoms import Atoms + from mace.tools import AtomicNumberTable, DefaultKeys Positions = np.ndarray # [..., 3] @@ -388,6 +391,143 @@ def compute_average_E0s( return atomic_energies_dict +def estimate_e0s_from_foundation( + foundation_model, + foundation_e0s: Dict[int, float], + collections_train: Configurations, + z_table: AtomicNumberTable, + device: str = "cpu", +) -> Dict[int, float]: + """ + Estimate atomic reference energies (E0s) by solving a linear system + that optimally corrects foundation model predictions on training data. + + This function computes E0 corrections by: + 1. Running the foundation model on all training configurations + 2. Computing prediction errors (reference - predicted) + 3. Solving a least-squares system to find optimal E0 corrections + + Args: + foundation_model: The foundation MACE model + foundation_e0s: Dictionary mapping element atomic numbers to original E0 values + collections_train: List of training configurations + z_table: Atomic number table for the training dataset + device: Device to run predictions on (default: "cpu") + + Returns: + Dictionary with estimated E0 values for each element + """ + + # Filter configs with valid energy + valid_configs = [] + for config in collections_train: + if "energy" in config.properties and config.properties["energy"] is not None: + valid_configs.append(config) + + if not valid_configs: + logging.warning("No configurations with energy found for E0 estimation. Using foundation E0s.") + return foundation_e0s.copy() + + elements = z_table.zs + n_configs = len(valid_configs) + n_elements = len(elements) + + # A matrix: each row contains atom counts for each element + # b vector: each entry is the prediction error for a configuration + A = np.zeros((n_configs, n_elements)) + b = np.zeros(n_configs) + + logging.info(f"Estimating E0s using foundation model on {n_configs} configurations with {n_elements} elements") + + # Set model to eval mode + foundation_model.eval() + + with torch.no_grad(): + for i, config in enumerate(valid_configs): + # Convert to AtomicData for model prediction + from mace.data import AtomicData + atomic_data = AtomicData.from_config( + config, + z_table=AtomicNumberTable([int(z) for z in foundation_model.atomic_numbers]), + cutoff=foundation_model.r_max, + ) + atomic_data = atomic_data.to(device) + + # Get model prediction + output = foundation_model(atomic_data.to_dict()) + predicted_energy = output["energy"] + + # Handle different tensor shapes (batched or unbatched) + if predicted_energy.dim() == 0: + predicted_energy = predicted_energy.item() + else: + predicted_energy = predicted_energy.item() if predicted_energy.numel() == 1 else predicted_energy[0].item() + + # Get reference energy + ref_energy = config.properties["energy"] + + # Compute error + error = ref_energy - predicted_energy + b[i] = error + + # Store atom counts for each element + for j, element in enumerate(elements): + A[i, j] = np.sum(config.atomic_numbers == element) + + # Solve least squares system: A @ corrections = b + try: + corrections, residuals, rank, s = np.linalg.lstsq(A, b, rcond=None) + + logging.info("=" * 80) + logging.info("E0 ESTIMATION FROM FOUNDATION MODEL") + logging.info("=" * 80) + logging.info(f"Rank of system: {rank}/{n_elements}") + logging.info(f"Residuals: {residuals}") + + # Compute new E0s + new_e0s = {} + for i, element in enumerate(elements): + correction = corrections[i] + foundation_e0 = foundation_e0s.get(element, 0.0) + new_e0s[element] = foundation_e0 + correction + logging.info( + f"Element {element}: foundation E0 = {foundation_e0:.6f} eV, " + f"correction = {correction:.6f} eV, new E0 = {new_e0s[element]:.6f} eV" + ) + + # Compute statistics + mse_before = np.mean(b**2) + b_after = b - A @ corrections + mse_after = np.mean(b_after**2) + rmse_before = np.sqrt(mse_before) + rmse_after = np.sqrt(mse_after) + mae_before = np.mean(np.abs(b)) + mae_after = np.mean(np.abs(b_after)) + + logging.info("=" * 80) + logging.info("FIT STATISTICS") + logging.info("=" * 80) + logging.info(f"RMSE before E0 correction: {rmse_before:.6f} eV") + logging.info(f"RMSE after E0 correction: {rmse_after:.6f} eV") + logging.info(f"MAE before E0 correction: {mae_before:.6f} eV") + logging.info(f"MAE after E0 correction: {mae_after:.6f} eV") + + if rank < n_elements: + logging.warning( + f"System is rank deficient (rank {rank}/{n_elements}). " + "Some elements may not be sufficiently represented in the dataset." + ) + + logging.info("=" * 80) + + return new_e0s + + except np.linalg.LinAlgError as e: + logging.error(f"Error solving linear system for E0 estimation: {e}") + logging.warning("Falling back to foundation model E0s") + return foundation_e0s.copy() + + def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: with h5py.File(out_name, "w") as f: for i, data in enumerate(dataset): From 82588bcc43d9964cffa4ef0b5d2237c4e09ee365 Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:09:15 +0000 Subject: [PATCH 03/11] fix cutoff dtype bug --- mace/data/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 45b147da4..4df4017e6 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -16,6 +16,7 @@ import torch from ase.atoms import Atoms +from mace.data import AtomicData from mace.tools import AtomicNumberTable, DefaultKeys @@ -442,14 +443,20 @@ def estimate_e0s_from_foundation( # Set model to eval mode foundation_model.eval() + # Get r_max as a float + r_max = foundation_model.r_max + if hasattr(r_max, 'item'): + r_max = r_max.item() + elif isinstance(r_max, torch.Tensor): + r_max = float(r_max) + with torch.no_grad(): for i, config in enumerate(valid_configs): # Convert to AtomicData for model prediction - from mace.data import AtomicData atomic_data = AtomicData.from_config( config, z_table=AtomicNumberTable([int(z) for z in foundation_model.atomic_numbers]), - cutoff=foundation_model.r_max, + cutoff=r_max, ) atomic_data = atomic_data.to(device) From cc8f3e488ab77dcd31e6957044d6b181ce64ed8d Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:12:10 +0000 Subject: [PATCH 04/11] fix circular dependency --- mace/data/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 4df4017e6..4a3c6678c 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -14,10 +14,6 @@ import h5py import numpy as np -import torch -from ase.atoms import Atoms -from mace.data import AtomicData - from mace.tools import AtomicNumberTable, DefaultKeys Positions = np.ndarray # [..., 3] @@ -418,6 +414,7 @@ def estimate_e0s_from_foundation( Returns: Dictionary with estimated E0 values for each element """ + import torch # Filter configs with valid energy valid_configs = [] @@ -453,6 +450,8 @@ def estimate_e0s_from_foundation( with torch.no_grad(): for i, config in enumerate(valid_configs): # Convert to AtomicData for model prediction + # Import here to avoid circular dependency + from mace.data import AtomicData atomic_data = AtomicData.from_config( config, z_table=AtomicNumberTable([int(z) for z in foundation_model.atomic_numbers]), From 4c1d6857a4dda0aad2b58c2f6611df935a3d5333 Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:16:06 +0000 Subject: [PATCH 05/11] create batches to do forward pass with for E0 estimation --- mace/data/utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 4a3c6678c..d049c3591 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -452,15 +452,25 @@ def estimate_e0s_from_foundation( # Convert to AtomicData for model prediction # Import here to avoid circular dependency from mace.data import AtomicData + from mace.tools import torch_geometric + atomic_data = AtomicData.from_config( config, z_table=AtomicNumberTable([int(z) for z in foundation_model.atomic_numbers]), cutoff=r_max, ) - atomic_data = atomic_data.to(device) + + # Create a proper batch using DataLoader + data_loader = torch_geometric.dataloader.DataLoader( + dataset=[atomic_data], + batch_size=1, + shuffle=False, + drop_last=False, + ) + batch = next(iter(data_loader)).to(device) # Get model prediction - output = foundation_model(atomic_data.to_dict()) + output = foundation_model(batch.to_dict()) predicted_energy = output["energy"] # Handle different tensor shapes (batched or unbatched) From 08b807cf4c177d9f978673b26cd8dad0b6eccb97 Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:24:43 +0000 Subject: [PATCH 06/11] debug --- mace/data/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mace/data/utils.py b/mace/data/utils.py index d049c3591..084f4cb62 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -439,6 +439,7 @@ def estimate_e0s_from_foundation( # Set model to eval mode foundation_model.eval() + foundation_model = foundation_model.to(device) # Get r_max as a float r_max = foundation_model.r_max From bd75f726b80223a28b14738cabbfc9e9d6d2157b Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:27:03 +0000 Subject: [PATCH 07/11] disable force prediciton for E0 estimation --- mace/data/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mace/data/utils.py b/mace/data/utils.py index 084f4cb62..fab52a965 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -470,8 +470,14 @@ def estimate_e0s_from_foundation( ) batch = next(iter(data_loader)).to(device) - # Get model prediction - output = foundation_model(batch.to_dict()) + # Get model prediction (only energy, no forces/stress to avoid gradient computation) + output = foundation_model( + batch.to_dict(), + training=False, + compute_force=False, + compute_virials=False, + compute_stress=False, + ) predicted_energy = output["energy"] # Handle different tensor shapes (batched or unbatched) From 424cddef5995648c808953c55764b3edeeb11380 Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Thu, 20 Nov 2025 23:39:32 +0000 Subject: [PATCH 08/11] update readme with the estimated E0s option --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 4928aa3ee..1aa4fef36 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,11 @@ To give a specific validation set, use the argument `--valid_file`. To set a lar To control the model's size, you need to change `--hidden_irreps`. For most applications, the recommended default model size is `--hidden_irreps='256x0e'` (meaning 256 invariant messages) or `--hidden_irreps='128x0e + 128x1o'`. If the model is not accurate enough, you can include higher order features, e.g., `128x0e + 128x1o + 128x2e`, or increase the number of channels to `256`. It is also possible to specify the model using the `--num_channels=128` and `--max_L=1`keys. -It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. If you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. Note that using fitted E0s corresponds to fitting the deviations of the atomic energies from the average, rather than fitting the atomization energy (which is the case when using isolated-atom E0s), and this will most likely result in less stable potentials for molecular dynamics applications. +It is usually preferred to add the isolated atoms to the training set, rather than reading in their energies through the command line like in the example above. To label them in the training set, set `config_type=IsolatedAtom` in their info fields. + +When training a model from scratch, if you prefer not to use or do not know the energies of the isolated atoms, you can use the option `--E0s="average"` which estimates the atomic energies using least squares regression. Note that using fitted E0s corresponds to fitting the deviations of the atomic energies from the average, rather than fitting the atomization energy (which is the case when using isolated-atom E0s), and this will most likely result in less stable potentials for molecular dynamics applications. + +When finetuning foundation models, you can use `--E0s="estimated"`, which estimates the atomic reference energies by solving a linear system that optimally corrects the foundation model's predictions on the training data. This approach computes E0 corrections by first running the foundation model on all training configurations, computing the prediction errors (reference energies minus predicted energies), and then solving a least-squares system to find optimal E0 corrections for each element. This is preferable in general over the 'average' option. If the keyword `--stage_two` (previously called swa) is enabled, the energy weight of the loss is increased for the last ~20% of the training epochs (from `--start_stage_two` epochs). This setting usually helps lower the energy errors. From d9eaf0c07619acf639ccbb800a7d0d1898ec14b5 Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Fri, 28 Nov 2025 14:52:52 +0000 Subject: [PATCH 09/11] merge LoRA weights back into the model after training --- mace/cli/run_train.py | 5 +- mace/tools/lora_tools.py | 170 +++++++++++++++++++++++++++++++++++++++ tests/test_lora.py | 98 +++++++++++++++++++++- 3 files changed, 271 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 656d1ca48..123e9103e 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -36,7 +36,7 @@ from mace.data import KeySpecification, update_keyspec_from_kwargs from mace.tools import torch_geometric from mace.tools.distributed_tools import init_distributed -from mace.tools.lora_tools import inject_LoRAs +from mace.tools.lora_tools import inject_LoRAs, merge_lora_weights from mace.tools.model_script_utils import configure_model from mace.tools.multihead_tools import ( HeadConfig, @@ -1028,6 +1028,9 @@ def run(args) -> None: model_path = Path(args.checkpoints_dir) / (tag + ".model") logging.info(f"Saving model to {model_path}") model_to_save = deepcopy(model) + if args.lora: + logging.info("Merging LoRA weights into base model") + merge_lora_weights(model_to_save) if args.enable_cueq and not args.only_cueq: logging.info("RUNING CUEQ TO E3NN") model_to_save = run_cueq_to_e3nn(deepcopy(model), device=device) diff --git a/mace/tools/lora_tools.py b/mace/tools/lora_tools.py index e65b23400..9ce750b0f 100644 --- a/mace/tools/lora_tools.py +++ b/mace/tools/lora_tools.py @@ -179,3 +179,173 @@ def inject_lora( def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1): inject_lora(model, rank=rank, alpha=alpha, wrap_equivariant=True, wrap_dense=True) return model + + +def _merge_dense_lora(lora_module: LoRADenseLinear) -> nn.Linear: + """ + Merge LoRA weights for nn.Linear. + + For nn.Linear, the forward is: base(x) + scaling * lora_B(lora_A(x)) + Which equals: x @ W_base.T + b + scaling * x @ W_A.T @ W_B.T + So: W_merged = W_base + scaling * (W_B @ W_A) + """ + with torch.no_grad(): + # lora_A.weight: (rank, in_features) + # lora_B.weight: (out_features, rank) + # delta: (out_features, in_features) + delta = lora_module.lora_B.weight @ lora_module.lora_A.weight + lora_module.base.weight.add_(lora_module.scaling * delta) + return lora_module.base + + +def _merge_fc_lora(lora_module: LoRAFCLayer) -> nn.Module: + """ + Merge LoRA weights for e3nn _Layer. + + The forward computes: base(x) with weight = w_orig + scaling * (lora_A @ lora_B) + So we simply add the delta to the base weight permanently. + """ + with torch.no_grad(): + # lora_A: (in_f, rank) + # lora_B: (rank, out_f) + # delta: (in_f, out_f) - matches e3nn weight layout + delta = lora_module.lora_A @ lora_module.lora_B + lora_module.base.weight.add_(lora_module.scaling * delta) + return lora_module.base + + +def _merge_o3_lora(lora_module: LoRAO3Linear) -> o3.Linear: + """ + Merge LoRA weights for o3.Linear by direct weight composition. + + For o3.Linear, each instruction connects an input irrep index to an output + irrep index. The LoRA composition B(A(x)) goes through an intermediate + bottleneck representation. We match instructions by their (i_in, i_out) + indices and compose the weight blocks. + + Formula: W_merged = W_base + scaling * (pw_A * pw_B / pw_base) * (W_A @ W_B) + """ + base = lora_module.base + lora_A = lora_module.lora_A + lora_B = lora_module.lora_B + scaling = lora_module.scaling + + with torch.no_grad(): + # Extract weight blocks indexed by instruction + def extract_weight_blocks(linear): + blocks = {} + offset = 0 + for idx, instr in enumerate(linear.instructions): + size = instr.path_shape[0] * instr.path_shape[1] + block = linear.weight[offset : offset + size].reshape(instr.path_shape) + blocks[idx] = block + offset += size + return blocks + + base_blocks = extract_weight_blocks(base) + A_blocks = extract_weight_blocks(lora_A) + B_blocks = extract_weight_blocks(lora_B) + + # Build lookup tables for lora_A and lora_B instructions + # lora_A: maps i_in -> (instruction_idx, i_out) + A_by_i_in = {} + for idx, instr in enumerate(lora_A.instructions): + A_by_i_in[instr.i_in] = (idx, instr.i_out) + + # lora_B: maps (i_in, i_out) -> instruction_idx + B_by_in_out = {} + for idx, instr in enumerate(lora_B.instructions): + B_by_in_out[(instr.i_in, instr.i_out)] = idx + + # Compute merged weight blocks + merged_blocks = [] + for base_idx, base_instr in enumerate(base.instructions): + i_in_base = base_instr.i_in + i_out_base = base_instr.i_out + pw_base = base_instr.path_weight + + # Find corresponding lora_A instruction (input -> bottleneck) + if i_in_base not in A_by_i_in: + # No LoRA for this path, keep base unchanged + merged_blocks.append(base_blocks[base_idx]) + continue + + A_idx, i_mid = A_by_i_in[i_in_base] + pw_A = lora_A.instructions[A_idx].path_weight + + # Find corresponding lora_B instruction (bottleneck -> output) + B_key = (i_mid, i_out_base) + if B_key not in B_by_in_out: + # No LoRA for this path, keep base unchanged + merged_blocks.append(base_blocks[base_idx]) + continue + + B_idx = B_by_in_out[B_key] + pw_B = lora_B.instructions[B_idx].path_weight + + # Compose: W_delta = (pw_A * pw_B / pw_base) * (W_A @ W_B) + ratio = (pw_A * pw_B) / pw_base + delta = A_blocks[A_idx] @ B_blocks[B_idx] + merged = base_blocks[base_idx] + scaling * ratio * delta + merged_blocks.append(merged) + + # Flatten merged blocks back into weight tensor + merged_weight = torch.cat([b.flatten() for b in merged_blocks]) + base.weight.copy_(merged_weight) + + return base + + +def merge_lora_weights(model: nn.Module, inplace: bool = True) -> nn.Module: + """ + Merge LoRA weights into base weights and replace LoRA wrappers with merged base modules. + + This eliminates the inference overhead from LoRA by folding the low-rank + adaptations directly into the original weight matrices. After merging: + - LoRADenseLinear -> nn.Linear (with merged weights) + - LoRAFCLayer -> e3nn _Layer (with merged weights) + - LoRAO3Linear -> o3.Linear (with merged weights) + + Args: + model: Model containing LoRA layers to merge. + inplace: If True, modifies the model in place. If False, works on a deep copy. + + Returns: + Model with LoRA weights merged into base layers. All parameters will have + requires_grad=True after merging. + + Example: + >>> model = load_model(...) + >>> inject_lora(model, rank=4) + >>> train(model) # Train with LoRA + >>> merge_lora_weights(model) # Merge for fast inference + >>> save_model(model) + """ + if not inplace: + import copy + + model = copy.deepcopy(model) + + _merge_lora_recursive(model) + + # Re-enable gradients for all parameters (they were frozen during LoRA training) + for param in model.parameters(): + param.requires_grad = True + + return model + + +def _merge_lora_recursive(module: nn.Module) -> None: + """Recursively merge LoRA layers in a module.""" + for name, child in list(module.named_children()): + if isinstance(child, LoRADenseLinear): + merged = _merge_dense_lora(child) + setattr(module, name, merged) + elif isinstance(child, LoRAFCLayer): + merged = _merge_fc_lora(child) + setattr(module, name, merged) + elif isinstance(child, LoRAO3Linear): + merged = _merge_o3_lora(child) + setattr(module, name, merged) + else: + _merge_lora_recursive(child) diff --git a/tests/test_lora.py b/tests/test_lora.py index ce90b7154..e75561e1b 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -11,7 +11,7 @@ from mace import data, modules, tools from mace.data import Configuration from mace.tools import torch_geometric -from mace.tools.lora_tools import inject_lora +from mace.tools.lora_tools import inject_lora, merge_lora_weights def _random_config() -> Configuration: @@ -228,3 +228,99 @@ def test_lora_symmetry_equivariance(build_lora_model, random_configs) -> None: rtol=1e-5, atol=1e-5, ) + + +def test_lora_merge_preserves_outputs(build_lora_model, random_configs) -> None: + """Test that merging LoRA weights produces identical outputs.""" + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + model.eval() + + # Get outputs before merging + configs = list(random_configs) + energy_before, forces_before = _forward_energy_forces(model, configs, table) + + # Merge LoRA weights + merge_lora_weights(model) + model.eval() + + # Get outputs after merging + energy_after, forces_after = _forward_energy_forces(model, configs, table) + + # Outputs should be identical (within numerical precision) + assert torch.allclose(energy_before, energy_after, rtol=1e-5, atol=1e-6), ( + f"Energy mismatch after merge: {energy_before} vs {energy_after}" + ) + assert torch.allclose(forces_before, forces_after, rtol=1e-5, atol=1e-6), ( + f"Forces mismatch after merge: max diff = {(forces_before - forces_after).abs().max()}" + ) + + +def test_lora_merge_removes_wrappers(build_lora_model) -> None: + """Test that merging removes LoRA wrapper modules.""" + from mace.tools.lora_tools import LoRADenseLinear, LoRAFCLayer, LoRAO3Linear + + model, _ = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Count LoRA wrappers before merge + def count_lora_wrappers(module): + count = 0 + for child in module.modules(): + if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear)): + count += 1 + return count + + wrappers_before = count_lora_wrappers(model) + assert wrappers_before > 0, "Model should have LoRA wrappers before merge" + + # Merge + merge_lora_weights(model) + + # Count LoRA wrappers after merge + wrappers_after = count_lora_wrappers(model) + assert wrappers_after == 0, f"Model still has {wrappers_after} LoRA wrappers after merge" + + +def test_lora_merge_enables_gradients(build_lora_model) -> None: + """Test that merging re-enables gradients for all parameters.""" + model, _ = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Before merge, only LoRA params have gradients + non_lora_grads_before = [ + name + for name, p in model.named_parameters() + if "lora_" not in name and p.requires_grad + ] + assert not non_lora_grads_before, "Non-LoRA params should be frozen before merge" + + # Merge + merge_lora_weights(model) + + # After merge, all params should have gradients + frozen_after = [name for name, p in model.named_parameters() if not p.requires_grad] + assert not frozen_after, f"Some parameters frozen after merge: {frozen_after}" + + +def test_lora_merge_preserves_equivariance(build_lora_model, random_configs) -> None: + """Test that merged model preserves rotational equivariance.""" + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Merge LoRA weights + merge_lora_weights(model) + model.eval() + + base_cfg = random_configs[0] + energy, forces = _forward_energy_forces(model, [base_cfg], table) + energy_val = energy.item() + forces_val = forces.squeeze(0).detach().numpy() + + # Test rotation equivariance after merge + R = _rotation_matrix() + rotated_cfg = _rotate_config(base_cfg, R) + energy_rot, forces_rot = _forward_energy_forces(model, [rotated_cfg], table) + + assert np.allclose(energy_rot.item(), energy_val, rtol=1e-6, atol=1e-6), ( + "Energy not invariant under rotation after merge" + ) + assert np.allclose( + forces_val @ R.T, forces_rot.squeeze(0).detach().numpy(), rtol=1e-5, atol=1e-5 + ), "Forces not equivariant under rotation after merge" From 8310e978d24c8902e5d11a6dd08b5bb9de71610c Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Fri, 28 Nov 2025 15:37:01 +0000 Subject: [PATCH 10/11] cache LoRA deltas in eval mode to speed up training --- mace/tools/lora_tools.py | 361 +++++++++++++++++++-------------------- 1 file changed, 180 insertions(+), 181 deletions(-) diff --git a/mace/tools/lora_tools.py b/mace/tools/lora_tools.py index 9ce750b0f..b6ec0f486 100644 --- a/mace/tools/lora_tools.py +++ b/mace/tools/lora_tools.py @@ -1,4 +1,5 @@ import torch +import torch.nn.functional as F from e3nn import o3 from e3nn.nn._fc import _Layer as E3NNFCLayer from torch import nn @@ -23,7 +24,11 @@ def build_lora_irreps( class LoRAO3Linear(nn.Module): - """LoRA for equivariant o3.Linear-like layers (preserves O(3) equivariance).""" + """LoRA for equivariant o3.Linear-like layers (preserves O(3) equivariance). + + Uses fused weight computation: W_merged = W_base + scaling * (W_A @ W_B) + with automatic caching during inference (when grad is disabled). + """ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): super().__init__() @@ -32,6 +37,7 @@ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): self.irreps_out = self.base.irreps_out self.scaling = float(alpha) / float(rank) self.lora_irreps = build_lora_irreps(self.irreps_in, self.irreps_out, rank) + # Use the same class as base to avoid layout mismatches if possible layer_type = type(self.base) self.lora_A = layer_type( @@ -40,11 +46,18 @@ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): self.lora_B = layer_type( self.lora_irreps, self.irreps_out, internal_weights=True, biases=False ) + # Match dtype/device to base base_param = next(self.base.parameters()) self.lora_A.to(dtype=base_param.dtype, device=base_param.device) self.lora_B.to(dtype=base_param.dtype, device=base_param.device) + # Cache for merged weight (used during inference) + self._cached_merged_weight: torch.Tensor | None = None + + # Build instruction mapping for weight composition + self._build_instruction_mapping() + with torch.no_grad(): for p in self.lora_B.parameters(): p.zero_() @@ -52,42 +65,152 @@ def __init__(self, base_linear: o3.Linear, rank: int = 4, alpha: float = 1.0): if p.dim() >= 2: p.normal_(mean=0.0, std=1e-3) + def _build_instruction_mapping(self) -> None: + """Build lookup tables for matching instructions between base, A, and B.""" + # lora_A: maps i_in -> (instruction_idx, i_out, path_weight) + self._A_by_i_in = {} + for idx, instr in enumerate(self.lora_A.instructions): + self._A_by_i_in[instr.i_in] = (idx, instr.i_out, instr.path_weight) + + # lora_B: maps (i_in, i_out) -> (instruction_idx, path_weight) + self._B_by_in_out = {} + for idx, instr in enumerate(self.lora_B.instructions): + self._B_by_in_out[(instr.i_in, instr.i_out)] = (idx, instr.path_weight) + + @staticmethod + def _extract_weight_blocks(linear: o3.Linear) -> dict[int, torch.Tensor]: + """Extract weight blocks indexed by instruction.""" + blocks = {} + offset = 0 + for idx, instr in enumerate(linear.instructions): + size = instr.path_shape[0] * instr.path_shape[1] + block = linear.weight[offset : offset + size].reshape(instr.path_shape) + blocks[idx] = block + offset += size + return blocks + + def compute_merged_weight(self) -> torch.Tensor: + """Compute W_base + scaling * composed(W_A, W_B) in weight space.""" + base_blocks = self._extract_weight_blocks(self.base) + A_blocks = self._extract_weight_blocks(self.lora_A) + B_blocks = self._extract_weight_blocks(self.lora_B) + + merged_blocks = [] + for base_idx, base_instr in enumerate(self.base.instructions): + i_in_base = base_instr.i_in + i_out_base = base_instr.i_out + pw_base = base_instr.path_weight + + # Find corresponding lora_A instruction + if i_in_base not in self._A_by_i_in: + merged_blocks.append(base_blocks[base_idx]) + continue + + A_idx, i_mid, pw_A = self._A_by_i_in[i_in_base] + + # Find corresponding lora_B instruction + B_key = (i_mid, i_out_base) + if B_key not in self._B_by_in_out: + merged_blocks.append(base_blocks[base_idx]) + continue + + B_idx, pw_B = self._B_by_in_out[B_key] + + # Compose: W_delta = (pw_A * pw_B / pw_base) * (W_A @ W_B) + ratio = (pw_A * pw_B) / pw_base + delta = A_blocks[A_idx] @ B_blocks[B_idx] + merged = base_blocks[base_idx] + self.scaling * ratio * delta + merged_blocks.append(merged) + + return torch.cat([b.flatten() for b in merged_blocks]) + def forward(self, x: torch.Tensor) -> torch.Tensor: - base = self.base(x) - delta = self.lora_B(self.lora_A(x)) - return base + self.scaling * delta + if torch.is_grad_enabled(): + # Training: use activation-space computation for correct gradient flow + self._cached_merged_weight = None + return self.base(x) + self.scaling * self.lora_B(self.lora_A(x)) + + # Inference: use fused weight-space computation with caching + if self._cached_merged_weight is None: + self._cached_merged_weight = self.compute_merged_weight() + + original_weight = self.base.weight.data + self.base.weight.data = self._cached_merged_weight + try: + return self.base(x) + finally: + self.base.weight.data = original_weight + + def merge_into_base(self) -> o3.Linear: + """Permanently merge LoRA weights into base and return the base layer.""" + with torch.no_grad(): + self.base.weight.copy_(self.compute_merged_weight()) + return self.base class LoRADenseLinear(nn.Module): - """LoRA for torch.nn.Linear""" + """LoRA for torch.nn.Linear. + + Uses fused weight computation: W_merged = W_base + scaling * (W_B @ W_A) + with automatic caching during inference (when grad is disabled). + """ def __init__(self, base_linear: nn.Linear, rank: int = 4, alpha: float = 1.0): super().__init__() self.base = base_linear - in_f = self.base.in_features - out_f = self.base.out_features + self.in_features = base_linear.in_features + self.out_features = base_linear.out_features self.scaling = float(alpha) / float(rank) - self.lora_A = nn.Linear(in_f, rank, bias=False) - self.lora_B = nn.Linear(rank, out_f, bias=False) - # match dtype/device to base + # LoRA matrices: W_delta = W_B @ W_A + # W_A: (rank, in_features), W_B: (out_features, rank) + self.lora_A = nn.Linear(self.in_features, rank, bias=False) + self.lora_B = nn.Linear(rank, self.out_features, bias=False) + + # Match dtype/device to base base_param = next(self.base.parameters()) self.lora_A.to(dtype=base_param.dtype, device=base_param.device) self.lora_B.to(dtype=base_param.dtype, device=base_param.device) + # Cache for weight delta (used during inference) + self._cached_delta: torch.Tensor | None = None + with torch.no_grad(): nn.init.zeros_(self.lora_B.weight) nn.init.normal_(self.lora_A.weight, mean=0.0, std=1e-3) + def compute_delta(self) -> torch.Tensor: + """Compute the LoRA weight delta: W_B @ W_A.""" + return self.lora_B.weight @ self.lora_A.weight + def forward(self, x: torch.Tensor) -> torch.Tensor: - base = self.base(x) - delta = self.lora_B(self.lora_A(x)) - return base + self.scaling * delta + if torch.is_grad_enabled(): + # Training: compute fresh delta (gradients flow through B @ A) + self._cached_delta = None + delta = self.compute_delta() + else: + # Inference: use cached delta + if self._cached_delta is None: + self._cached_delta = self.compute_delta() + delta = self._cached_delta + + merged_weight = self.base.weight + self.scaling * delta + return F.linear(x, merged_weight, self.base.bias) + + def merge_into_base(self) -> nn.Linear: + """Permanently merge LoRA weights into base and return the base layer.""" + with torch.no_grad(): + self.base.weight.add_(self.scaling * self.compute_delta()) + return self.base class LoRAFCLayer(nn.Module): """LoRA for e3nn.nn._fc._Layer used by FullyConnectedNet (scalar MLP). - Adds a low-rank delta on the weight matrix. + + Uses fused weight computation: W_merged = W_base + scaling * (A @ B) + with automatic caching during inference (when grad is disabled). + + Note: e3nn uses (in, out) weight layout, so delta = A @ B (not B @ A). """ def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0): @@ -100,37 +223,50 @@ def __init__(self, base_layer: nn.Module, rank: int = 4, alpha: float = 1.0): in_f, out_f = int(w.shape[0]), int(w.shape[1]) self.scaling = float(alpha) / float(rank) - # Use explicit parameters to match e3nn layout [in, out] - self.lora_A = nn.Parameter( - torch.empty(in_f, rank, device=w.device, dtype=w.dtype) - ) - self.lora_B = nn.Parameter( - torch.empty(rank, out_f, device=w.device, dtype=w.dtype) - ) + # LoRA matrices: delta = A @ B (e3nn layout: in_f x out_f) + self.lora_A = nn.Parameter(torch.empty(in_f, rank, device=w.device, dtype=w.dtype)) + self.lora_B = nn.Parameter(torch.empty(rank, out_f, device=w.device, dtype=w.dtype)) + + # Cache for weight delta (used during inference) + self._cached_delta: torch.Tensor | None = None with torch.no_grad(): - torch.nn.init.normal_(self.lora_A, mean=0.0, std=1e-3) - torch.nn.init.zeros_(self.lora_B) + nn.init.normal_(self.lora_A, mean=0.0, std=1e-3) + nn.init.zeros_(self.lora_B) + + def compute_delta(self) -> torch.Tensor: + """Compute the LoRA weight delta: A @ B.""" + return self.lora_A @ self.lora_B def forward(self, x: torch.Tensor) -> torch.Tensor: - # Temporarily patch the weight of the base layer to include LoRA delta - # This avoids re-implementing the complex normalization/activation logic of e3nn _Layer - w_orig = self.base.weight - delta = self.lora_A @ self.lora_B - weight_patched = w_orig + self.scaling * delta + if torch.is_grad_enabled(): + # Training: compute fresh delta (gradients flow through A @ B) + self._cached_delta = None + delta = self.compute_delta() + else: + # Inference: use cached delta + if self._cached_delta is None: + self._cached_delta = self.compute_delta() + delta = self._cached_delta - # Patch: self.base.weight is a Parameter. We replace it with a Tensor for the forward pass. - # To do this safely in PyTorch, we must temporarily remove it from _parameters. - del self.base._parameters["weight"] - self.base.weight = weight_patched + merged_weight = self.base.weight + self.scaling * delta + # Temporarily patch weight for forward (dict manipulation preserves gradient flow) + w_orig = self.base.weight + del self.base._parameters["weight"] + self.base.weight = merged_weight try: return self.base(x) finally: - # Restore the original Parameter self.base.weight = w_orig self.base._parameters["weight"] = w_orig + def merge_into_base(self) -> nn.Module: + """Permanently merge LoRA weights into base and return the base layer.""" + with torch.no_grad(): + self.base.weight.add_(self.scaling * self.compute_delta()) + return self.base + def inject_lora( module: nn.Module, @@ -140,10 +276,7 @@ def inject_lora( wrap_dense: bool = True, _is_root: bool = True, ) -> None: - """ - Recursively replace eligible linears with LoRA-wrapped versions. - """ - + """Recursively replace eligible linears with LoRA-wrapped versions.""" for child_name, child in list(module.named_children()): # Skip already wrapped if isinstance(child, (LoRAO3Linear, LoRADenseLinear, LoRAFCLayer)): @@ -170,10 +303,7 @@ def inject_lora( if _is_root: for name, p in module.named_parameters(): - if ("lora_A" in name) or ("lora_B" in name): - p.requires_grad = True - else: - p.requires_grad = False + p.requires_grad = ("lora_A" in name) or ("lora_B" in name) def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1): @@ -181,121 +311,6 @@ def inject_LoRAs(model: nn.Module, rank: int = 4, alpha: int = 1): return model -def _merge_dense_lora(lora_module: LoRADenseLinear) -> nn.Linear: - """ - Merge LoRA weights for nn.Linear. - - For nn.Linear, the forward is: base(x) + scaling * lora_B(lora_A(x)) - Which equals: x @ W_base.T + b + scaling * x @ W_A.T @ W_B.T - So: W_merged = W_base + scaling * (W_B @ W_A) - """ - with torch.no_grad(): - # lora_A.weight: (rank, in_features) - # lora_B.weight: (out_features, rank) - # delta: (out_features, in_features) - delta = lora_module.lora_B.weight @ lora_module.lora_A.weight - lora_module.base.weight.add_(lora_module.scaling * delta) - return lora_module.base - - -def _merge_fc_lora(lora_module: LoRAFCLayer) -> nn.Module: - """ - Merge LoRA weights for e3nn _Layer. - - The forward computes: base(x) with weight = w_orig + scaling * (lora_A @ lora_B) - So we simply add the delta to the base weight permanently. - """ - with torch.no_grad(): - # lora_A: (in_f, rank) - # lora_B: (rank, out_f) - # delta: (in_f, out_f) - matches e3nn weight layout - delta = lora_module.lora_A @ lora_module.lora_B - lora_module.base.weight.add_(lora_module.scaling * delta) - return lora_module.base - - -def _merge_o3_lora(lora_module: LoRAO3Linear) -> o3.Linear: - """ - Merge LoRA weights for o3.Linear by direct weight composition. - - For o3.Linear, each instruction connects an input irrep index to an output - irrep index. The LoRA composition B(A(x)) goes through an intermediate - bottleneck representation. We match instructions by their (i_in, i_out) - indices and compose the weight blocks. - - Formula: W_merged = W_base + scaling * (pw_A * pw_B / pw_base) * (W_A @ W_B) - """ - base = lora_module.base - lora_A = lora_module.lora_A - lora_B = lora_module.lora_B - scaling = lora_module.scaling - - with torch.no_grad(): - # Extract weight blocks indexed by instruction - def extract_weight_blocks(linear): - blocks = {} - offset = 0 - for idx, instr in enumerate(linear.instructions): - size = instr.path_shape[0] * instr.path_shape[1] - block = linear.weight[offset : offset + size].reshape(instr.path_shape) - blocks[idx] = block - offset += size - return blocks - - base_blocks = extract_weight_blocks(base) - A_blocks = extract_weight_blocks(lora_A) - B_blocks = extract_weight_blocks(lora_B) - - # Build lookup tables for lora_A and lora_B instructions - # lora_A: maps i_in -> (instruction_idx, i_out) - A_by_i_in = {} - for idx, instr in enumerate(lora_A.instructions): - A_by_i_in[instr.i_in] = (idx, instr.i_out) - - # lora_B: maps (i_in, i_out) -> instruction_idx - B_by_in_out = {} - for idx, instr in enumerate(lora_B.instructions): - B_by_in_out[(instr.i_in, instr.i_out)] = idx - - # Compute merged weight blocks - merged_blocks = [] - for base_idx, base_instr in enumerate(base.instructions): - i_in_base = base_instr.i_in - i_out_base = base_instr.i_out - pw_base = base_instr.path_weight - - # Find corresponding lora_A instruction (input -> bottleneck) - if i_in_base not in A_by_i_in: - # No LoRA for this path, keep base unchanged - merged_blocks.append(base_blocks[base_idx]) - continue - - A_idx, i_mid = A_by_i_in[i_in_base] - pw_A = lora_A.instructions[A_idx].path_weight - - # Find corresponding lora_B instruction (bottleneck -> output) - B_key = (i_mid, i_out_base) - if B_key not in B_by_in_out: - # No LoRA for this path, keep base unchanged - merged_blocks.append(base_blocks[base_idx]) - continue - - B_idx = B_by_in_out[B_key] - pw_B = lora_B.instructions[B_idx].path_weight - - # Compose: W_delta = (pw_A * pw_B / pw_base) * (W_A @ W_B) - ratio = (pw_A * pw_B) / pw_base - delta = A_blocks[A_idx] @ B_blocks[B_idx] - merged = base_blocks[base_idx] + scaling * ratio * delta - merged_blocks.append(merged) - - # Flatten merged blocks back into weight tensor - merged_weight = torch.cat([b.flatten() for b in merged_blocks]) - base.weight.copy_(merged_weight) - - return base - - def merge_lora_weights(model: nn.Module, inplace: bool = True) -> nn.Module: """ Merge LoRA weights into base weights and replace LoRA wrappers with merged base modules. @@ -313,39 +328,23 @@ def merge_lora_weights(model: nn.Module, inplace: bool = True) -> nn.Module: Returns: Model with LoRA weights merged into base layers. All parameters will have requires_grad=True after merging. - - Example: - >>> model = load_model(...) - >>> inject_lora(model, rank=4) - >>> train(model) # Train with LoRA - >>> merge_lora_weights(model) # Merge for fast inference - >>> save_model(model) """ if not inplace: import copy model = copy.deepcopy(model) - _merge_lora_recursive(model) + def merge_recursive(module: nn.Module) -> None: + for name, child in list(module.named_children()): + if isinstance(child, (LoRADenseLinear, LoRAFCLayer, LoRAO3Linear)): + setattr(module, name, child.merge_into_base()) + else: + merge_recursive(child) + + merge_recursive(model) - # Re-enable gradients for all parameters (they were frozen during LoRA training) + # Re-enable gradients for all parameters for param in model.parameters(): param.requires_grad = True return model - - -def _merge_lora_recursive(module: nn.Module) -> None: - """Recursively merge LoRA layers in a module.""" - for name, child in list(module.named_children()): - if isinstance(child, LoRADenseLinear): - merged = _merge_dense_lora(child) - setattr(module, name, merged) - elif isinstance(child, LoRAFCLayer): - merged = _merge_fc_lora(child) - setattr(module, name, merged) - elif isinstance(child, LoRAO3Linear): - merged = _merge_o3_lora(child) - setattr(module, name, merged) - else: - _merge_lora_recursive(child) From 84d4f5362181f46f914b48a71a56c92ebd873b0d Mon Sep 17 00:00:00 2001 From: ttompa <01_buck_jubilee@icloud.com> Date: Fri, 28 Nov 2025 15:52:39 +0000 Subject: [PATCH 11/11] add tests to check required_grad behaves correctly when using LoRA finetuning --- tests/test_lora.py | 52 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/test_lora.py b/tests/test_lora.py index e75561e1b..42213224e 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -324,3 +324,55 @@ def test_lora_merge_preserves_equivariance(build_lora_model, random_configs) -> assert np.allclose( forces_val @ R.T, forces_rot.squeeze(0).detach().numpy(), rtol=1e-5, atol=1e-5 ), "Forces not equivariant under rotation after merge" + + +def test_lora_evaluate_preserves_frozen_state(build_lora_model, random_configs) -> None: + """Test that evaluate() preserves requires_grad states for LoRA models. + """ + from mace.tools import evaluate + from mace.modules.loss import WeightedEnergyForcesLoss + + model, table = build_lora_model(rank=2, alpha=0.5, randomize=True) + + # Record which parameters should be trainable (only LoRA params) + lora_params_before = { + name: p.requires_grad for name, p in model.named_parameters() + } + trainable_before = [name for name, grad in lora_params_before.items() if grad] + frozen_before = [name for name, grad in lora_params_before.items() if not grad] + + # Verify initial state: only LoRA params are trainable + assert all("lora_" in name for name in trainable_before), ( + "Only LoRA params should be trainable initially" + ) + assert len(frozen_before) > 0, "Some base params should be frozen" + + # Create a minimal data loader for evaluation + configs = list(random_configs) + dataset = [_atomic_data_from_config(cfg, table) for cfg in configs] + loader = torch_geometric.dataloader.DataLoader( + dataset=dataset, + batch_size=len(dataset), + shuffle=False, + drop_last=False, + ) + + # Run evaluate + loss_fn = WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) + output_args = {"forces": True, "virials": False, "stress": False} + evaluate(model, loss_fn, loader, output_args, device=torch.device("cpu")) + + # Check that requires_grad states are preserved + lora_params_after = { + name: p.requires_grad for name, p in model.named_parameters() + } + + for name in trainable_before: + assert lora_params_after[name], ( + f"LoRA param {name} should still be trainable after evaluate()" + ) + + for name in frozen_before: + assert not lora_params_after[name], ( + f"Base param {name} should still be frozen after evaluate()" + )