Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __init__(
"stress",
"dipole",
]
elif model_type == "AtomicTargetMACE":
self.implemented_properties = ["atomic_targets"]
else:
raise ValueError(
f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported"
Expand Down Expand Up @@ -150,6 +152,8 @@ def __init__(
)
elif model_type == "DipoleMACE":
self.implemented_properties.extend(["dipole_var"])
elif mode_type == "AtomicTargetMACE":
self.implemented_properties.extend(["atomic_targets_var"])

if compile_mode is not None:
print(f"Torch compile is enabled with mode: {compile_mode}")
Expand Down Expand Up @@ -232,6 +236,9 @@ def _create_result_tensors(
if model_type in ["EnergyDipoleMACE", "DipoleMACE"]:
dipole = torch.zeros(num_models, 3, device=self.device)
dict_of_tensors.update({"dipole": dipole})
if model_type == "AtomicTargetMACE":
atomic_targets = torch.zeros(num_models, num_atoms, device=self.device)
dict_of_tensors.update({"atomic_targets": atomic_targets})
return dict_of_tensors

def _atoms_to_batch(self, atoms):
Expand Down Expand Up @@ -299,6 +306,8 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
ret_tensors["stress"][i] = out["stress"].detach()
if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]:
ret_tensors["dipole"][i] = out["dipole"].detach()
if self.model_type == "AtomicTargetMACE":
ret_tensors["atomic_targets"][i] = out["atomic_targets"].detach()

self.results = {}
if self.model_type in ["MACE", "EnergyDipoleMACE"]:
Expand Down Expand Up @@ -354,6 +363,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes):
.cpu()
.numpy()
)
if self.model_type == "AtomicTargetMACE":
self.results["atomic_targets"] = (
torch.mean(ret_tensors["atomic_targets"], dim=0).cpu().numpy()
)
if self.num_models > 1:
self.results["atomic_targets_var"] = (
torch.var(ret_tensors["atomic_targets"], dim=0, unbiased=False)
.cpu()
.numpy()
)

def get_hessian(self, atoms=None):
if atoms is None and self.atoms is None:
Expand Down
34 changes: 32 additions & 2 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
dict_to_array,
extract_config_mace_model,
get_atomic_energies,
get_atomic_targets,
get_avg_num_neighbors,
get_config_type_weights,
get_dataset_from_xyz,
Expand Down Expand Up @@ -211,6 +212,7 @@ def run(args: argparse.Namespace) -> None:
virials_key=head_config.virials_key,
dipole_key=head_config.dipole_key,
charges_key=head_config.charges_key,
atomic_targets_key=head_config.atomic_targets_key,
head_name=head_config.head_name,
keep_isolated_atoms=head_config.keep_isolated_atoms,
)
Expand Down Expand Up @@ -281,6 +283,7 @@ def run(args: argparse.Namespace) -> None:
virials_key=args.virials_key,
dipole_key=args.dipole_key,
charges_key=args.charges_key,
atomic_targets_key=args.atomic_targets_key,
head_name="pt_head",
keep_isolated_atoms=args.keep_isolated_atoms,
)
Expand All @@ -298,6 +301,7 @@ def run(args: argparse.Namespace) -> None:
virials_key=args.virials_key,
dipole_key=args.dipole_key,
charges_key=args.charges_key,
atomic_targets_key=args.atomic_targets_key,
keep_isolated_atoms=args.keep_isolated_atoms,
collections=collections,
avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors,
Expand Down Expand Up @@ -339,10 +343,21 @@ def run(args: argparse.Namespace) -> None:
z_table = AtomicNumberTable(sorted(list(all_atomic_numbers)))
logging.info(f"Atomic Numbers used: {z_table.zs}")

# Atomic targets
#atomic_targets_dict = {}
#for head_config in head_configs:
# if head_config.atomic_targets_dict is None or len(head_config.atomic_targets_dict) == 0:
# if check_path_ase_read(head_config.train_file):
# atomic_targets_dict[head_config.head_name] = get_atomic_targets(
# head_config.collections.train, head_config.z_table
# )
# else:
# atomic_targets_dict[head_config.head_name] = head_config.atomic_targets_dict

# Atomic energies
atomic_energies_dict = {}
for head_config in head_configs:
if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0:
if (head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0) and args.model != "AtomicTargetsMACE":
assert head_config.E0s is not None, "Atomic energies must be provided"
if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation":
atomic_energies_dict[head_config.head_name] = get_atomic_energies(
Expand Down Expand Up @@ -381,14 +396,28 @@ def run(args: argparse.Namespace) -> None:

if args.model == "AtomicDipolesMACE":
atomic_energies = None
#atomic_targets = None
dipole_only = True
targets_only = False
args.compute_dipole = True
args.compute_energy = False
args.compute_forces = False
args.compute_virials = False
args.compute_stress = False
elif args.model == "AtomicTargetsMACE":
args.scaling = "atomic_targets_std_scaling"
atomic_energies = None
#atomic_targets = dict_to_array(atomic_targets_dict, heads)
dipole_only = False
targets_only = True
args.compute_dipole = False
args.compute_energy = True
args.compute_forces = False
args.compute_virials = False
args.compute_stress = False
else:
dipole_only = False
targets_only = False
if args.model == "EnergyDipolesMACE":
args.compute_dipole = True
args.compute_energy = True
Expand All @@ -402,6 +431,7 @@ def run(args: argparse.Namespace) -> None:
# [atomic_energies_dict[z] for z in z_table.zs]
# )
atomic_energies = dict_to_array(atomic_energies_dict, heads)
#atomic_targets = None
for head_config in head_configs:
try:
logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}")
Expand Down Expand Up @@ -498,7 +528,7 @@ def run(args: argparse.Namespace) -> None:
generator=torch.Generator().manual_seed(args.seed),
)

loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole)
loss_fn = get_loss_fn(args, dipole_only, targets_only, args.compute_dipole)
args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device)

# Model
Expand Down
2 changes: 2 additions & 0 deletions mace/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
save_configurations_as_HDF5,
save_dataset_as_HDF5,
test_config_types,
compute_average_atomic_targets,
)

__all__ = [
Expand All @@ -31,4 +32,5 @@
"dataset_from_sharded_hdf5",
"save_AtomicData_to_HDF5",
"save_configurations_as_HDF5",
"compute_average_atomic_targets",
]
10 changes: 10 additions & 0 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class AtomicData(torch_geometric.data.Data):
virials: torch.Tensor
dipole: torch.Tensor
charges: torch.Tensor
atomic_targets: torch.Tensor
weight: torch.Tensor
energy_weight: torch.Tensor
forces_weight: torch.Tensor
Expand All @@ -63,6 +64,7 @@ def __init__(
virials: Optional[torch.Tensor], # [1,3,3]
dipole: Optional[torch.Tensor], # [, 3]
charges: Optional[torch.Tensor], # [n_nodes, ]
atomic_targets: Optional[torch.Tensor], # [n_nodes, ]
):
# Check shapes
num_nodes = node_attrs.shape[0]
Expand All @@ -85,6 +87,7 @@ def __init__(
assert virials is None or virials.shape == (1, 3, 3)
assert dipole is None or dipole.shape[-1] == 3
assert charges is None or charges.shape == (num_nodes,)
assert atomic_targets is None or atomic_targets.shape == (num_nodes,)
# Aggregate data
data = {
"num_nodes": num_nodes,
Expand All @@ -106,6 +109,7 @@ def __init__(
"virials": virials,
"dipole": dipole,
"charges": charges,
"atomic_targets": atomic_targets,
}
super().__init__(**data)

Expand Down Expand Up @@ -204,6 +208,11 @@ def from_config(
if config.charges is not None
else None
)
atomic_targets = (
torch.tensor(config.atomic_targets, dtype=torch.get_default_dtype())
if config.atomic_targets is not None
else None
)

return cls(
edge_index=torch.tensor(edge_index, dtype=torch.long),
Expand All @@ -224,6 +233,7 @@ def from_config(
virials=virials,
dipole=dipole,
charges=charges,
atomic_targets=atomic_targets,
)


Expand Down
36 changes: 36 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Stress = np.ndarray # [6, ], [3,3], [9, ]
Virials = np.ndarray # [6, ], [3,3], [9, ]
Charges = np.ndarray # [..., 1]
AtomicTargets = np.ndarray # [..., 1]
Cell = np.ndarray # [3,3]
Pbc = tuple # (3,)

Expand All @@ -38,6 +39,7 @@ class Configuration:
virials: Optional[Virials] = None # eV
dipole: Optional[Vector] = None # Debye
charges: Optional[Charges] = None # atomic unit
atomic_targets: Optional[AtomicTargets] = None
cell: Optional[Cell] = None
pbc: Optional[Pbc] = None

Expand Down Expand Up @@ -92,6 +94,7 @@ def config_from_atoms_list(
virials_key="REF_virials",
dipole_key="REF_dipole",
charges_key="REF_charges",
atomic_targets_key="REF_atomic_targets",
head_key="head",
config_type_weights: Optional[Dict[str, float]] = None,
) -> Configurations:
Expand All @@ -110,6 +113,7 @@ def config_from_atoms_list(
virials_key=virials_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_targets_key=atomic_targets_key,
head_key=head_key,
config_type_weights=config_type_weights,
)
Expand All @@ -125,6 +129,7 @@ def config_from_atoms(
virials_key="REF_virials",
dipole_key="REF_dipole",
charges_key="REF_charges",
atomic_targets_key="REF_atomic_targets",
head_key="head",
config_type_weights: Optional[Dict[str, float]] = None,
) -> Configuration:
Expand All @@ -139,6 +144,8 @@ def config_from_atoms(
dipole = atoms.info.get(dipole_key, None) # Debye
# Charges default to 0 instead of None if not found
charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit
# Atomic targets default to 0 instead of None if not found
atomic_targets = atoms.arrays.get(atomic_targets_key, np.zeros(len(atoms)))
atomic_numbers = np.array(
[ase.data.atomic_numbers[symbol] for symbol in atoms.symbols]
)
Expand Down Expand Up @@ -181,6 +188,7 @@ def config_from_atoms(
virials=virials,
dipole=dipole,
charges=charges,
atomic_targets=atomic_targets,
weight=weight,
head=head,
energy_weight=energy_weight,
Expand Down Expand Up @@ -219,6 +227,7 @@ def load_from_xyz(
virials_key: str = "REF_virials",
dipole_key: str = "REF_dipole",
charges_key: str = "REF_charges",
atomic_targets_key: str = "REF_atomic_targets",
head_key: str = "head",
head_name: str = "Default",
extract_atomic_energies: bool = False,
Expand Down Expand Up @@ -297,6 +306,7 @@ def load_from_xyz(
virials_key=virials_key,
dipole_key=dipole_key,
charges_key=charges_key,
atomic_targets_key=atomic_targets_key,
head_key=head_key,
)
return atomic_energies_dict, configs
Expand Down Expand Up @@ -331,6 +341,32 @@ def compute_average_E0s(
atomic_energies_dict[z] = 0.0
return atomic_energies_dict

def compute_average_atomic_targets(
collections_train: Configurations, z_table: AtomicNumberTable,
) -> Tuple[Dict[int, float], float]:
"""
Function to compute the average node target and node std of each chemical element
returns a dictionary with averages and a float scale
"""
len_train = len(collections_train)
len_zs = len(z_table)
elementwise_targets = {}
for i in range(len_train):
for j in range(len(collections_train[i].atomic_numbers)):
z = collections_train[i].atomic_numbers[j]
if z not in elementwise_targets.keys():
elementwise_targets[z] = []
elementwise_targets[z].append(collections_train[i].atomic_targets[j])


atomic_energies_dict = {}
atomic_scales = []
for i, z in enumerate(z_table.zs):
atomic_energies_dict[z] = np.mean(elementwise_targets[z])
atomic_scales.append((len(elementwise_targets[z]), np.std(elementwise_targets[z])))
# compute weighted average of scales with tuple element 0 ebing the weight and element 1 the value to average
scale = np.average([x[1] for x in atomic_scales], weights=[x[0] for x in atomic_scales])
return atomic_energies_dict #, scale

def save_dataset_as_HDF5(dataset: List, out_name: str) -> None:
with h5py.File(out_name, "w") as f:
Expand Down
7 changes: 7 additions & 0 deletions mace/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
WeightedEnergyForcesVirialsLoss,
WeightedForcesLoss,
WeightedHuberEnergyForcesStressLoss,
AtomicTargetsLoss
)
from .models import (
MACE,
AtomicDipolesMACE,
AtomicTargetsMACE,
BOTNet,
EnergyDipolesMACE,
ScaleShiftBOTNet,
Expand All @@ -45,6 +47,7 @@
compute_fixed_charge_dipole,
compute_mean_rms_energy_forces,
compute_mean_std_atomic_inter_energy,
compute_mean_std_atomic_targets,
compute_rms_dipoles,
compute_statistics,
)
Expand All @@ -62,6 +65,7 @@
"std_scaling": compute_mean_std_atomic_inter_energy,
"rms_forces_scaling": compute_mean_rms_energy_forces,
"rms_dipoles_scaling": compute_rms_dipoles,
"atomic_targets_std_scaling": compute_mean_std_atomic_targets,
}

gate_dict: Dict[str, Optional[Callable]] = {
Expand Down Expand Up @@ -91,6 +95,7 @@
"BOTNet",
"ScaleShiftBOTNet",
"AtomicDipolesMACE",
"AtomicTargetsMACE",
"EnergyDipolesMACE",
"WeightedEnergyForcesLoss",
"WeightedForcesLoss",
Expand All @@ -100,9 +105,11 @@
"WeightedEnergyForcesDipoleLoss",
"WeightedHuberEnergyForcesStressLoss",
"UniversalLoss",
"AtomicTargetsLoss",
"SymmetricContraction",
"interaction_classes",
"compute_mean_std_atomic_inter_energy",
"compute_mean_std_atomic_targets",
"compute_avg_num_neighbors",
"compute_statistics",
"compute_fixed_charge_dipole",
Expand Down
12 changes: 12 additions & 0 deletions mace/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,15 @@ def __repr__(self):
f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, "
f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})"
)

class AtomicTargetsLoss(torch.nn.Module):
def __init__(self, huber_delta=0.01) -> None:
super().__init__()
self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta)

def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor:
return self.huber_loss(ref["atomic_targets"], pred["atomic_targets"])*1e3

def __repr__(self):
return f"{self.__class__.__name__}(huber_delta={self.huber_loss.delta:.3f})"

Loading