Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
from ase.calculators.calculator import Calculator, all_changes
from ase.stress import full_3x3_to_voigt_6_stress
from e3nn import o3
from functools import partial

from mace import data as mace_data
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
from mace.tools.compile import prepare
from mace.tools.scripts_utils import extract_model
from mace.data.atomic_data import atomicdata_collate

try:
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
Expand Down Expand Up @@ -383,19 +385,22 @@ def _atoms_to_batch(self, atoms):
config = mace_data.config_from_atoms(
atoms, key_specification=keyspec, head_name=self.head
)

collate = partial(
atomicdata_collate,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
)

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
mace_data.AtomicData.from_config(
config,
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
)
],
dataset=[config], # pass the Configuration, not AtomicData
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=collate,
)

batch = next(iter(data_loader)).to(self.device)
return batch

Expand Down
18 changes: 10 additions & 8 deletions mace/cli/eval_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import numpy as np
import torch
from e3nn import o3
from functools import partial


from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils

from mace.data.atomic_data import atomicdata_collate

def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -165,15 +167,16 @@ def run(args: argparse.Namespace) -> None:
heads = None

data_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
config, z_table=z_table, cutoff=float(model.r_max), heads=heads
)
for config in configs
],
dataset=configs, # list of Configuration
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
collate_fn=partial(
atomicdata_collate,
z_table=z_table,
cutoff=float(model.r_max),
heads=heads,
),
)

# Collect data
Expand Down Expand Up @@ -328,6 +331,5 @@ def run(args: argparse.Namespace) -> None:
# Write atoms to output path
ase.io.write(args.output, images=atoms_list, format="extxyz")


if __name__ == "__main__":
main()
27 changes: 27 additions & 0 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch.optim import LBFGS
from torch.utils.data import ConcatDataset
from torch_ema import ExponentialMovingAverage
from functools import partial

import mace
from mace import data, tools
Expand All @@ -34,6 +35,7 @@
from mace.cli.convert_oeq_e3nn import run as run_oeq_to_e3nn
from mace.cli.visualise_train import TrainingPlotter
from mace.data import KeySpecification, update_keyspec_from_kwargs
from mace.data.atomic_data import atomicdata_collate
from mace.tools import torch_geometric
from mace.tools.distributed_tools import init_distributed
from mace.tools.model_script_utils import configure_model
Expand Down Expand Up @@ -636,6 +638,8 @@ def run(args) -> None:
dataset_size = len(train_sets[head_config.head_name])
logging.info(f"Head '{head_config.head_name}' training dataset size: {dataset_size}")

collate = partial(atomicdata_collate, z_table=z_table, cutoff=args.r_max, heads=heads)

train_loader_head = torch_geometric.dataloader.DataLoader(
dataset=train_sets[head_config.head_name],
batch_size=args.batch_size,
Expand All @@ -644,6 +648,7 @@ def run(args) -> None:
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=collate
)
head_config.train_loader = train_loader_head

Expand Down Expand Up @@ -671,6 +676,9 @@ def run(args) -> None:
)
valid_samplers[head] = valid_sampler


collate = partial(atomicdata_collate, z_table=z_table, cutoff=args.r_max, heads=heads)

train_loader = torch_geometric.dataloader.DataLoader(
dataset=train_set,
batch_size=args.batch_size,
Expand All @@ -680,12 +688,21 @@ def run(args) -> None:
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=collate
)

valid_loaders = {heads[i]: None for i in range(len(heads))}
if not isinstance(valid_sets, dict):
valid_sets = {"Default": valid_sets}
for head, valid_set in valid_sets.items():

collate_valid = partial(
atomicdata_collate,
z_table=z_table,
cutoff=args.r_max,
heads=heads,
)

valid_loaders[head] = torch_geometric.dataloader.DataLoader(
dataset=valid_set,
batch_size=args.valid_batch_size,
Expand All @@ -695,6 +712,7 @@ def run(args) -> None:
pin_memory=args.pin_memory,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
collate_fn=collate_valid,
)

loss_fn = get_loss_fn(args, dipole_only, args.compute_dipole)
Expand Down Expand Up @@ -942,13 +960,22 @@ def run(args) -> None:
drop_last = test_set.drop_last
except AttributeError as e: # pylint: disable=W0612
drop_last = False

collate_test = partial(
atomicdata_collate,
z_table=z_table,
cutoff=args.r_max,
heads=heads,
)

test_loader = torch_geometric.dataloader.DataLoader(
test_set,
batch_size=args.valid_batch_size,
shuffle=(test_sampler is None),
drop_last=drop_last,
num_workers=args.num_workers,
pin_memory=args.pin_memory,
collate_fn=collate_test,
)
test_data_loader[test_name] = test_loader
if stop_first_test:
Expand Down
73 changes: 55 additions & 18 deletions mace/data/atomic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
voigt_to_matrix,
)

from .neighborhood import get_neighborhood
from ..tools.torch_geometric import Batch
from .neighborhood import get_neighborhood, get_neighborhood_batched
from .utils import Configuration


Expand Down Expand Up @@ -153,12 +154,21 @@ def from_config(
) -> "AtomicData":
if heads is None:
heads = ["Default"]
edge_index, shifts, unit_shifts, cell = get_neighborhood(
positions=config.positions,
cutoff=cutoff,
pbc=deepcopy(config.pbc),
cell=deepcopy(config.cell),
)

edge_index = kwargs.pop("edge_index", None)
shifts = kwargs.pop("shifts", None)
unit_shifts = kwargs.pop("unit_shifts", None)

if edge_index is None or shifts is None or unit_shifts is None:
edge_index, shifts, unit_shifts, cell = get_neighborhood(
positions=config.positions,
cutoff=cutoff,
pbc=deepcopy(config.pbc),
cell=deepcopy(config.cell),
)
else:
cell = deepcopy(config.cell)

indices = atomic_numbers_to_indices(config.atomic_numbers, z_table=z_table)
one_hot = to_one_hot(
torch.tensor(indices, dtype=torch.long).unsqueeze(-1),
Expand Down Expand Up @@ -395,15 +405,42 @@ def from_config(
return cls(**cls_kwargs)


def get_data_loader(
dataset: Sequence[AtomicData],
batch_size: int,
shuffle=True,
drop_last=False,
) -> torch.utils.data.DataLoader:
return torch_geometric.dataloader.DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
# Adding a collate function

def atomicdata_collate(batch, z_table, cutoff, heads=None):
# If they are already AtomicData (e.g. HDF5/LMDB), just batch them
if isinstance(batch[0], AtomicData):
return Batch.from_data_list(batch)

# Otherwise assume they are Configuration objects
configs = batch

positions_list = [cfg.positions for cfg in configs]
pbc_list = [cfg.pbc for cfg in configs]
cell_list = [cfg.cell for cfg in configs]

edge_indices, shifts_list, unit_shifts_list, cells_list = get_neighborhood_batched(
positions_list=positions_list,
cutoff=cutoff,
pbc_list=pbc_list,
cell_list=cell_list,
true_self_interaction=False,
)

data_list = []
for cfg, edge_index, shifts, unit_shifts, cell in zip(
configs, edge_indices, shifts_list, unit_shifts_list, cells_list
):
atomic_i = AtomicData.from_config(
config=cfg,
z_table=z_table,
cutoff=cutoff,
heads=heads,
edge_index=edge_index,
shifts=shifts,
unit_shifts=unit_shifts,
cell=cell,
)
data_list.append(atomic_i)

return Batch.from_data_list(data_list)
41 changes: 40 additions & 1 deletion mace/data/neighborhood.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,48 @@
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Sequence

import numpy as np
from matscipy.neighbours import neighbour_list

def get_neighborhood_batched(
positions_list: Sequence[np.ndarray], # list of [num_positions_i, 3]
cutoff: float,
pbc_list: Optional[Sequence[Optional[Tuple[bool, bool, bool]]]] = None,
cell_list: Optional[Sequence[Optional[np.ndarray]]] = None, # list of [3, 3]
true_self_interaction: bool = False,
) -> Tuple[
List[np.ndarray], # edge_index_list
List[np.ndarray], # shifts_list
List[np.ndarray], # unit_shifts_list
List[np.ndarray], # cell_list_out
]:
"""
For now: trivial batched version that just loops over structures and
calls get_neighborhood for each one.
"""
if pbc_list is None:
pbc_list = [None] * len(positions_list)
if cell_list is None:
cell_list = [None] * len(positions_list)

edge_index_list: List[np.ndarray] = []
shifts_list: List[np.ndarray] = []
unit_shifts_list: List[np.ndarray] = []
cell_list_out: List[np.ndarray] = []

for positions, pbc, cell in zip(positions_list, pbc_list, cell_list):
edge_index, shifts, unit_shifts, cell_out = get_neighborhood(
positions=positions,
cutoff=cutoff,
pbc=pbc,
cell=cell,
true_self_interaction=true_self_interaction,
)
edge_index_list.append(edge_index)
shifts_list.append(shifts)
unit_shifts_list.append(unit_shifts)
cell_list_out.append(cell_out)

return edge_index_list, shifts_list, unit_shifts_list, cell_list_out
def get_neighborhood(
positions: np.ndarray, # [num_positions, 3]
cutoff: float,
Expand Down
7 changes: 1 addition & 6 deletions mace/tools/run_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ def load_dataset_for_path(
assert (
collection is not None
), "Collection must be provided for ASE readable files"
return [
data.AtomicData.from_config(
config, z_table=z_table, cutoff=r_max, heads=heads
)
for config in collection
]
return list(collection)

filepath = Path(file_path)
if filepath.is_dir():
Expand Down
7 changes: 6 additions & 1 deletion mace/tools/torch_geometric/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ def __init__(
shuffle: bool = False,
follow_batch: Optional[List[str]] = [None],
exclude_keys: Optional[List[str]] = [None],
collate_fn=None,
**kwargs,
):

if collate_fn is None:
collate_fn = Collater(follow_batch, exclude_keys)

if "collate_fn" in kwargs:
del kwargs["collate_fn"]

Expand All @@ -82,6 +87,6 @@ def __init__(
dataset,
batch_size,
shuffle,
collate_fn=Collater(follow_batch, exclude_keys),
collate_fn=collate_fn,
**kwargs,
)
Loading