From 6fd4ff81ed83ad43185d04188307b8f37ee0171c Mon Sep 17 00:00:00 2001 From: Venkat Kapil Date: Wed, 3 Dec 2025 11:19:39 +0000 Subject: [PATCH 1/6] added code skeleton --- mace/data/atomic_data.py | 64 +++++++++++++++++++++++++++------------ mace/data/neighborhood.py | 41 ++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 20 deletions(-) diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index fdd8ec0b8..a70bb7ece 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -17,7 +17,8 @@ voigt_to_matrix, ) -from .neighborhood import get_neighborhood +from ..tools.torch_geometric import Batch +from .neighborhood import get_neighborhood, get_neighborhood_batched from .utils import Configuration @@ -153,12 +154,21 @@ def from_config( ) -> "AtomicData": if heads is None: heads = ["Default"] - edge_index, shifts, unit_shifts, cell = get_neighborhood( - positions=config.positions, - cutoff=cutoff, - pbc=deepcopy(config.pbc), - cell=deepcopy(config.cell), - ) + + edge_index = kwargs.pop("edge_index", None) + shifts = kwargs.pop("shifts", None) + unit_shifts = kwargs.pop("unit_shifts", None) + + if edge_index is None or shifts is None or unit_shifts is None: + edge_index, shifts, unit_shifts, cell = get_neighborhood( + positions=config.positions, + cutoff=cutoff, + pbc=deepcopy(config.pbc), + cell=deepcopy(config.cell), + ) + else: + cell = deepcopy(config.cell) + indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table) one_hot = to_one_hot( torch.tensor(indices, dtype=torch.long).unsqueeze(-1), @@ -395,15 +405,31 @@ def from_config( return cls(**cls_kwargs) -def get_data_loader( - dataset: Sequence[AtomicData], - batch_size: int, - shuffle=True, - drop_last=False, -) -> torch.utils.data.DataLoader: - return torch_geometric.dataloader.DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - ) +# Adding a collate function + +def atomicdata_collate(configs, z_table, cutoff, mode, heads=None): + + batched_edge_index, batched_shifts, batched_unit_shifts, batched_cells = get_neighborhood_batched(configs, cutoff=cutoff) + data_list = [] + + for i, config in enumerate(configs): + # slice per-config pieces + edge_index_i = batched_edge_index[i] + shifts_i = batched_shifts[i] + unit_shifts_i = batched_unit_shifts[i] + cell_i = batched_cells[i] + + # helper that skips its own get_neighborhood + atomic_i = AtomicData.from_config_with_edges( + config=config, + z_table=z_table, + cutoff=cutoff, + heads=heads, + edge_index=edge_index_i, + shifts=shifts_i, + unit_shifts=unit_shifts_i, + cell=cell_i, + ) + data_list.append(atomic_i) + + return Batch.from_data_list(data_list) \ No newline at end of file diff --git a/mace/data/neighborhood.py b/mace/data/neighborhood.py index 03728969d..eb95f8026 100644 --- a/mace/data/neighborhood.py +++ b/mace/data/neighborhood.py @@ -1,9 +1,48 @@ -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Sequence import numpy as np from matscipy.neighbours import neighbour_list +def get_neighborhood_batched( + positions_list: Sequence[np.ndarray], # list of [num_positions_i, 3] + cutoff: float, + pbc_list: Optional[Sequence[Optional[Tuple[bool, bool, bool]]]] = None, + cell_list: Optional[Sequence[Optional[np.ndarray]]] = None, # list of [3, 3] + true_self_interaction: bool = False, +) -> Tuple[ + List[np.ndarray], # edge_index_list + List[np.ndarray], # shifts_list + List[np.ndarray], # unit_shifts_list + List[np.ndarray], # cell_list_out +]: + """ + For now: trivial batched version that just loops over structures and + calls get_neighborhood for each one. + """ + if pbc_list is None: + pbc_list = [None] * len(positions_list) + if cell_list is None: + cell_list = [None] * len(positions_list) + + edge_index_list: List[np.ndarray] = [] + shifts_list: List[np.ndarray] = [] + unit_shifts_list: List[np.ndarray] = [] + cell_list_out: List[np.ndarray] = [] + + for positions, pbc, cell in zip(positions_list, pbc_list, cell_list): + edge_index, shifts, unit_shifts, cell_out = get_neighborhood( + positions=positions, + cutoff=cutoff, + pbc=pbc, + cell=cell, + true_self_interaction=true_self_interaction, + ) + edge_index_list.append(edge_index) + shifts_list.append(shifts) + unit_shifts_list.append(unit_shifts) + cell_list_out.append(cell_out) + return edge_index_list, shifts_list, unit_shifts_list, cell_list_out def get_neighborhood( positions: np.ndarray, # [num_positions, 3] cutoff: float, From af78660fbb751166773fae6a9a691f9b494e6ed1 Mon Sep 17 00:00:00 2001 From: Venkat Kapil Date: Wed, 3 Dec 2025 14:15:38 +0000 Subject: [PATCH 2/6] implementation of atomicdata_collate in run_train.py --- mace/cli/run_train.py | 4 ++++ mace/data/atomic_data.py | 45 ++++++++++++++++++++++------------- mace/tools/run_train_utils.py | 7 +----- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 919a30b85..9243cd3c8 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -19,6 +19,7 @@ from torch.optim import LBFGS from torch.utils.data import ConcatDataset from torch_ema import ExponentialMovingAverage +from functools import partial import mace from mace import data, tools @@ -34,6 +35,7 @@ from mace.cli.convert_oeq_e3nn import run as run_oeq_to_e3nn from mace.cli.visualise_train import TrainingPlotter from mace.data import KeySpecification, update_keyspec_from_kwargs +from mace.data.atomic_data import atomicdata_collate from mace.tools import torch_geometric from mace.tools.distributed_tools import init_distributed from mace.tools.model_script_utils import configure_model @@ -636,6 +638,8 @@ def run(args) -> None: dataset_size = len(train_sets[head_config.head_name]) logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}") + collate = partial(atomicdata_collate, z_table=z_table, cutoff=args.r_max, heads=heads) + train_loader_head = torch_geometric.dataloader.DataLoader( dataset=train_sets[head_config.head_name], batch_size=args.batch_size, diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index a70bb7ece..a46f98d2c 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -407,28 +407,39 @@ def from_config( # Adding a collate function -def atomicdata_collate(configs, z_table, cutoff, mode, heads=None): +def atomicdata_collate(batch, z_table, cutoff, mode, heads=None): + # If they are already AtomicData (e.g. HDF5/LMDB), just batch them + if isinstance(batch[0], AtomicData): + return Batch.from_data_list(batch) + + # Otherwise assume they are Configuration objects + configs = batch + + positions_list = [cfg.positions for cfg in configs] + pbc_list = [cfg.pbc for cfg in configs] + cell_list = [cfg.cell for cfg in configs] + + edge_indices, shifts_list, unit_shifts_list, cells_list = get_neighborhood_batched( + positions_list=positions_list, + cutoff=cutoff, + pbc_list=pbc_list, + cell_list=cell_list, + true_self_interaction=False, + ) - batched_edge_index, batched_shifts, batched_unit_shifts, batched_cells = get_neighborhood_batched(configs, cutoff=cutoff) data_list = [] - - for i, config in enumerate(configs): - # slice per-config pieces - edge_index_i = batched_edge_index[i] - shifts_i = batched_shifts[i] - unit_shifts_i = batched_unit_shifts[i] - cell_i = batched_cells[i] - - # helper that skips its own get_neighborhood - atomic_i = AtomicData.from_config_with_edges( - config=config, + for cfg, edge_index, shifts, unit_shifts, cell in zip( + configs, edge_indices, shifts_list, unit_shifts_list, cells_list + ): + atomic_i = AtomicData.from_config( + config=cfg, z_table=z_table, cutoff=cutoff, heads=heads, - edge_index=edge_index_i, - shifts=shifts_i, - unit_shifts=unit_shifts_i, - cell=cell_i, + edge_index=edge_index, + shifts=shifts, + unit_shifts=unit_shifts, + cell=cell, ) data_list.append(atomic_i) diff --git a/mace/tools/run_train_utils.py b/mace/tools/run_train_utils.py index ce37e0edc..549a2ed41 100644 --- a/mace/tools/run_train_utils.py +++ b/mace/tools/run_train_utils.py @@ -67,12 +67,7 @@ def load_dataset_for_path( assert ( collection is not None ), "Collection must be provided for ASE readable files" - return [ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=r_max, heads=heads - ) - for config in collection - ] + return list(collection) filepath = Path(file_path) if filepath.is_dir(): From 608ffeda5ff63444fcfcc999c41bcfce6ddcdb20 Mon Sep 17 00:00:00 2001 From: Venkat Kapil Date: Wed, 3 Dec 2025 15:10:26 +0000 Subject: [PATCH 3/6] bug fixes --- mace/cli/run_train.py | 5 +++++ mace/data/atomic_data.py | 2 +- mace/tools/torch_geometric/dataloader.py | 7 ++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 9243cd3c8..0d2d3ba61 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -648,6 +648,7 @@ def run(args) -> None: pin_memory=args.pin_memory, num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), + collate_fn=collate ) head_config.train_loader = train_loader_head @@ -675,6 +676,9 @@ def run(args) -> None: ) valid_samplers[head] = valid_sampler + + collate = partial(atomicdata_collate, z_table=z_table, cutoff=args.r_max, heads=heads) + train_loader = torch_geometric.dataloader.DataLoader( dataset=train_set, batch_size=args.batch_size, @@ -684,6 +688,7 @@ def run(args) -> None: pin_memory=args.pin_memory, num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), + collate_fn=collate ) valid_loaders = {heads[i]: None for i in range(len(heads))} diff --git a/mace/data/atomic_data.py b/mace/data/atomic_data.py index a46f98d2c..efbbd4755 100644 --- a/mace/data/atomic_data.py +++ b/mace/data/atomic_data.py @@ -407,7 +407,7 @@ def from_config( # Adding a collate function -def atomicdata_collate(batch, z_table, cutoff, mode, heads=None): +def atomicdata_collate(batch, z_table, cutoff, heads=None): # If they are already AtomicData (e.g. HDF5/LMDB), just batch them if isinstance(batch[0], AtomicData): return Batch.from_data_list(batch) diff --git a/mace/tools/torch_geometric/dataloader.py b/mace/tools/torch_geometric/dataloader.py index 396b7e728..e264cbe4c 100644 --- a/mace/tools/torch_geometric/dataloader.py +++ b/mace/tools/torch_geometric/dataloader.py @@ -69,8 +69,13 @@ def __init__( shuffle: bool = False, follow_batch: Optional[List[str]] = [None], exclude_keys: Optional[List[str]] = [None], + collate_fn=None, **kwargs, ): + + if collate_fn is None: + collate_fn = Collater(follow_batch, exclude_keys) + if "collate_fn" in kwargs: del kwargs["collate_fn"] @@ -82,6 +87,6 @@ def __init__( dataset, batch_size, shuffle, - collate_fn=Collater(follow_batch, exclude_keys), + collate_fn=collate_fn, **kwargs, ) From d8fb7266e64a69d9971916458d30588dfa0debb8 Mon Sep 17 00:00:00 2001 From: Venkat Kapil Date: Wed, 3 Dec 2025 17:48:07 +0000 Subject: [PATCH 4/6] batched calculator first attempt --- mace/calculators/mace.py | 20 ++++++++++++-------- mace/cli/run_train.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index 9fe83a38d..dc027e243 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -19,6 +19,7 @@ from ase.calculators.calculator import Calculator, all_changes from ase.stress import full_3x3_to_voigt_6_stress from e3nn import o3 +from functools import partial from mace import data as mace_data from mace.modules.utils import extract_invariant @@ -383,19 +384,22 @@ def _atoms_to_batch(self, atoms): config = mace_data.config_from_atoms( atoms, key_specification=keyspec, head_name=self.head ) + + collate = partial( + mace_data.atomicdata_collate, + z_table=self.z_table, + cutoff=self.r_max, + heads=self.available_heads, + ) + data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - mace_data.AtomicData.from_config( - config, - z_table=self.z_table, - cutoff=self.r_max, - heads=self.available_heads, - ) - ], + dataset=[config], # pass the Configuration, not AtomicData batch_size=1, shuffle=False, drop_last=False, + collate_fn=collate, ) + batch = next(iter(data_loader)).to(self.device) return batch diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index 0d2d3ba61..737e59a18 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -695,6 +695,14 @@ def run(args) -> None: if not isinstance(valid_sets, dict): valid_sets = {"Default": valid_sets} for head, valid_set in valid_sets.items(): + + collate_valid = partial( + atomicdata_collate, + z_table=z_table, + cutoff=args.r_max, + heads=heads, + ) + valid_loaders[head] = torch_geometric.dataloader.DataLoader( dataset=valid_set, batch_size=args.valid_batch_size, @@ -704,6 +712,7 @@ def run(args) -> None: pin_memory=args.pin_memory, num_workers=args.num_workers, generator=torch.Generator().manual_seed(args.seed), + collate_fn=collate_valid, ) loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole) @@ -951,6 +960,14 @@ def run(args) -> None: drop_last = test_set.drop_last except AttributeError as e: # pylint: disable=W0612 drop_last = False + + collate_test = partial( + atomicdata_collate, + z_table=z_table, + cutoff=args.r_max, + heads=heads, + ) + test_loader = torch_geometric.dataloader.DataLoader( test_set, batch_size=args.valid_batch_size, @@ -958,6 +975,7 @@ def run(args) -> None: drop_last=drop_last, num_workers=args.num_workers, pin_memory=args.pin_memory, + collate_fn=collate_test, ) test_data_loader[test_name] = test_loader if stop_first_test: From 1d0710d22ec381815c4ef6c5453cb426a1498c5e Mon Sep 17 00:00:00 2001 From: Venkat Kapil Date: Wed, 3 Dec 2025 18:03:39 +0000 Subject: [PATCH 5/6] fixed import error --- mace/calculators/mace.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mace/calculators/mace.py b/mace/calculators/mace.py index dc027e243..4b9635640 100644 --- a/mace/calculators/mace.py +++ b/mace/calculators/mace.py @@ -26,6 +26,7 @@ from mace.tools import torch_geometric, torch_tools, utils from mace.tools.compile import prepare from mace.tools.scripts_utils import extract_model +from mace.data.atomic_data import atomicdata_collate try: from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq @@ -386,7 +387,7 @@ def _atoms_to_batch(self, atoms): ) collate = partial( - mace_data.atomicdata_collate, + atomicdata_collate, z_table=self.z_table, cutoff=self.r_max, heads=self.available_heads, From 042e536be602459bc835c014d503534e9f3f125e Mon Sep 17 00:00:00 2001 From: Venkat Kapil Date: Thu, 4 Dec 2025 10:36:00 +0000 Subject: [PATCH 6/6] added collate to eval_configs --- mace/cli/eval_configs.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mace/cli/eval_configs.py b/mace/cli/eval_configs.py index 50aeb3ac6..645531ba4 100644 --- a/mace/cli/eval_configs.py +++ b/mace/cli/eval_configs.py @@ -12,12 +12,14 @@ import numpy as np import torch from e3nn import o3 +from functools import partial + from mace import data from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq from mace.modules.utils import extract_invariant from mace.tools import torch_geometric, torch_tools, utils - +from mace.data.atomic_data import atomicdata_collate def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -165,15 +167,16 @@ def run(args: argparse.Namespace) -> None: heads = None data_loader = torch_geometric.dataloader.DataLoader( - dataset=[ - data.AtomicData.from_config( - config, z_table=z_table, cutoff=float(model.r_max), heads=heads - ) - for config in configs - ], + dataset=configs, # list of Configuration batch_size=args.batch_size, shuffle=False, drop_last=False, + collate_fn=partial( + atomicdata_collate, + z_table=z_table, + cutoff=float(model.r_max), + heads=heads, + ), ) # Collect data @@ -328,6 +331,5 @@ def run(args: argparse.Namespace) -> None: # Write atoms to output path ase.io.write(args.output, images=atoms_list, format="extxyz") - if __name__ == "__main__": main()