Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions .github/workflows/pre-commit.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@ on:
push:
branches: [main]

concurrency:
group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
- name: Check out repo
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
- name: Install requirements
run: |
pip install -U pip
pip install pylint
pip install -U black
pip install .[dev]
pip install wandb
pip install tqdm
- name: Run black
run: |
python -m black .
- uses: pre-commit/action@v3.0.0

- name: Set up uv
uses: astral-sh/setup-uv@v2

- name: Install mace-torch with extras
run: uv pip install .[dev,cueq,wandb] --system

- name: Run pre-commit
run: pre-commit run --all-files --show-diff-on-failure
65 changes: 65 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
name: unit tests
on:
pull_request:
push:
branches: [main]

concurrency:
group: ${{ github.workflow }}-pr-${{ github.event.pull_request.number }}
cancel-in-progress: true

jobs:
pytest-general:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
version:
- { python: "3.8" }
- { python: "3.9" }
- { python: "3.10" }
- { python: "3.11" }
- { python: "3.12" }
- { python: "3.13" }
runs-on: ${{ matrix.os }}

steps:
- name: checkout
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.version.python }}

- name: Set up uv
uses: astral-sh/setup-uv@v2

- name: Install mace-torch
run: uv pip install -e .[dev] --system

- name: Run general unit tests
run: |
pytest tests --ignore=tests/cli/test_cueq_oeq.py

pytest-cueq:
runs-on: ubuntu-latest
steps:
- name: Check out repo
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Set up uv
uses: astral-sh/setup-uv@v2

- name: Install mace-torch with cueq
run: uv pip install -e .[dev,cueq] --system

- name: Run cueq-specific tests
run: |
pytest tests/cli/test_cueq_oeq.py -k TestCueq
pytest tests/test_calculator.py
57 changes: 0 additions & 57 deletions .github/workflows/unittest.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ exclude: &exclude_files >

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.5.0
rev: v5.0.0
hooks:
- id: mixed-line-ending
- id: trailing-whitespace
Expand All @@ -23,7 +23,7 @@ repos:
exclude: *exclude_files

- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 6.0.1
hooks:
- id: isort
name: Sort imports
Expand Down
13 changes: 6 additions & 7 deletions mace/calculators/foundations_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
mace_mp_names = [None] + list(mace_mp_urls.keys())


def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:
def download_mace_mp_checkpoint(model: Union[str, Path, None] = None) -> str:
"""
Downloads or locates the MACE-MP checkpoint file.

Expand Down Expand Up @@ -90,7 +90,7 @@ def download_mace_mp_checkpoint(model: Union[str, Path] = None) -> str:


def mace_mp(
model: Union[str, Path] = None,
model: Union[str, Path, None] = None,
device: str = "",
default_dtype: str = "float32",
dispersion: bool = False,
Expand Down Expand Up @@ -168,11 +168,10 @@ def mace_mp(
) from exc

print("Using TorchDFTD3Calculator for D3 dispersion corrections")
dtype = torch.float32 if default_dtype == "float32" else torch.float64
d3_calc = TorchDFTD3Calculator(
device=device,
damping=damping,
dtype=dtype,
dtype=mace_calc.dtype,
xc=dispersion_xc,
cutoff=dispersion_cutoff,
**kwargs,
Expand All @@ -182,7 +181,7 @@ def mace_mp(


def mace_off(
model: Union[str, Path] = None,
model: Union[str, Path, None] = None,
device: str = "",
default_dtype: str = "float64",
return_raw_model: bool = False,
Expand Down Expand Up @@ -265,7 +264,7 @@ def mace_off(

def mace_anicc(
device: str = "cuda",
model_path: str = None,
model_path: Union[str, Path, None] = None,
return_raw_model: bool = False,
) -> MACECalculator:
"""
Expand Down Expand Up @@ -318,7 +317,7 @@ def report_progress(block_num, block_size, total_size):


def mace_omol(
model: Union[str, Path] = None,
model: Union[str, Path, None] = None,
device: str = "",
default_dtype: str = "float64",
return_raw_model: bool = False,
Expand Down
106 changes: 74 additions & 32 deletions mace/calculators/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,13 @@ def __init__(
print(
f"Default dtype {default_dtype} does not match model dtype {model_dtype}, converting models to {default_dtype}."
)
if default_dtype == "float64":
self.models = [model.double() for model in self.models]
elif default_dtype == "float32":
self.models = [model.float() for model in self.models]
torch_tools.set_default_dtype(default_dtype)
self.models = [
model.to(dtype=torch_tools.dtype_dict[default_dtype])
for model in self.models
]

self._dtype = torch_tools.dtype_dict[default_dtype]

if enable_cueq:
print("Converting models to CuEq for acceleration")
self.models = [
Expand All @@ -315,22 +317,41 @@ def __init__(
for param in model.parameters():
param.requires_grad = False

@property
def dtype(self):
return self._dtype

def to(self, **kwargs):
self._dtype = kwargs.get("dtype", self._dtype)
self.models = [model.to(**kwargs) for model in self.models]
return self

def _create_result_tensors(
self, model_type: str, num_models: int, num_atoms: int
) -> dict:
"""
Create tensors to store the results of the committee
:param model_type: str, type of model to load
Options: [MACE, DipoleMACE, EnergyDipoleMACE]
:param num_models: int, number of models in the committee
:return: tuple of torch tensors
) -> dict[str, torch.Tensor]:
"""Creates tensors to store the results of the committee.

Args:
model_type: Type of model to load. Must be one of:
- 'MACE'
- 'DipoleMACE'
- 'EnergyDipoleMACE'
num_models: Number of models in the committee.
num_atoms: Number of atoms in the system.

Returns:
dict: Dictionary containing initialized tensors for storing committee results.
"""
dict_of_tensors = {}
if model_type in ["MACE", "EnergyDipoleMACE"]:
energies = torch.zeros(num_models, device=self.device)
node_energy = torch.zeros(num_models, num_atoms, device=self.device)
forces = torch.zeros(num_models, num_atoms, 3, device=self.device)
stress = torch.zeros(num_models, 3, 3, device=self.device)
energies = torch.zeros(num_models, device=self.device, dtype=self.dtype)
node_energy = torch.zeros(
num_models, num_atoms, device=self.device, dtype=self.dtype
)
forces = torch.zeros(
num_models, num_atoms, 3, device=self.device, dtype=self.dtype
)
stress = torch.zeros(num_models, 3, 3, device=self.device, dtype=self.dtype)
dict_of_tensors.update(
{
"energies": energies,
Expand All @@ -340,12 +361,18 @@ def _create_result_tensors(
}
)
if model_type in ["EnergyDipoleMACE", "DipoleMACE", "DipolePolarizabilityMACE"]:
dipole = torch.zeros(num_models, 3, device=self.device)
dipole = torch.zeros(num_models, 3, device=self.device, dtype=self.dtype)
dict_of_tensors.update({"dipole": dipole})
if model_type in ["DipolePolarizabilityMACE"]:
charges = torch.zeros(num_models, num_atoms, device=self.device)
polarizability = torch.zeros(num_models, 3, 3, device=self.device)
polarizability_sh = torch.zeros(num_models, 6, device=self.device)
charges = torch.zeros(
num_models, num_atoms, device=self.device, dtype=self.dtype
)
polarizability = torch.zeros(
num_models, 3, 3, device=self.device, dtype=self.dtype
)
polarizability_sh = torch.zeros(
num_models, 6, device=self.device, dtype=self.dtype
)
dict_of_tensors.update(
{
"charges": charges,
Expand All @@ -370,6 +397,7 @@ def _atoms_to_batch(self, atoms):
z_table=self.z_table,
cutoff=self.r_max,
heads=self.available_heads,
dtype=self.dtype,
)
],
batch_size=1,
Expand All @@ -386,14 +414,19 @@ def _clone_batch(self, batch):
batch_clone["positions"].requires_grad_(True)
return batch_clone

# pylint: disable=dangerous-default-value
def calculate(self, atoms=None, properties=None, system_changes=all_changes):
"""
Calculate properties.
:param atoms: ase.Atoms object
:param properties: [str], properties to be computed, used by ASE internally
:param system_changes: [str], system changes since last calculation, used by ASE internally
:return:
def calculate(self, atoms=None, properties=None, system_changes=tuple(all_changes)):
"""Calculates atomic properties using the MACE model.

Args:
atoms: ASE Atoms object representing the atomic structure.
properties: List of strings specifying the properties to be computed.
Used internally by ASE.
system_changes: List of strings indicating what has changed in the system
since the last calculation. Used internally by ASE.
Defaults to all_changes.

Note:
This method is part of ASE's calculator interface.
"""
# call to base-class to set atoms attribute
Calculator.calculate(self, atoms)
Expand Down Expand Up @@ -594,10 +627,19 @@ def get_hessian(self, atoms=None):

def get_descriptors(self, atoms=None, invariants_only=True, num_layers=-1):
"""Extracts the descriptors from MACE model.
:param atoms: ase.Atoms object
:param invariants_only: bool, if True only the invariant descriptors are returned
:param num_layers: int, number of layers to extract descriptors from, if -1 all layers are used
:return: np.ndarray (num_atoms, num_interactions, invariant_features) of invariant descriptors if num_models is 1 or list[np.ndarray] otherwise

Args:
atoms: ASE Atoms object representing the atomic structure.
invariants_only: If True, only the invariant descriptors are returned.
Defaults to True.
num_layers: Number of layers to extract descriptors from.
If -1, descriptors from all layers are used.
Defaults to -1.

Returns:
Union[np.ndarray, List[np.ndarray]]: If num_models is 1, returns a numpy array
of shape (num_atoms, num_interactions, invariant_features) containing
the invariant descriptors. Otherwise, returns a list of such arrays.
"""
if atoms is None and self.atoms is None:
raise ValueError("atoms not set")
Expand Down
Loading
Loading