diff --git a/cascade/learning/mace.py b/cascade/learning/mace.py index 9c38e06..90f55a2 100644 --- a/cascade/learning/mace.py +++ b/cascade/learning/mace.py @@ -29,6 +29,34 @@ """Just the model, which we require being the MACE which includes scale shifting logic""" +# TODO (wardlt): Use https://github.com/ACEsuit/mace/pull/830 when merged +def freeze_layers(model: torch.nn.Module, n: int = 4) -> None: + """ + Freezes the first `n` layers of a model. If `n` is negative, freezes the last `|n|` layers. + Args: + model (torch.nn.Module): The model. + n (int): The number of layers to freeze. + """ + layers = list(model.children()) + num_layers = len(layers) + + logging.info(f"Total layers in model: {num_layers}") + + if abs(n) > num_layers: + logging.warning( + f"Requested {n} layers, but model only has {num_layers}. Adjusting `n` to fit the model." + ) + n = num_layers if n > 0 else -num_layers + + frozen_layers = layers[:n] if n > 0 else layers[n:] + + logging.info(f"Freezing {len(frozen_layers)} layers.") + + for layer in frozen_layers: + for param in layer.parameters(): + param.requires_grad = False + + def atoms_to_loader(atoms: list[Atoms], batch_size: int, z_table: AtomicNumberTable, r_max: float, **kwargs): """ Make a data loader from a list of ASE atoms objects @@ -125,7 +153,29 @@ def train(self, stress_weight: float = 100, reset_weights: bool = False, patience: int | None = None, - **kwargs) -> tuple[bytes, pd.DataFrame]: + num_freeze: int | None = None + ) -> tuple[bytes, pd.DataFrame]: + """Train a model + + Args: + model_msg: Model to be retrained + train_data: Structures used for training + valid_data: Structures used for validation + num_epochs: Number of training epochs + device: Device (e.g., 'cuda', 'cpu') used for training + batch_size: Batch size during training + learning_rate: Initial learning rate for optimizer + huber_deltas: Delta parameters for the loss functions for energy and force + force_weight: Amount of weight to use for the force part of the loss function + stress_weight: Amount of weight to use for the stress part of the loss function + reset_weights: Whether to reset the weights before training + patience: Halt training after validation error increases for these many epochs + num_freeze: Number of layers to freeze. Starts from the top of the model (node embedding) + See: `Radova et al. `_ + Returns: + - model: Retrained model + - history: Training history + """ # Load the model model = self.get_model(model_msg) @@ -142,6 +192,10 @@ def train(self, for p in model.parameters(): p.requires_grad = True + # Freeze desired layers + if num_freeze is not None: + freeze_layers(model, num_freeze) + # Convert the training data from ASE -> MACE Configs train_loader = atoms_to_loader(train_data, batch_size, z_table, r_max, shuffle=True, drop_last=True) valid_loader = atoms_to_loader(valid_data, batch_size, z_table, r_max, shuffle=False, drop_last=True) diff --git a/environment.yml b/environment.yml index 38f0eec..1e7e776 100644 --- a/environment.yml +++ b/environment.yml @@ -5,29 +5,20 @@ channels: dependencies: # Core dependencies - python==3.11 - - matplotlib - - scikit-learn>=1 - - jupyterlab - - pandas - - pytest - - flake8 - - pip # Computational chemistry - packmol - cp2k - # For nersc's jupyterlab - - ipykernel - -# Pip packages for all of them + # Pip packages for the Python modules + - pip - pip: - - git+https://gitlab.com/ase/ase.git - - git+https://github.com/ACEsuit/mace.git - - torch - - mlflow - - pytorch-ignite - - python-git-info - - tqdm + # General utilities + - matplotlib - papermill - - -e . + - scikit-learn>=1 + - jupyterlab + - pandas + + # ML side + - -e .[mace,chgnet,ani] diff --git a/pyproject.toml b/pyproject.toml index 2c2b1d3..c15f486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,5 +48,5 @@ chgnet = [ ] mace = [ 'mace-torch', - 'ignite' + 'pytorch-ignite' ] \ No newline at end of file diff --git a/tests/learning/test_mace.py b/tests/learning/test_mace.py index 5ba9f7d..c24cefb 100644 --- a/tests/learning/test_mace.py +++ b/tests/learning/test_mace.py @@ -48,3 +48,13 @@ def test_inference(mace, example_data): assert np.isclose(atoms.get_potential_energy(), energy[0]).all() assert np.isclose(atoms.get_forces(), forces[0]).all() assert np.isclose(atoms.get_stress(voigt=False), stresses[0]).all() + + +def test_freeze(example_data, mace): + # Get baseline predictions, train + mi = MACEInterface() + model_msg, _ = mi.train(mace, example_data, example_data, 2, batch_size=2, patience=1, num_freeze=2) + model: MACEState = mi.get_model(model_msg) + is_trainable = [all(y.requires_grad for y in x.parameters()) for x in model.children()] + assert not is_trainable[0] + assert all(is_trainable[4:6]) # Layers >4 include some layers which are not trainable diff --git a/tests/test_proxima.py b/tests/test_proxima.py index 4b75bc8..b284531 100644 --- a/tests/test_proxima.py +++ b/tests/test_proxima.py @@ -59,7 +59,6 @@ def simple_proxima(simple_model, target_calc, tmpdir): def initialized_db(simple_proxima, target_calc, starting_frame): """Initialize the database""" # Compute a set of initial calcs if required - global _initial_calcs if len(_initial_calcs) == 0: for i in range(12): new_frame = starting_frame.copy()