Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions mace/cli/fine_tuning_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import argparse
import ast
import logging
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
Expand All @@ -17,6 +18,7 @@
import torch

from mace.calculators import MACECalculator, mace_mp
from mace.tools.utils import distributed_on_root, distributed_barrier

try:
import fpsample # type: ignore
Expand Down Expand Up @@ -454,7 +456,6 @@ def _write_metadata(
if head is not None:
a.info["head"] = head


def select_samples(
settings: SelectionSettings,
) -> None:
Expand Down Expand Up @@ -491,7 +492,10 @@ def select_samples(
f"filename '{settings.output}' does no have "
"suffix compatible with extxyz format"
)
_maybe_save_descriptors(subsampled_atoms, settings.output)
if distributed_on_root():
_maybe_save_descriptors(subsampled_atoms, settings.output)
time.sleep(1) # hope this is enough for filesystem to re-synchronize
distributed_barrier()

_write_metadata(
subsampled_atoms,
Expand All @@ -506,17 +510,23 @@ def select_samples(
head=settings.head_ft,
)

logging.info("Saving the selected configurations")
ase.io.write(settings.output, subsampled_atoms)
if distributed_on_root():
logging.info("Saving the selected configurations")
ase.io.write(settings.output, subsampled_atoms)
time.sleep(1) # hope this is enough for filesystem to re-synchronize
distributed_barrier()

logging.info("Saving a combined XYZ file")
atoms_fps_pt_ft = subsampled_atoms + atoms_list_ft

output = Path(settings.output)
ase.io.write(
output.parent / (output.stem + "_combined" + output.suffix),
atoms_fps_pt_ft,
)
if distributed_on_root():
output = Path(settings.output)
ase.io.write(
output.parent / (output.stem + "_combined" + output.suffix),
atoms_fps_pt_ft,
)
time.sleep(1) # hope this is enough for filesystem to re-synchronize
distributed_barrier()


def main():
Expand Down
11 changes: 8 additions & 3 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple

Expand All @@ -15,6 +16,7 @@
import numpy as np

from mace.tools import AtomicNumberTable, DefaultKeys
from mace.tools.utils import distributed_on_root, distributed_barrier

Positions = np.ndarray # [..., 3]
Cell = np.ndarray # [3,3]
Expand Down Expand Up @@ -132,9 +134,12 @@ def random_train_valid_split(
if prefix is not None and len(prefix) > 0:
filename = f"{prefix}_" + filename
path = os.path.join(work_dir, filename)
with open(path, "w", encoding="utf-8") as f:
for index in indices[train_size:]:
f.write(f"{index}\n")
if distributed_on_root():
with open(path, "w", encoding="utf-8") as f:
for index in indices[train_size:]:
f.write(f"{index}\n")
time.sleep(1) # hope this is enough for filesystem to re-synchronize
distributed_barrier()

logging.info(
f"Using random {100 * valid_fraction:.0f}% of training set for validation with indices saved in: {path}"
Expand Down
40 changes: 17 additions & 23 deletions mace/tools/finetuning_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,10 @@ def load_foundations_elements(

if load_readout:
# Transferring readouts
for i, readout in enumerate(model.readouts):
for readout, foundation_readout in zip(model.readouts, model_foundations.readouts):
if readout.__class__.__name__ == "LinearReadoutBlock":
model_readouts_zero_linear_weight = readout.linear.weight.clone()
model_readouts_zero_linear_weight = (
model_foundations.readouts[i]
foundation_readout
.linear.weight.view(num_channels_foundation, -1)
.repeat(1, len(model_heads))
.flatten()
Expand All @@ -214,17 +213,16 @@ def load_foundations_elements(
# Determine shapes once to avoid uninitialized use
if hasattr(readout, "linear_1"):
shape_input_1 = (
model_foundations.readouts[i]
foundation_readout
.linear_1.__dict__["irreps_out"]
.num_irreps
)
shape_output_1 = readout.linear_1.__dict__["irreps_out"].num_irreps
else:
raise ValueError("Readout block must have linear_1")
if hasattr(readout, "linear_1"):
model_readouts_one_linear_1_weight = readout.linear_1.weight.clone()
model_readouts_one_linear_1_weight = (
model_foundations.readouts[i]
foundation_readout
.linear_1.weight.view(num_channels_foundation, -1)
.repeat(1, len(model_heads))
.flatten()
Expand All @@ -233,10 +231,9 @@ def load_foundations_elements(
readout.linear_1.weight = torch.nn.Parameter(
model_readouts_one_linear_1_weight
)
if readout.linear_1.bias is not None:
model_readouts_one_linear_1_bias = readout.linear_1.bias.clone()
if readout.linear_1.bias is not None and readout.linear_1.bias.nelement() != 0:
model_readouts_one_linear_1_bias = (
model_foundations.readouts[i]
foundation_readout
.linear_1.bias.view(-1)
.repeat(len(model_heads))
.clone()
Expand All @@ -246,7 +243,7 @@ def load_foundations_elements(
)
if hasattr(readout, "linear_mid"):
readout.linear_mid.weight = torch.nn.Parameter(
model_foundations.readouts[i]
foundation_readout
.linear_mid.weight.view(
shape_input_1,
shape_input_1,
Expand All @@ -257,28 +254,25 @@ def load_foundations_elements(
/ ((shape_input_1) / (shape_output_1)) ** 0.5
)
# if it has biases transfer them too
if readout.linear_mid.bias is not None:
if readout.linear_mid.bias is not None and readout.linear_mid.bias.nelement() != 0:
readout.linear_mid.bias = torch.nn.Parameter(
model_foundations.readouts[i]
foundation_readout
.linear_mid.bias.repeat(len(model_heads))
.clone()
)
if hasattr(readout, "linear_2"):
model_readouts_one_linear_2_weight = readout.linear_2.weight.clone()
model_readouts_one_linear_2_weight = model_foundations.readouts[
i
].linear_2.weight.view(shape_input_1, -1).repeat(
len(model_heads), len(model_heads)
).flatten().clone() / (
((shape_input_1) / (shape_output_1)) ** 0.5
)
model_readouts_one_linear_2_weight = (foundation_readout
.linear_2.weight.view(shape_input_1, -1).repeat(
len(model_heads), len(model_heads)
).flatten().clone() / (
((shape_input_1) / (shape_output_1)) ** 0.5
))
readout.linear_2.weight = torch.nn.Parameter(
model_readouts_one_linear_2_weight
)
if readout.linear_2.bias is not None:
model_readouts_one_linear_2_bias = readout.linear_2.bias.clone()
if readout.linear_2.bias is not None and readout.linear_2.bias.nelement() != 0:
model_readouts_one_linear_2_bias = (
model_foundations.readouts[i]
foundation_readout
.linear_2.bias.view(-1)
.repeat(len(model_heads))
.flatten()
Expand Down
8 changes: 8 additions & 0 deletions mace/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,3 +205,11 @@ def filter_nonzero_weight(

quantity_l[-1] = filtered_q
return 1.0

def distributed_on_root():
distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
return (not distributed) or (torch.distributed.get_rank() == 0)

def distributed_barrier():
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
Loading