From ae256fcce1c68b13d759e987dbc754ac6c7d570b Mon Sep 17 00:00:00 2001 From: JeremyThorn Date: Fri, 1 Nov 2024 16:14:46 +0000 Subject: [PATCH] Updated draft of per-node atomic feature fitting. --- mace/calculators/mace.py | 19 ++++++++ mace/cli/run_train.py | 34 ++++++++++++- mace/data/__init__.py | 2 + mace/data/atomic_data.py | 10 ++++ mace/data/utils.py | 36 ++++++++++++++ mace/modules/__init__.py | 7 +++ mace/modules/loss.py | 12 +++++ mace/modules/models.py | 84 ++++++++++++++++++++++++++++++++ mace/modules/utils.py | 23 +++++++++ mace/tools/arg_parser.py | 15 ++++++ mace/tools/model_script_utils.py | 23 ++++++++- mace/tools/multihead_tools.py | 5 ++ mace/tools/scripts_utils.py | 36 ++++++++++++++ mace/tools/train.py | 31 ++++++++++++ 14 files changed, 334 insertions(+), 3 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 9d307eda7..75b4215af 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -104,6 +104,8 @@ def __init__( "stress", "dipole", ] + elif model_type == "AtomicTargetMACE": + self.implemented_properties = ["atomic_targets"] else: raise ValueError( f"Give a valid model_type: [MACE, DipoleMACE, EnergyDipoleMACE], {model_type} not supported" @@ -150,6 +152,8 @@ def __init__( ) elif model_type == "DipoleMACE": self.implemented_properties.extend(["dipole_var"]) + elif mode_type == "AtomicTargetMACE": + self.implemented_properties.extend(["atomic_targets_var"]) if compile_mode is not None: print(f"Torch compile is enabled with mode: {compile_mode}") @@ -232,6 +236,9 @@ def _create_result_tensors( if model_type in ["EnergyDipoleMACE", "DipoleMACE"]: dipole = torch.zeros(num_models, 3, device=self.device) dict_of_tensors.update({"dipole": dipole}) + if model_type == "AtomicTargetMACE": + atomic_targets = torch.zeros(num_models, num_atoms, device=self.device) + dict_of_tensors.update({"atomic_targets": atomic_targets}) return dict_of_tensors def _atoms_to_batch(self, atoms): @@ -299,6 +306,8 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): ret_tensors["stress"][i] = out["stress"].detach() if self.model_type in ["DipoleMACE", "EnergyDipoleMACE"]: ret_tensors["dipole"][i] = out["dipole"].detach() + if self.model_type == "AtomicTargetMACE": + ret_tensors["atomic_targets"][i] = out["atomic_targets"].detach() self.results = {} if self.model_type in ["MACE", "EnergyDipoleMACE"]: @@ -354,6 +363,16 @@ def calculate(self, atoms=None, properties=None, system_changes=all_changes): .cpu() .numpy() ) + if self.model_type == "AtomicTargetMACE": + self.results["atomic_targets"] = ( + torch.mean(ret_tensors["atomic_targets"], dim=0).cpu().numpy() + ) + if self.num_models > 1: + self.results["atomic_targets_var"] = ( + torch.var(ret_tensors["atomic_targets"], dim=0, unbiased=False) + .cpu() + .numpy() + ) def get_hessian(self, atoms=None): if atoms is None and self.atoms is None: diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 8cab392ed..af359c78e 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -40,6 +40,7 @@ dict_to_array, extract_config_mace_model, get_atomic_energies, + get_atomic_targets, get_avg_num_neighbors, get_config_type_weights, get_dataset_from_xyz, @@ -211,6 +212,7 @@ def run(args: argparse.Namespace) -> None: virials_key=head_config.virials_key, dipole_key=head_config.dipole_key, charges_key=head_config.charges_key, + atomic_targets_key=head_config.atomic_targets_key, head_name=head_config.head_name, keep_isolated_atoms=head_config.keep_isolated_atoms, ) @@ -281,6 +283,7 @@ def run(args: argparse.Namespace) -> None: virials_key=args.virials_key, dipole_key=args.dipole_key, charges_key=args.charges_key, + atomic_targets_key=args.atomic_targets_key, head_name="pt_head", keep_isolated_atoms=args.keep_isolated_atoms, ) @@ -298,6 +301,7 @@ def run(args: argparse.Namespace) -> None: virials_key=args.virials_key, dipole_key=args.dipole_key, charges_key=args.charges_key, + atomic_targets_key=args.atomic_targets_key, keep_isolated_atoms=args.keep_isolated_atoms, collections=collections, avg_num_neighbors=model_foundation.interactions[0].avg_num_neighbors, @@ -339,10 +343,21 @@ def run(args: argparse.Namespace) -> None: z_table = AtomicNumberTable(sorted(list(all_atomic_numbers))) logging.info(f"Atomic Numbers used: {z_table.zs}") + # Atomic targets + #atomic_targets_dict = {} + #for head_config in head_configs: + # if head_config.atomic_targets_dict is None or len(head_config.atomic_targets_dict) == 0: + # if check_path_ase_read(head_config.train_file): + # atomic_targets_dict[head_config.head_name] = get_atomic_targets( + # head_config.collections.train, head_config.z_table + # ) + # else: + # atomic_targets_dict[head_config.head_name] = head_config.atomic_targets_dict + # Atomic energies atomic_energies_dict = {} for head_config in head_configs: - if head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0: + if (head_config.atomic_energies_dict is None or len(head_config.atomic_energies_dict) == 0) and args.model != "AtomicTargetsMACE": assert head_config.E0s is not None, "Atomic energies must be provided" if check_path_ase_read(head_config.train_file) and head_config.E0s.lower() != "foundation": atomic_energies_dict[head_config.head_name] = get_atomic_energies( @@ -381,14 +396,28 @@ def run(args: argparse.Namespace) -> None: if args.model == "AtomicDipolesMACE": atomic_energies = None + #atomic_targets = None dipole_only = True + targets_only = False args.compute_dipole = True args.compute_energy = False args.compute_forces = False args.compute_virials = False args.compute_stress = False + elif args.model == "AtomicTargetsMACE": + args.scaling = "atomic_targets_std_scaling" + atomic_energies = None + #atomic_targets = dict_to_array(atomic_targets_dict, heads) + dipole_only = False + targets_only = True + args.compute_dipole = False + args.compute_energy = True + args.compute_forces = False + args.compute_virials = False + args.compute_stress = False else: dipole_only = False + targets_only = False if args.model == "EnergyDipolesMACE": args.compute_dipole = True args.compute_energy = True @@ -402,6 +431,7 @@ def run(args: argparse.Namespace) -> None: # [atomic_energies_dict[z] for z in z_table.zs] # ) atomic_energies = dict_to_array(atomic_energies_dict, heads) + #atomic_targets = None for head_config in head_configs: try: logging.info(f"Atomic Energies used (z: eV) for head {head_config.head_name}: " + "{" + ", ".join([f"{z}: {atomic_energies_dict[head_config.head_name][z]}" for z in head_config.z_table.zs]) + "}") @@ -498,7 +528,7 @@ def run(args: argparse.Namespace) -> None: generator=torch.Generator().manual_seed(args.seed), ) - loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) + loss_fn = get_loss_fn(args, dipole_only, targets_only, args.compute_dipole) args.avg_num_neighbors = get_avg_num_neighbors(head_configs, args, train_loader, device) # Model diff --git a/mace/data/__init__.py b/mace/data/__init__.py index c10a36982..7f049180e 100644 --- a/mace/data/__init__.py +++ b/mace/data/__init__.py @@ -13,6 +13,7 @@ save_configurations_as_HDF5, save_dataset_as_HDF5, test_config_types, + compute_average_atomic_targets, ) __all__ = [ @@ -31,4 +32,5 @@ "dataset_from_sharded_hdf5", "save_AtomicData_to_HDF5", "save_configurations_as_HDF5", + "compute_average_atomic_targets", ] diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index cb4edd94e..9192d73e8 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -37,6 +37,7 @@ class AtomicData(torch_geometric.data.Data): virials: torch.Tensor dipole: torch.Tensor charges: torch.Tensor + atomic_targets: torch.Tensor weight: torch.Tensor energy_weight: torch.Tensor forces_weight: torch.Tensor @@ -63,6 +64,7 @@ def __init__( virials: Optional[torch.Tensor], # [1,3,3] dipole: Optional[torch.Tensor], # [, 3] charges: Optional[torch.Tensor], # [n_nodes, ] + atomic_targets: Optional[torch.Tensor], # [n_nodes, ] ): # Check shapes num_nodes = node_attrs.shape[0] @@ -85,6 +87,7 @@ def __init__( assert virials is None or virials.shape == (1, 3, 3) assert dipole is None or dipole.shape[-1] == 3 assert charges is None or charges.shape == (num_nodes,) + assert atomic_targets is None or atomic_targets.shape == (num_nodes,) # Aggregate data data = { "num_nodes": num_nodes, @@ -106,6 +109,7 @@ def __init__( "virials": virials, "dipole": dipole, "charges": charges, + "atomic_targets": atomic_targets, } super().__init__(**data) @@ -204,6 +208,11 @@ def from_config( if config.charges is not None else None ) + atomic_targets = ( + torch.tensor(config.atomic_targets, dtype=torch.get_default_dtype()) + if config.atomic_targets is not None + else None + ) return cls( edge_index=torch.tensor(edge_index, dtype=torch.long), @@ -224,6 +233,7 @@ def from_config( virials=virials, dipole=dipole, charges=charges, + atomic_targets=atomic_targets, ) diff --git a/mace/data/utils.py b/mace/data/utils.py index bb8e54484..04edd9100 100644 --- a/mace/data/utils.py +++ b/mace/data/utils.py @@ -21,6 +21,7 @@ Stress = np.ndarray # [6, ], [3,3], [9, ] Virials = np.ndarray # [6, ], [3,3], [9, ] Charges = np.ndarray # [..., 1] +AtomicTargets = np.ndarray # [..., 1] Cell = np.ndarray # [3,3] Pbc = tuple # (3,) @@ -38,6 +39,7 @@ class Configuration: virials: Optional[Virials] = None # eV dipole: Optional[Vector] = None # Debye charges: Optional[Charges] = None # atomic unit + atomic_targets: Optional[AtomicTargets] = None cell: Optional[Cell] = None pbc: Optional[Pbc] = None @@ -92,6 +94,7 @@ def config_from_atoms_list( virials_key="REF_virials", dipole_key="REF_dipole", charges_key="REF_charges", + atomic_targets_key="REF_atomic_targets", head_key="head", config_type_weights: Optional[Dict[str, float]] = None, ) -> Configurations: @@ -110,6 +113,7 @@ def config_from_atoms_list( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_targets_key=atomic_targets_key, head_key=head_key, config_type_weights=config_type_weights, ) @@ -125,6 +129,7 @@ def config_from_atoms( virials_key="REF_virials", dipole_key="REF_dipole", charges_key="REF_charges", + atomic_targets_key="REF_atomic_targets", head_key="head", config_type_weights: Optional[Dict[str, float]] = None, ) -> Configuration: @@ -139,6 +144,8 @@ def config_from_atoms( dipole = atoms.info.get(dipole_key, None) # Debye # Charges default to 0 instead of None if not found charges = atoms.arrays.get(charges_key, np.zeros(len(atoms))) # atomic unit + # Atomic targets default to 0 instead of None if not found + atomic_targets = atoms.arrays.get(atomic_targets_key, np.zeros(len(atoms))) atomic_numbers = np.array( [ase.data.atomic_numbers[symbol] for symbol in atoms.symbols] ) @@ -181,6 +188,7 @@ def config_from_atoms( virials=virials, dipole=dipole, charges=charges, + atomic_targets=atomic_targets, weight=weight, head=head, energy_weight=energy_weight, @@ -219,6 +227,7 @@ def load_from_xyz( virials_key: str = "REF_virials", dipole_key: str = "REF_dipole", charges_key: str = "REF_charges", + atomic_targets_key: str = "REF_atomic_targets", head_key: str = "head", head_name: str = "Default", extract_atomic_energies: bool = False, @@ -297,6 +306,7 @@ def load_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_targets_key=atomic_targets_key, head_key=head_key, ) return atomic_energies_dict, configs @@ -331,6 +341,32 @@ def compute_average_E0s( atomic_energies_dict[z] = 0.0 return atomic_energies_dict +def compute_average_atomic_targets( + collections_train: Configurations, z_table: AtomicNumberTable, +) -> Tuple[Dict[int, float], float]: + """ + Function to compute the average node target and node std of each chemical element + returns a dictionary with averages and a float scale + """ + len_train = len(collections_train) + len_zs = len(z_table) + elementwise_targets = {} + for i in range(len_train): + for j in range(len(collections_train[i].atomic_numbers)): + z = collections_train[i].atomic_numbers[j] + if z not in elementwise_targets.keys(): + elementwise_targets[z] = [] + elementwise_targets[z].append(collections_train[i].atomic_targets[j]) + + + atomic_energies_dict = {} + atomic_scales = [] + for i, z in enumerate(z_table.zs): + atomic_energies_dict[z] = np.mean(elementwise_targets[z]) + atomic_scales.append((len(elementwise_targets[z]), np.std(elementwise_targets[z]))) + # compute weighted average of scales with tuple element 0 ebing the weight and element 1 the value to average + scale = np.average([x[1] for x in atomic_scales], weights=[x[0] for x in atomic_scales]) + return atomic_energies_dict #, scale def save_dataset_as_HDF5(dataset: List, out_name: str) -> None: with h5py.File(out_name, "w") as f: diff --git a/mace/modules/__init__.py b/mace/modules/__init__.py index 9278130fd..7aca0bed1 100644 --- a/mace/modules/__init__.py +++ b/mace/modules/__init__.py @@ -29,10 +29,12 @@ WeightedEnergyForcesVirialsLoss, WeightedForcesLoss, WeightedHuberEnergyForcesStressLoss, + AtomicTargetsLoss ) from .models import ( MACE, AtomicDipolesMACE, + AtomicTargetsMACE, BOTNet, EnergyDipolesMACE, ScaleShiftBOTNet, @@ -45,6 +47,7 @@ compute_fixed_charge_dipole, compute_mean_rms_energy_forces, compute_mean_std_atomic_inter_energy, + compute_mean_std_atomic_targets, compute_rms_dipoles, compute_statistics, ) @@ -62,6 +65,7 @@ "std_scaling": compute_mean_std_atomic_inter_energy, "rms_forces_scaling": compute_mean_rms_energy_forces, "rms_dipoles_scaling": compute_rms_dipoles, + "atomic_targets_std_scaling": compute_mean_std_atomic_targets, } gate_dict: Dict[str, Optional[Callable]] = { @@ -91,6 +95,7 @@ "BOTNet", "ScaleShiftBOTNet", "AtomicDipolesMACE", + "AtomicTargetsMACE", "EnergyDipolesMACE", "WeightedEnergyForcesLoss", "WeightedForcesLoss", @@ -100,9 +105,11 @@ "WeightedEnergyForcesDipoleLoss", "WeightedHuberEnergyForcesStressLoss", "UniversalLoss", + "AtomicTargetsLoss", "SymmetricContraction", "interaction_classes", "compute_mean_std_atomic_inter_energy", + "compute_mean_std_atomic_targets", "compute_avg_num_neighbors", "compute_statistics", "compute_fixed_charge_dipole", diff --git a/mace/modules/loss.py b/mace/modules/loss.py index a7e28c550..80002a0de 100644 --- a/mace/modules/loss.py +++ b/mace/modules/loss.py @@ -381,3 +381,15 @@ def __repr__(self): f"{self.__class__.__name__}(energy_weight={self.energy_weight:.3f}, " f"forces_weight={self.forces_weight:.3f}, dipole_weight={self.dipole_weight:.3f})" ) + +class AtomicTargetsLoss(torch.nn.Module): + def __init__(self, huber_delta=0.01) -> None: + super().__init__() + self.huber_loss = torch.nn.HuberLoss(reduction="mean", delta=huber_delta) + + def forward(self, ref: Batch, pred: TensorDict) -> torch.Tensor: + return self.huber_loss(ref["atomic_targets"], pred["atomic_targets"])*1e3 + + def __repr__(self): + return f"{self.__class__.__name__}(huber_delta={self.huber_loss.delta:.3f})" + diff --git a/mace/modules/models.py b/mace/modules/models.py index c0d8ab430..726ec8d1d 100644 --- a/mace/modules/models.py +++ b/mace/modules/models.py @@ -453,6 +453,90 @@ def forward( return output +@compile_mode("script") +class AtomicTargetsMACE(MACE): + def __init__( + self, + atomic_inter_scale: float, + atomic_inter_shift: float, + **kwargs, + ): + super().__init__(**kwargs) + self.scale_shift = ScaleShiftBlock( + scale=atomic_inter_scale, shift=atomic_inter_shift + ) + + def forward( + self, + data: Dict[str, torch.Tensor], + training: bool = False, + compute_force: bool = True, + compute_virials: bool = False, + compute_stress: bool = False, + compute_displacement: bool = False, + compute_hessian: bool = False, + ) -> Dict[str, Optional[torch.Tensor]]: + # Setup + data["positions"].requires_grad_(True) + data["node_attrs"].requires_grad_(True) + num_graphs = data["ptr"].numel() - 1 + num_atoms_arange = torch.arange(data["positions"].shape[0]) + node_heads = ( + data["head"][data["batch"]] + if "head" in data + else torch.zeros_like(data["batch"]) + ) + + + # Embeddings + node_feats = self.node_embedding(data["node_attrs"]) + vectors, lengths = get_edge_vectors_and_lengths( + positions=data["positions"], + edge_index=data["edge_index"], + shifts=data["shifts"], + ) + edge_attrs = self.spherical_harmonics(vectors) + edge_feats = self.radial_embedding( + lengths, data["node_attrs"], data["edge_index"], self.atomic_numbers + ) + + # Interactions + node_es_list = [] + node_feats_list = [] + for interaction, product, readout in zip( + self.interactions, self.products, self.readouts + ): + node_feats, sc = interaction( + node_attrs=data["node_attrs"], + node_feats=node_feats, + edge_attrs=edge_attrs, + edge_feats=edge_feats, + edge_index=data["edge_index"], + ) + node_feats = product( + node_feats=node_feats, sc=sc, node_attrs=data["node_attrs"] + ) + node_feats_list.append(node_feats) + node_es_list.append( + readout(node_feats, node_heads)[num_atoms_arange, node_heads] + ) # {[n_nodes, ], } + + # Concatenate node features + node_feats_out = torch.cat(node_feats_list, dim=-1) + # Sum over interactions + node_inter_es = torch.sum( + torch.stack(node_es_list, dim=0), dim=0 + ) # [n_nodes, ] + node_inter_es = self.scale_shift(node_inter_es, node_heads) + + # Add E_0 and (scaled) interaction energy + node_energy = node_inter_es + output = { + "atomic_targets": node_energy, + "node_feats": node_feats_out, + } + + return output class BOTNet(torch.nn.Module): def __init__( diff --git a/mace/modules/utils.py b/mace/modules/utils.py index d0a1e5f67..2214052b3 100644 --- a/mace/modules/utils.py +++ b/mace/modules/utils.py @@ -250,6 +250,29 @@ def extract_invariant(x: torch.Tensor, num_layers: int, num_features: int, l_max out.append(x[:, -num_features:]) return torch.cat(out, dim=-1) +def compute_mean_std_atomic_targets( + data_loader: torch.utils.data.DataLoader, + z_table +) -> Tuple[float, float]: + + elementwise_targets = {} + for batch in data_loader: + for i, z in enumerate(z_table.zs): + z_targets = batch.atomic_targets[batch.node_attrs.bool()[:,i]] + if z not in elementwise_targets.keys(): + elementwise_targets[z] = [] + elementwise_targets[z].append(z_targets) + + atomic_energies_dict = {} + atomic_scales = [] + for z in elementwise_targets.keys(): + elementwise_targets[z] = torch.cat(elementwise_targets[z], dim=0) + atomic_energies_dict[z] = torch.mean(elementwise_targets[z]) + atomic_scales.append((len(elementwise_targets[z]), torch.std(elementwise_targets[z]))) + # compute weighted average of scales with tuple element 0 ebing the weight and element 1 the value to average + scale = np.average([x[1] for x in atomic_scales], weights=[x[0] for x in atomic_scales]) + + return atomic_energies_dict, scale def compute_mean_std_atomic_inter_energy( data_loader: torch.utils.data.DataLoader, diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index 11a6d2f30..ba70a9398 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -83,6 +83,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: help="Type of error table produced at the end of the training", type=str, choices=[ + "AtomicTargetsPerAtomRMSE", "PerAtomRMSE", "TotalRMSE", "PerAtomRMSEstressvirials", @@ -108,6 +109,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "ScaleShiftBOTNet", "AtomicDipolesMACE", "EnergyDipolesMACE", + "AtomicTargetsMACE", ], ) parser.add_argument( @@ -427,6 +429,12 @@ def build_default_arg_parser() -> argparse.ArgumentParser: type=str, default="REF_charges", ) + parser.add_argument( + "--atomic_targets_key", + help="Key of atomic targets in training xyz", + type=str, + default="REF_atomic_targets", + ) # Loss and optimization parser.add_argument( @@ -443,6 +451,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser: "huber", "universal", "energy_forces_dipole", + "atomic_targets", ], ) parser.add_argument( @@ -792,6 +801,12 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser: type=str, default="REF_charges", ) + parser.add_argument( + "--atomic_targets_key", + help="Key of atomic targets in training xyz", + type=str, + default="REF_atomic_targets", + ) parser.add_argument( "--atomic_numbers", help="List of atomic numbers", diff --git a/mace/tools/model_script_utils.py b/mace/tools/model_script_utils.py index 8e8c28770..2029ef19d 100644 --- a/mace/tools/model_script_utils.py +++ b/mace/tools/model_script_utils.py @@ -33,10 +33,16 @@ def configure_model( if args.scaling == "no_scaling": args.std = 1.0 logging.info("No scaling selected") - elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE": + elif (args.mean is None or args.std is None) and args.model != "AtomicDipolesMACE" and args.model != "AtomicTargetsMACE": args.mean, args.std = modules.scaling_classes[args.scaling]( train_loader, atomic_energies ) + elif args.model == "AtomicTargetsMACE": + _, args.std = modules.scaling_classes[args.scaling]( + train_loader, z_table + ) + args.mean = 0.0 + atomic_energies = 0.0 # Build model if model_foundation is not None and args.model in ["MACE", "ScaleShiftMACE"]: @@ -178,6 +184,21 @@ def _build_model( radial_type=args.radial_type, heads=heads, ) + if args.model == "AtomicTargetsMACE": + return modules.AtomicTargetsMACE( + **model_config, + pair_repulsion=args.pair_repulsion, + distance_transform=args.distance_transform, + correlation=args.correlation, + gate=modules.gate_dict[args.gate], + interaction_cls_first=modules.interaction_classes[args.interaction_first], + MLP_irreps=o3.Irreps(args.MLP_irreps), + atomic_inter_scale=args.std, + atomic_inter_shift=args.mean, + radial_MLP=ast.literal_eval(args.radial_MLP), + radial_type=args.radial_type, + heads=heads, + ) if args.model == "FoundationMACE": return modules.ScaleShiftMACE(**model_config_foundation) if args.model == "ScaleShiftBOTNet": diff --git a/mace/tools/multihead_tools.py b/mace/tools/multihead_tools.py index ffde107ff..3a7dfee6f 100644 --- a/mace/tools/multihead_tools.py +++ b/mace/tools/multihead_tools.py @@ -32,6 +32,7 @@ class HeadConfig: virials_key: Optional[str] = None dipole_key: Optional[str] = None charges_key: Optional[str] = None + atomic_targets_key: Optional[str] = None keep_isolated_atoms: Optional[bool] = None atomic_numbers: Optional[Union[List[int], List[str]]] = None mean: Optional[float] = None @@ -42,6 +43,7 @@ class HeadConfig: train_loader: torch.utils.data.DataLoader = None z_table: Optional[Any] = None atomic_energies_dict: Optional[Dict[str, float]] = None + atomic_targets_dict: Optional[Dict[str, float]] = None def dict_head_to_dataclass( @@ -71,6 +73,7 @@ def dict_head_to_dataclass( virials_key=head.get("virials_key", args.virials_key), dipole_key=head.get("dipole_key", args.dipole_key), charges_key=head.get("charges_key", args.charges_key), + atomic_targets_key=head.get("atomic_targets_key", args.atomic_targets_key), keep_isolated_atoms=head.get("keep_isolated_atoms", args.keep_isolated_atoms), ) @@ -92,6 +95,7 @@ def prepare_default_head(args: argparse.Namespace) -> Dict[str, Any]: "virials_key": args.virials_key, "dipole_key": args.dipole_key, "charges_key": args.charges_key, + "atomic_targets_key": args.atomic_targets_key, "keep_isolated_atoms": args.keep_isolated_atoms, } } @@ -176,6 +180,7 @@ def assemble_mp_data( virials_key=args.virials_key, dipole_key=args.dipole_key, charges_key=args.charges_key, + atomic_targets_key=args.atomic_targets_key, keep_isolated_atoms=args.keep_isolated_atoms, ) return collections_mp diff --git a/mace/tools/scripts_utils.py b/mace/tools/scripts_utils.py index ac9d09fb5..2e2608ae6 100644 --- a/mace/tools/scripts_utils.py +++ b/mace/tools/scripts_utils.py @@ -48,6 +48,7 @@ def get_dataset_from_xyz( virials_key: str = "virials", dipole_key: str = "dipoles", charges_key: str = "charges", + atomic_targets_key: str = "atomic_targets", head_key: str = "head", ) -> Tuple[SubsetCollection, Optional[Dict[int, float]]]: """Load training and test dataset from xyz file""" @@ -60,6 +61,7 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_targets_key=atomic_targets_key, head_key=head_key, extract_atomic_energies=True, keep_isolated_atoms=keep_isolated_atoms, @@ -78,6 +80,7 @@ def get_dataset_from_xyz( virials_key=virials_key, dipole_key=dipole_key, charges_key=charges_key, + atomic_targets_key=atomic_targets_key, head_key=head_key, extract_atomic_energies=False, head_name=head_name, @@ -105,6 +108,7 @@ def get_dataset_from_xyz( stress_key=stress_key, virials_key=virials_key, charges_key=charges_key, + atomic_targets_key=atomic_targets_key, head_key=head_key, extract_atomic_energies=False, head_name=head_name, @@ -348,6 +352,18 @@ def get_atomic_energies(E0s, train_collection, z_table) -> dict: ) return atomic_energies_dict +def get_atomic_targets(train_collection, z_table) -> Tuple[Dict, float]: + # catch if colections.train not defined above + try: + assert train_collection is not None + atomic_targets_dict = data.compute_average_atomic_targets( + train_collection, z_table + ) + except Exception as e: + raise RuntimeError( + f"Could not compute average atomic targets if no training xyz given, error {e} occured" + ) from e + return atomic_targets_dict def get_avg_num_neighbors(head_configs, args, train_loader, device): if all(head_config.compute_avg_num_neighbors for head_config in head_configs): @@ -384,6 +400,7 @@ def get_avg_num_neighbors(head_configs, args, train_loader, device): def get_loss_fn( args: argparse.Namespace, dipole_only: bool, + targets_only: bool, compute_dipole: bool, ) -> torch.nn.Module: if args.loss == "weighted": @@ -432,6 +449,11 @@ def get_loss_fn( forces_weight=args.forces_weight, dipole_weight=args.dipole_weight, ) + elif args.loss == "atomic_targets": + assert dipole_only is False and targets_only is True + loss_fn = modules.AtomicTargetsLoss( + huber_delta=args.huber_delta, + ) else: loss_fn = modules.WeightedEnergyForcesLoss(energy_weight=1.0, forces_weight=1.0) return loss_fn @@ -702,6 +724,12 @@ def create_error_table( "RMSE F / meV / A", "relative F RMSE %", ] + elif table_type == "AtomicTargetsPerAtomRMSE": + table.field_names = [ + "config_type", + "RMSE Target / atom", + "relative Target RMSE %", + ] elif table_type == "PerAtomRMSEstressvirials": table.field_names = [ "config_type", @@ -796,6 +824,14 @@ def create_error_table( f"{metrics['rel_rmse_f']:8.2f}", ] ) + elif table_type == "AtomicTargetsPerAtomRMSE": + table.add_row( + [ + name, + f"{metrics['rmse_atomic_target_per_atom'] * 1000:8.1f}", + f"{metrics['rel_rmse_atomic_target']:8.2f}", + ] + ) elif ( table_type == "PerAtomRMSEstressvirials" and metrics["rmse_stress"] is not None diff --git a/mace/tools/train.py b/mace/tools/train.py index 8e293beee..550e553d1 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -62,6 +62,11 @@ def valid_err_log( logging.info( f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_E_per_atom={error_e:8.1f} meV, RMSE_F={error_f:8.1f} meV / A" ) + elif log_errors == "AtomicTargetsPerAtomRMSE": + error_atom_targets = eval_metrics["rmse_atomic_target_per_atom"] * 1e3 + logging.info( + f"{inintial_phrase}: head: {valid_loader_name}, loss={valid_loss:8.4f}, RMSE_atomic_target_per_atom={error_atom_targets:8.1f}" + ) elif ( log_errors == "PerAtomRMSEstressvirials" and eval_metrics["rmse_stress"] is not None @@ -187,6 +192,7 @@ def train( output_args=output_args, device=device, ) + print(eval_metrics) valid_err_log( valid_loss_head, eval_metrics, logger, log_errors, None, valid_loader_name ) @@ -452,6 +458,10 @@ def __init__(self, loss_fn: torch.nn.Module): self.add_state("mus", default=[], dist_reduce_fx="cat") self.add_state("delta_mus", default=[], dist_reduce_fx="cat") self.add_state("delta_mus_per_atom", default=[], dist_reduce_fx="cat") + self.add_state("AtomicTargets_computed", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("atomic_targets", default=[], dist_reduce_fx="cat") + self.add_state("delta_atomic_targets", default=[], dist_reduce_fx="cat") + self.add_state("delta_atomic_targets_per_atom", default=[], dist_reduce_fx="cat") def update(self, batch, output): # pylint: disable=arguments-differ loss = self.loss_fn(pred=output, ref=batch) @@ -486,6 +496,17 @@ def update(self, batch, output): # pylint: disable=arguments-differ (batch.dipole - output["dipole"]) / (batch.ptr[1:] - batch.ptr[:-1]).unsqueeze(-1) ) + if output.get("atomic_targets") is not None and batch.atomic_targets is not None: + self.AtomicTargets_computed += 1.0 + self.atomic_targets.append(batch.atomic_targets) + self.delta_atomic_targets.append(batch.atomic_targets - output["atomic_targets"]) + + for i in range(len(batch.ptr)-1): + config_delta = batch.atomic_targets[batch.ptr[i]:batch.ptr[i+1]] - output["atomic_targets"][batch.ptr[i]:batch.ptr[i+1]] + num_atoms = batch.ptr[i+1] - batch.ptr[i] + self.delta_atomic_targets_per_atom.append( + config_delta / num_atoms + ) def convert(self, delta: Union[torch.Tensor, List[torch.Tensor]]) -> np.ndarray: if isinstance(delta, list): @@ -534,5 +555,15 @@ def compute(self): aux["rmse_mu_per_atom"] = compute_rmse(delta_mus_per_atom) aux["rel_rmse_mu"] = compute_rel_rmse(delta_mus, mus) aux["q95_mu"] = compute_q95(delta_mus) + if self.AtomicTargets_computed: + atomic_targets = self.convert(self.atomic_targets) + delta_atomic_targets = self.convert(self.delta_atomic_targets) + delta_atomic_targets_per_atom = self.convert(self.delta_atomic_targets_per_atom) + aux["mae_atomic_target"] = compute_mae(delta_atomic_targets) + aux["mae_atomic_target_per_atom"] = compute_mae(delta_atomic_targets_per_atom) + aux["rmse_atomic_target"] = compute_rmse(delta_atomic_targets) + aux["rmse_atomic_target_per_atom"] = compute_rmse(delta_atomic_targets_per_atom) + aux["rel_rmse_atomic_target"] = compute_rel_rmse(delta_atomic_targets, atomic_targets) + aux["q95_atomic_target"] = compute_q95(delta_atomic_targets) return aux["loss"], aux