From 3d5fe580abda148abd88d7aed96aae1238a78079 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Thu, 18 Dec 2025 10:46:54 -0600 Subject: [PATCH 1/3] Write files from root process only when distributed In particular, fixes issue where distributed mace_run_train processes overwrite each others' pretrained and combined data files --- mace/cli/fine_tuning_select.py | 24 +++++++++++++++--------- mace/data/utils.py | 9 ++++++--- mace/tools/utils.py | 8 ++++++++ 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 60dd4b85b..79e1e9804 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -17,6 +17,7 @@ import torch from mace.calculators import MACECalculator, mace_mp +from mace.tools.utils import distributed_on_root, distributed_barrier try: import fpsample # type: ignore @@ -454,7 +455,6 @@ def _write_metadata( if head is not None: a.info["head"] = head - def select_samples( settings: SelectionSettings, ) -> None: @@ -491,7 +491,9 @@ def select_samples( f"filename '{settings.output}' does no have " "suffix compatible with extxyz format" ) - _maybe_save_descriptors(subsampled_atoms, settings.output) + if distributed_on_root(): + _maybe_save_descriptors(subsampled_atoms, settings.output) + distributed_barrier() _write_metadata( subsampled_atoms, @@ -506,17 +508,21 @@ def select_samples( head=settings.head_ft, ) - logging.info("Saving the selected configurations") - ase.io.write(settings.output, subsampled_atoms) + if distributed_on_root(): + logging.info("Saving the selected configurations") + ase.io.write(settings.output, subsampled_atoms) + distributed_barrier() logging.info("Saving a combined XYZ file") atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft - output = Path(settings.output) - ase.io.write( - output.parent / (output.stem + "_combined" + output.suffix), - atoms_fps_pt_ft, - ) + if distributed_on_root(): + output = Path(settings.output) + ase.io.write( + output.parent / (output.stem + "_combined" + output.suffix), + atoms_fps_pt_ft, + ) + distributed_barrier() def main(): diff --git a/mace/data/utils.py b/mace/data/utils.py index 049d42cdd..4bb833f37 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -15,6 +15,7 @@ import numpy as np from mace.tools import AtomicNumberTable, DefaultKeys +from mace.tools.utils import distributed_on_root, distributed_barrier Positions = np.ndarray # [..., 3] Cell = np.ndarray # [3,3] @@ -132,9 +133,11 @@ def random_train_valid_split( if prefix is not None and len(prefix) > 0: filename = f"{prefix}_" + filename path = os.path.join(work_dir, filename) - with open(path, "w", encoding="utf-8") as f: - for index in indices[train_size:]: - f.write(f"{index}\n") + if distributed_on_root(): + with open(path, "w", encoding="utf-8") as f: + for index in indices[train_size:]: + f.write(f"{index}\n") + distributed_barrier() logging.info( f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {path}" diff --git a/mace/tools/utils.py b/mace/tools/utils.py index ae2c9bf3f..24aef8e1c 100644 --- a/mace/tools/utils.py +++ b/mace/tools/utils.py @@ -205,3 +205,11 @@ def filter_nonzero_weight( quantity_l[-1] = filtered_q return 1.0 + +def distributed_on_root(): + distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + return (not distributed) or (torch.distributed.get_rank() == 0) + +def distributed_barrier(): + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() From 5b4809d5f6791c75b20caa87bf5fc13a657b46e1 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Thu, 18 Dec 2025 14:08:03 -0600 Subject: [PATCH 2/3] when copying readouts from foundations models, don't copy if tensor has size 0, to avoid unused parameters error --- mace/tools/finetuning_utils.py | 40 +++++++++++++++------------------- 1 file changed, 17 insertions(+), 23 deletions(-) diff --git a/mace/tools/finetuning_utils.py b/mace/tools/finetuning_utils.py index 3e81492bd..8f0509dcf 100644 --- a/mace/tools/finetuning_utils.py +++ b/mace/tools/finetuning_utils.py @@ -191,11 +191,10 @@ def load_foundations_elements( if load_readout: # Transferring readouts - for i, readout in enumerate(model.readouts): + for readout, foundation_readout in zip(model.readouts, model_foundations.readouts): if readout.__class__.__name__ == "LinearReadoutBlock": - model_readouts_zero_linear_weight = readout.linear.weight.clone() model_readouts_zero_linear_weight = ( - model_foundations.readouts[i] + foundation_readout .linear.weight.view(num_channels_foundation, -1) .repeat(1, len(model_heads)) .flatten() @@ -214,7 +213,7 @@ def load_foundations_elements( # Determine shapes once to avoid uninitialized use if hasattr(readout, "linear_1"): shape_input_1 = ( - model_foundations.readouts[i] + foundation_readout .linear_1.__dict__["irreps_out"] .num_irreps ) @@ -222,9 +221,8 @@ def load_foundations_elements( else: raise ValueError("Readout block must have linear_1") if hasattr(readout, "linear_1"): - model_readouts_one_linear_1_weight = readout.linear_1.weight.clone() model_readouts_one_linear_1_weight = ( - model_foundations.readouts[i] + foundation_readout .linear_1.weight.view(num_channels_foundation, -1) .repeat(1, len(model_heads)) .flatten() @@ -233,10 +231,9 @@ def load_foundations_elements( readout.linear_1.weight = torch.nn.Parameter( model_readouts_one_linear_1_weight ) - if readout.linear_1.bias is not None: - model_readouts_one_linear_1_bias = readout.linear_1.bias.clone() + if readout.linear_1.bias is not None and readout.linear_1.bias.nelement() != 0: model_readouts_one_linear_1_bias = ( - model_foundations.readouts[i] + foundation_readout .linear_1.bias.view(-1) .repeat(len(model_heads)) .clone() @@ -246,7 +243,7 @@ def load_foundations_elements( ) if hasattr(readout, "linear_mid"): readout.linear_mid.weight = torch.nn.Parameter( - model_foundations.readouts[i] + foundation_readout .linear_mid.weight.view( shape_input_1, shape_input_1, @@ -257,28 +254,25 @@ def load_foundations_elements( / ((shape_input_1) / (shape_output_1)) ** 0.5 ) # if it has biases transfer them too - if readout.linear_mid.bias is not None: + if readout.linear_mid.bias is not None and readout.linear_mid.bias.nelement() != 0: readout.linear_mid.bias = torch.nn.Parameter( - model_foundations.readouts[i] + foundation_readout .linear_mid.bias.repeat(len(model_heads)) .clone() ) if hasattr(readout, "linear_2"): - model_readouts_one_linear_2_weight = readout.linear_2.weight.clone() - model_readouts_one_linear_2_weight = model_foundations.readouts[ - i - ].linear_2.weight.view(shape_input_1, -1).repeat( - len(model_heads), len(model_heads) - ).flatten().clone() / ( - ((shape_input_1) / (shape_output_1)) ** 0.5 - ) + model_readouts_one_linear_2_weight = (foundation_readout + .linear_2.weight.view(shape_input_1, -1).repeat( + len(model_heads), len(model_heads) + ).flatten().clone() / ( + ((shape_input_1) / (shape_output_1)) ** 0.5 + )) readout.linear_2.weight = torch.nn.Parameter( model_readouts_one_linear_2_weight ) - if readout.linear_2.bias is not None: - model_readouts_one_linear_2_bias = readout.linear_2.bias.clone() + if readout.linear_2.bias is not None and readout.linear_2.bias.nelement() != 0: model_readouts_one_linear_2_bias = ( - model_foundations.readouts[i] + foundation_readout .linear_2.bias.view(-1) .repeat(len(model_heads)) .flatten() From 52d94106d90ac7af8b262a688d31214d228a1af7 Mon Sep 17 00:00:00 2001 From: Noam Bernstein Date: Thu, 18 Dec 2025 14:11:49 -0600 Subject: [PATCH 3/3] Wait after writing files from task 0 before barrier, to increase chance that other tasks see the compeleted write --- mace/cli/fine_tuning_select.py | 4 ++++ mace/data/utils.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/mace/cli/fine_tuning_select.py b/mace/cli/fine_tuning_select.py index 79e1e9804..01c919cf0 100644 --- a/mace/cli/fine_tuning_select.py +++ b/mace/cli/fine_tuning_select.py @@ -6,6 +6,7 @@ import argparse import ast import logging +import time from dataclasses import dataclass from enum import Enum from pathlib import Path @@ -493,6 +494,7 @@ def select_samples( ) if distributed_on_root(): _maybe_save_descriptors(subsampled_atoms, settings.output) + time.sleep(1) # hope this is enough for filesystem to re-synchronize distributed_barrier() _write_metadata( @@ -511,6 +513,7 @@ def select_samples( if distributed_on_root(): logging.info("Saving the selected configurations") ase.io.write(settings.output, subsampled_atoms) + time.sleep(1) # hope this is enough for filesystem to re-synchronize distributed_barrier() logging.info("Saving a combined XYZ file") @@ -522,6 +525,7 @@ def select_samples( output.parent / (output.stem + "_combined" + output.suffix), atoms_fps_pt_ft, ) + time.sleep(1) # hope this is enough for filesystem to re-synchronize distributed_barrier() diff --git a/mace/data/utils.py b/mace/data/utils.py index 4bb833f37..1af5603c2 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -6,6 +6,7 @@ import logging import os +import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence, Tuple @@ -137,6 +138,7 @@ def random_train_valid_split( with open(path, "w", encoding="utf-8") as f: for index in indices[train_size:]: f.write(f"{index}\n") + time.sleep(1) # hope this is enough for filesystem to re-synchronize distributed_barrier() logging.info(