Skip to content

Conversation

@Neutrino155
Copy link

@Neutrino155 Neutrino155 commented Sep 12, 2025

Add MACEField: train, finetune, and infer electric-field responses (polarization, BECs, polarizability) from energy

Summary

This PR introduces MACEField — a uniform-electric-field extension that injects an external field (o3.Irreps("1o")) into MACE’s latent node features. By differentiating the energy w.r.t. field and atomic displacements, the model exposes autograd-consistent responses:

  • Polarization: $\qquad \displaystyle \mathbf P = -\frac{1}{\Omega},\frac{\partial E}{\partial \mathbf E} \quad [\mathbf P:\ \mathrm{e}/Å^2]$

  • Born effective charges (BECs): $\qquad \displaystyle Z_{\kappa,\alpha\beta} = \Omega,\frac{\partial P_\beta}{\partial u_{\kappa,\alpha}} \qquad [Z:\ \mathrm{e}]$

  • Polarizability / susceptibility: $\qquad \displaystyle \chi_{\alpha\beta} = \frac{\partial P_\alpha}{\partial E_\beta} = -\frac{1}{\Omega},\frac{\partial^2 E}{\partial E_\alpha,\partial E_\beta} \quad [\chi:\ \mathrm{e}/(\mathrm{V}\cdotÅ)]$

Heads are opt-in, work with batching and per-graph fields, and can be used from scratch or to finetune a foundation model to be field-aware.

What’s new

  • Model: MACEField injects a uniform field ("1o") and mixes it via equivariant tensor products + linear mixing.
  • Calculator: returns polarization, becs, polarizability (when requested) alongside E/F/σ; supports per-graph fields.
  • CLI & Tools: training via run_train.py; batch inference via mace_eval_configs; plotting during training shows new outputs; ASE calculator example included.

Electric field source

  1. Explicit field (--electric-field in CLI as optional input to model) — takes precedence if provided — used mainly for inference and MD
  2. Per-structure field in input .xyz (info["REF_electric_field"]) — used mainly for training
  3. Otherwise: zero field — no electric field provided

Loss function (universal_field) — what’s optimized

We use a multi-task objective UniversalFieldLoss that combines the standard E/F/σ terms as in UniversalLoss with field-response heads.

Polarization folding — resolving Berry-phase branches

As polarization in a periodic crystal is multi-valued: during training we fold the polarization difference onto the nearest branch before computing the loss:

  1. Form the raw difference: $\Delta \mathbf{P} = \mathbf{P}_r - \mathbf{P}_p$ .
  2. Build the polarization-quantum matrix: $\mathbf{Q} = \big[\mathbf{Q}_1;\mathbf{Q}_2;\mathbf{Q}_3\big] = \frac{e}{\Omega}[\mathbf{a}_1;\mathbf{a}_2;\mathbf{a}_3]$ (columns $\mathbf{Q}_i$).
  3. Compute the (real-valued) coefficients $\mathbf{c} = \mathbf{Q}^{-1},\Delta \mathbf{P}$ (solve a 3×3 linear system).
  4. Round to the nearest integers: $\mathbf{n} = \mathrm{round}(\mathbf{c})$.
  5. Folded difference: $\displaystyle \Delta \mathbf{P}^{\text{fold}} = \Delta \mathbf{P} - \mathbf{Q},\mathbf{n}$.

Use $\Delta \mathbf{P}^{\text{fold}}$ in $\mathcal{L}_P$. This makes the objective branch-invariant and avoids discontinuities across ferroelectric paths or mixed-cell datasets.

Example training script (CLI)

torchrun --standalone --nproc_per_node="gpu" ./mace/cli/run_train.py \
  --name="MACEField-BaTiO3" \
  --train_file="BaTiO3-md-traj.xyz" \ 
  --valid_fraction=0.10 \
  --E0s="average" \
  --model="MACEField" \ # Use MACEField model
  --loss="universal_field" \  # Loss for energy, forces, stress, poalrization, becs and polarizability
  --error_table="PerAtomFieldRMSE" \  # Error table for all 6 outputs
  --energy_weight=1.0 \
  --forces_weight=100.0 
  --stress_weight=1.0 \
  --polarization_weight=1.0 \ # new weight for polarization (can be 1x3 vector)
  --becs_weight=100.0 \ # new weight for becs (can be 3x3 matrix)
  --polarizability_weight=100.0 \ # new weight for polarizability (can be 3x3 matrix)
  --compute_forces=True \
  --compute_stress=True \
  --compute_polarization=True \
  --compute_becs=True \
  --compute_polarizability=True \
  --distributed \
  --launcher="torchrun" \
  --device="cuda" \

Targets & shapes (convention):
info["REF_polarization"][3] (e/Ų) · arrays["REF_becs"][n_atoms,3,3] (e) · info["REF_polarizability"][3,3] (e/(V·Å))

Weighting tips: start with forces≈100, polarization≈1, becs≈50–150, polarizability≈50–150

Field source: prefer per-frame info["REF_electric_field"] in the dataset; pass an explicit field only if you want to override such as during finite-field MD

Plotting: polarization, becs and polarizability errors and parities show alongside energy, forces and stress during training with --plot=True

Finetune a foundation model to be field-aware

MACEField can be used to finetune existing foundation models to be "field aware".

For example, with heterogeneous data where our datasets are:

  • becs-polarizabilities.xyz has just becs and polarizability data (use config_energy_weight=0.0, config_forces_weight=0.0, config_stress_weight=0.0 and config_polarization_weight=0.0).
  • polarizations.xyz has just polarization data (use config_energy_weight=0.0, config_forces_weight=0.0, config_stress_weight=0.0, config_becs_weight=0.0 and config_polarizability_weight=0.0)

we can do multiheaded finetuning on the mp-0b3 model using MACEField with a config:

# config.yaml (excerpt)
name: mace-field-mp-0b3-medium-mh
distributed: true
device: cuda
foundation_model: "mace-mp-0b3-medium.model"
model: "MACEField"
loss: "universal_field"
heads:
  mp-becs-polarizabilities: 
    train_file: ["becs-polarizabilities-train.xyz"] 
    valid_file: ["becs-polarizabilities-valid.xyz"] 
  mp-polarizations:
    train_file: ["polarizations-train.xyz"]
    valid_file: ["polarizations-valid.xyz"]

pt_train_file: "mp_traj_selected.xyz"   # optional replay set (e.g. 10000 selected samples)
multiheads_finetuning: true
E0s: foundation

compute_forces: true
compute_becs: true
compute_polarisation: true
compute_polarisability: true
lr: 0.0001
ema: true
ema_decay: 0.999
batch_size: 1

Run:

torchrun --standalone --nproc_per_node="gpu" ./mace/cli/run_train.py --config config.yaml

Note: Still experimental. Maybe interesting to investigate frozen weights, initializing new MACEField weights as 0 or small, etc.

Inference

Batch CLI — mace_eval_configs

mace_eval_configs \
  --configs "mp_traj_selected.xyz" \
  --model "mace-field-mp-0b3-medium-mh.model" \
  --output "mp_traj_selected-MACE.xyz" \
  --head "pt_head" \
  --compute_polarisation \
  --compute_becs \ 
  --compute_polarisability \

Uses per-frame info["REF_electric_field"] unless an explicit field option is provided via --electric-field [E_x, E_y, E_z] (explicit overrides). Writes predictions back to the XYZ.

Finite-field Molecular Dynamics

ASE calculator

from ase import Atoms
from mace.calculators.mace import MACECalculator

atoms = Atoms(...)
calc = MACECalculator(
    model_path="MACEField-BaTiO3.model",
    model_type="MACEField",
    # If set, this overrides atoms.info["electric_field"]
    electric_field=[0.0, 0.0, 0.02],
)
atoms.calc = calc
E  = atoms.get_potential_energy()
F  = atoms.get_forces()
P   = atoms.calc.results.get("polarization")    # e/Ų
Z   = atoms.calc.results.get("becs")            # e
chi = atoms.calc.results.get("polarizability")  # e/(V·Å)

For time-dependent electric field, we can use calc.electric_field = [E_x, E_z, E_z] to update through logger.

Data format

Extended-XYZ per frame: info["REF_electric_field"] as [Ex,Ey,Ez]; optional labels info["REF_polarization"] (e/Ų), info["REF_polarizability"] (e/(V·Å)), and arrays["REF_becs"] (e).

Implementation highlights

  • Field enters as Irreps("1o"), mixed via equivariant tensor products + linear mixing.
  • Derivatives from energy expose P, BECs, χ; these derivatives are computed only when requested.
  • Work's with finetuned model and existing CLI and ASE calculators.

Limitations / TODO

  • Tests: not yet implemented (FD vs autograd for P/Z*/χ; batching; χ symmetry; ASR).
  • LAMMPS: integration not yet implemented (wire uniform field & flags into MLIAP bridge).

TL;DR: Learn and evaluate electric-field responses in MACE. Train from scratch or finetune a foundation model; run batch inference via mace_eval_configs or interactively finite-field MD via ASE — with explicit field overrides, per-graph fields, and polarization folding in loss.

Neutrino155 and others added 30 commits January 24, 2025 10:47
… not using the unit electric field vector for the angular components. Whilst I was at it, I also added a radial component for the electric field for cases where it is not zero.
…tion lattice and make it modulo the polarisation quantum).
…e --device choices so I can use --device="cuda:1"
@Neutrino155
Copy link
Author

@ilyes319 Sorry for the delay on this - finally got the tests done. They appear to pass locally on my end.

There are a few branch conflicts that need to be resolved. I had a look:

  • atomic_data.py is just an additional wildcard on your end, which is fine. We can keep the incoming change.
  • finetuning_utils.py includes additional code to allow a MACEField model to also be finetuned on. This would be good to keep, but there are some changes throughout that I did when debugging some failing tests, so we need to be careful not to break your existing code.
  • mace.py seems to have parts that have been just rewritten in a slightly different way. We can probably accept both here without actual conflicts.

Let me know what you think and if you need anything more from me.

Cheers,
Brad

@Neutrino155
Copy link
Author

Okay, fixed up the tests. Everything now seems green... I am not sure if some tests are timing out.

I extended the LAMMPs code to work with MACEField, so now it can be used to do MD with a static or time-dependent electric field. The lammps_mace.py can output the additional polarization, becs and polarizability easily, but the mliap seems less flexible. So, for now, I can inject the electric field per timestep, but dielectric properties have to be evaluated afterwards from the resulting trajectory.

LAMMPS: constant electric field

Example (Kokkos GPU, single GPU):

export MACE_EFIELD_MODE=env \
export MACE_EFIELD=0,0,0.3 \
lmp -k on g 1 -sf kk -pk kokkos gpu/aware on neigh half newton on -in in.lammps_macefield

This feeds $\mathbf{E} = (0, 0, 0.3)$ V/Å to MACE-Field at every MD step.

LAMMPS: time-dependent electric field (MACE-Field via env var)

MACE-Field’s LAMMPS ML-IAP wrapper can take a time-dependent electric field by updating the environment variable MACE_EFIELD every MD step from LAMMPS equal-style variables.

This approach keeps the wrapper simple: it only needs to re-read os.environ["MACE_EFIELD"] each step.

1) Define the field as LAMMPS equal-style variables

Define Ex, Ey, Ez in your in.lammps (here Ez is sinusoidal in step):

variable        E0     equal 0.30
variable        period equal 2000

variable        Ex equal 0.0
variable        Ey equal 0.0
variable        Ez equal v_E0*sin(2.0*PI*step/v_period)

2) Push Ex,Ey,Ez into MACE_EFIELD every step via python/invoke

This Python snippet runs inside LAMMPS, reads the current values using extract_variable, and updates MACE_EFIELD:

python set_mace_efield here """
import os
from lammps import lammps

def set_mace_efield(lammps_ptr):
    lmp = lammps(ptr=lammps_ptr)
    ex = float(lmp.extract_variable("Ex", None, 0))
    ey = float(lmp.extract_variable("Ey", None, 0))
    ez = float(lmp.extract_variable("Ez", None, 0))
    os.environ["MACE_EFIELD"] = f"{ex},{ey},{ez}"
"""
fix mace_efield all python/invoke 1 end_of_step set_mace_efield

What this does:

  • LAMMPS evaluates Ex/Ey/Ez for the current step.
  • The Python hook converts them to floats and sets MACE_EFIELD="ex,ey,ez".
  • The MACE-Field LAMMPS wrapper reads MACE_EFIELD on the next force call and feeds it to the model.

@Neutrino155
Copy link
Author

Example LAMMPS-MLIAP for polarisation hysteresis for BTO - oscillatory electric field simulation for 10ps for 16875 atom supercell. Ran on 4 L4 Nvidia gpus.

Loop time of 3309.48 on 4 procs for 10000 steps with 16875 atoms

Performance: 0.261 ns/day, 91.930 hours/ns, 3.022 timesteps/s, 50.990 katom-step/s
99.9% CPU use with 4 MPI tasks x 1 OpenMP threads

MPI task timing breakdown:
Section |  min time  |  avg time  |  max time  |%varavg| %total
---------------------------------------------------------------
Pair    | 3156.2     | 3168.6     | 3186.7     |  20.0 | 95.74
Neigh   | 0          | 0          | 0          |   0.0 |  0.00
Comm    | 3.3592     | 21.437     | 33.776     | 242.3 |  0.65
Output  | 101.03     | 104        | 106.22     |  18.9 |  3.14
Modify  | 11.636     | 12.218     | 12.583     |  10.4 |  0.37
Other   |            | 3.204      |            |       |  0.10

Nlocal:        4218.75 ave        4335 max        4110 min
Histogram: 1 0 0 0 2 0 0 0 0 1
Nghost:        9185.25 ave        9429 max        8949 min
Histogram: 1 0 1 0 0 0 0 1 0 1
Neighs:              0 ave           0 max           0 min
Histogram: 4 0 0 0 0 0 0 0 0 0
FullNghs:       686812 ave      703350 max      671400 min
Histogram: 1 0 0 0 2 0 0 0 0 1

Total # of neighbors = 2747250
Ave neighs/atom = 162.8
Neighbor list builds = 0
Dangerous builds = 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants