Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion cascade/learning/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@
"""Just the model, which we require being the MACE which includes scale shifting logic"""


# TODO (wardlt): Use https://github.com/ACEsuit/mace/pull/830 when merged
def freeze_layers(model: torch.nn.Module, n: int = 4) -> None:
"""
Freezes the first `n` layers of a model. If `n` is negative, freezes the last `|n|` layers.
Args:
model (torch.nn.Module): The model.
n (int): The number of layers to freeze.
"""
layers = list(model.children())
num_layers = len(layers)

logging.info(f"Total layers in model: {num_layers}")

if abs(n) > num_layers:
logging.warning(
f"Requested {n} layers, but model only has {num_layers}. Adjusting `n` to fit the model."
)
n = num_layers if n > 0 else -num_layers

frozen_layers = layers[:n] if n > 0 else layers[n:]

logging.info(f"Freezing {len(frozen_layers)} layers.")

for layer in frozen_layers:
for param in layer.parameters():
param.requires_grad = False


def atoms_to_loader(atoms: list[Atoms], batch_size: int, z_table: AtomicNumberTable, r_max: float, **kwargs):
"""
Make a data loader from a list of ASE atoms objects
Expand Down Expand Up @@ -125,7 +153,29 @@ def train(self,
stress_weight: float = 100,
reset_weights: bool = False,
patience: int | None = None,
**kwargs) -> tuple[bytes, pd.DataFrame]:
num_freeze: int | None = None
) -> tuple[bytes, pd.DataFrame]:
"""Train a model

Args:
model_msg: Model to be retrained
train_data: Structures used for training
valid_data: Structures used for validation
num_epochs: Number of training epochs
device: Device (e.g., 'cuda', 'cpu') used for training
batch_size: Batch size during training
learning_rate: Initial learning rate for optimizer
huber_deltas: Delta parameters for the loss functions for energy and force
force_weight: Amount of weight to use for the force part of the loss function
stress_weight: Amount of weight to use for the stress part of the loss function
reset_weights: Whether to reset the weights before training
patience: Halt training after validation error increases for these many epochs
num_freeze: Number of layers to freeze. Starts from the top of the model (node embedding)
See: `Radova et al. <https://arxiv.org/html/2502.15582v1>`_
Returns:
- model: Retrained model
- history: Training history
"""

# Load the model
model = self.get_model(model_msg)
Expand All @@ -142,6 +192,10 @@ def train(self,
for p in model.parameters():
p.requires_grad = True

# Freeze desired layers
if num_freeze is not None:
freeze_layers(model, num_freeze)

# Convert the training data from ASE -> MACE Configs
train_loader = atoms_to_loader(train_data, batch_size, z_table, r_max, shuffle=True, drop_last=True)
valid_loader = atoms_to_loader(valid_data, batch_size, z_table, r_max, shuffle=False, drop_last=True)
Expand Down
29 changes: 10 additions & 19 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,20 @@ channels:
dependencies:
# Core dependencies
- python==3.11
- matplotlib
- scikit-learn>=1
- jupyterlab
- pandas
- pytest
- flake8
- pip

# Computational chemistry
- packmol
- cp2k

# For nersc's jupyterlab
- ipykernel

# Pip packages for all of them
# Pip packages for the Python modules
- pip
- pip:
- git+https://gitlab.com/ase/ase.git
- git+https://github.com/ACEsuit/mace.git
- torch
- mlflow
- pytorch-ignite
- python-git-info
- tqdm
# General utilities
- matplotlib
- papermill
- -e .
- scikit-learn>=1
- jupyterlab
- pandas

# ML side
- -e .[mace,chgnet,ani]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,5 @@ chgnet = [
]
mace = [
'mace-torch',
'ignite'
'pytorch-ignite'
]
10 changes: 10 additions & 0 deletions tests/learning/test_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ def test_inference(mace, example_data):
assert np.isclose(atoms.get_potential_energy(), energy[0]).all()
assert np.isclose(atoms.get_forces(), forces[0]).all()
assert np.isclose(atoms.get_stress(voigt=False), stresses[0]).all()


def test_freeze(example_data, mace):
# Get baseline predictions, train
mi = MACEInterface()
model_msg, _ = mi.train(mace, example_data, example_data, 2, batch_size=2, patience=1, num_freeze=2)
model: MACEState = mi.get_model(model_msg)
is_trainable = [all(y.requires_grad for y in x.parameters()) for x in model.children()]
assert not is_trainable[0]
assert all(is_trainable[4:6]) # Layers >4 include some layers which are not trainable
1 change: 0 additions & 1 deletion tests/test_proxima.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def simple_proxima(simple_model, target_calc, tmpdir):
def initialized_db(simple_proxima, target_calc, starting_frame):
"""Initialize the database"""
# Compute a set of initial calcs if required
global _initial_calcs
if len(_initial_calcs) == 0:
for i in range(12):
new_frame = starting_frame.copy()
Expand Down
Loading