Skip to content

Conversation

@ThomasWarford
Copy link
Contributor

@ThomasWarford ThomasWarford commented Sep 15, 2025

Builds upon some of the tests made by @hatemhelal - a quick way to gauge CUEQ speedups.

Before adding these changes, the tests with compile_mode='default' failed. The tests with compile_mode='default' still fail.

The tests with size-9 fail due to OOM errors.

FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[None-False-float64-9]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float32-3]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float32-5]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float32-7]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float32-9]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float64-3]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float64-5]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float64-7]
FAILED ../../../../../../../home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py::test_inference[default-False-float64-9]

Successful results:

-------------------------------------------------------------------------------------------- benchmark: 23 tests ---------------------------------------------------------------------------------------------
Name (time in ms)                               Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_inference[None-False-float32-3]        27.7014 (1.0)       29.5667 (1.0)       28.7893 (1.0)      0.3778 (2.05)      28.7741 (1.0)      0.5306 (1.93)         11;1  34.7351 (1.0)          39           1
test_inference[None-True-float32-3]         33.8343 (1.22)      35.2593 (1.19)      34.4274 (1.20)     0.3798 (2.07)      34.3398 (1.19)     0.4507 (1.64)          9;0  29.0467 (0.84)         29           1
test_inference[None-True-float32-5]         33.9597 (1.23)      35.3746 (1.20)      34.4707 (1.20)     0.3798 (2.07)      34.3375 (1.19)     0.5680 (2.06)          9;0  29.0101 (0.84)         29           1
test_inference[None-True-float64-3]         34.6224 (1.25)      36.2247 (1.23)      35.1721 (1.22)     0.4141 (2.25)      35.0142 (1.22)     0.6165 (2.24)          8;0  28.4316 (0.82)         29           1
test_inference[None-True-float64-5]         35.0530 (1.27)      36.9004 (1.25)      35.6038 (1.24)     0.4151 (2.26)      35.4669 (1.23)     0.3593 (1.31)          7;3  28.0869 (0.81)         28           1
test_inference[None-False-float64-3]        35.8275 (1.29)      37.5772 (1.27)      36.7414 (1.28)     0.3697 (2.01)      36.7767 (1.28)     0.4904 (1.78)          5;0  27.2172 (0.78)         27           1
test_inference[default-True-float32-5]      38.0831 (1.37)      40.0626 (1.35)      38.7183 (1.34)     0.4924 (2.68)      38.6180 (1.34)     0.3231 (1.17)          8;3  25.8276 (0.74)         25           1
test_inference[default-True-float32-3]      38.5790 (1.39)      39.6432 (1.34)      38.9915 (1.35)     0.3194 (1.74)      38.9113 (1.35)     0.5700 (2.07)         10;0  25.6466 (0.74)         25           1
test_inference[default-True-float64-5]      38.9013 (1.40)      40.7125 (1.38)      39.5701 (1.37)     0.4651 (2.53)      39.3596 (1.37)     0.5400 (1.96)          8;1  25.2716 (0.73)         25           1
test_inference[None-True-float32-7]         39.0070 (1.41)      40.4221 (1.37)      39.4768 (1.37)     0.3889 (2.11)      39.3502 (1.37)     0.2962 (1.08)          7;3  25.3313 (0.73)         26           1
test_inference[default-True-float64-3]      39.0108 (1.41)      40.4404 (1.37)      39.4829 (1.37)     0.4299 (2.34)      39.3444 (1.37)     0.6518 (2.37)          7;0  25.3274 (0.73)         25           1
test_inference[default-True-float32-7]      42.9031 (1.55)      44.4469 (1.50)      43.7136 (1.52)     0.3679 (2.00)      43.7452 (1.52)     0.4087 (1.49)          6;0  22.8762 (0.66)         23           1
test_inference[None-True-float64-7]         53.7957 (1.94)      55.4345 (1.87)      54.2767 (1.89)     0.4749 (2.58)      54.0652 (1.88)     0.5911 (2.15)          5;1  18.4241 (0.53)         19           1
test_inference[default-True-float64-7]      59.5455 (2.15)      61.9425 (2.10)      60.8628 (2.11)     0.5706 (3.10)      60.7941 (2.11)     0.5755 (2.09)          6;1  16.4304 (0.47)         20           1
test_inference[None-True-float32-9]         64.2816 (2.32)      65.5146 (2.22)      64.8037 (2.25)     0.3707 (2.02)      64.7167 (2.25)     0.5260 (1.91)          5;0  15.4312 (0.44)         16           1
test_inference[default-True-float32-9]      70.8785 (2.56)      72.6247 (2.46)      71.4664 (2.48)     0.4474 (2.43)      71.4312 (2.48)     0.5766 (2.10)          5;1  13.9926 (0.40)         18           1
test_inference[None-False-float32-5]        87.3152 (3.15)      88.3607 (2.99)      87.8746 (3.05)     0.3117 (1.70)      87.8210 (3.05)     0.4349 (1.58)          5;0  11.3799 (0.33)         15           1
test_inference[None-True-float64-9]         99.3786 (3.59)     101.1536 (3.42)      99.9854 (3.47)     0.5816 (3.16)      99.7814 (3.47)     0.8345 (3.03)          4;0  10.0015 (0.29)         11           1
test_inference[default-True-float64-9]     110.3348 (3.98)     112.9138 (3.82)     111.7370 (3.88)     0.7625 (4.15)     111.8403 (3.89)     0.8639 (3.14)          5;0   8.9496 (0.26)         14           1
test_inference[None-False-float64-5]       135.2081 (4.88)     136.3062 (4.61)     135.7444 (4.72)     0.4129 (2.25)     135.7572 (4.72)     0.7787 (2.83)          3;0   7.3668 (0.21)         10           1
test_inference[None-False-float32-7]       232.1830 (8.38)     232.7591 (7.87)     232.4950 (8.08)     0.1839 (1.0)      232.5885 (8.08)     0.2752 (1.0)           2;0   4.3012 (0.12)          9           1
test_inference[None-False-float64-7]       366.0413 (13.21)    369.6881 (12.50)    367.5816 (12.77)    1.2717 (6.92)     367.2931 (12.76)    2.0629 (7.50)          2;0   2.7205 (0.08)          8           1
test_inference[None-False-float32-9]       488.5533 (17.64)    490.0517 (16.57)    489.0649 (16.99)    0.4914 (2.67)     489.0746 (17.00)    0.6100 (2.22)          2;0   2.0447 (0.06)          8           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
image

@ThomasWarford ThomasWarford marked this pull request as ready for review September 15, 2025 21:19
@hatemhelal
Copy link
Collaborator

Could you share a stacktrace for one of the failing compile test cases? I thought this was working and can you double check if you have the latest develop branch (#1170 is the PR you want)

@hatemhelal
Copy link
Collaborator

I realised this was unclear: you will need PR #1175 as well to combine cueq with compile but wanted to check if you are seeing compilation fail without cueq?

@ThomasWarford
Copy link
Contributor Author

The error persists with the latest commits from develop. The errors are only for enable_cueq=False and `compile_mode='default' - this seems unexpected. The *_zeroed keys seem to be causing the problem.

___________________ test_inference[default-False-float64-5] ____________________

benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x400918d88170>
size = 5, dtype = 'float64', enable_cueq = False, compile_mode = 'default'
device = 'cuda'

    @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
    @pytest.mark.benchmark(warmup=True, warmup_iterations=4, min_rounds=8)
    @pytest.mark.parametrize("size", (3, 5, 7, 9))
    @pytest.mark.parametrize("dtype", ["float32", "float64"])
    @pytest.mark.parametrize("enable_cueq", [False, True])
    @pytest.mark.parametrize("compile_mode", [None, "default"])
    def test_inference(
        benchmark, size: int, dtype: str, enable_cueq: bool, compile_mode: Optional[str],device: str = "cuda"
    ):
        if not is_mace_full_bench() and compile_mode is not None:
            pytest.skip("Skipping long running benchmark, set MACE_FULL_BENCH=1 to execute")
    
        with torch_tools.default_dtype(dtype):
>           model = load_mace_mp_medium(dtype, enable_cueq, compile_mode, device)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

/home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py:43: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/home/s5f/twarf.s5f/maces/mace/tests/test_benchmark.py:56: in load_mace_mp_medium
    calc = mace_mp(
/home/s5f/twarf.s5f/maces/mace/mace/calculators/foundations_models.py:166: in mace_mp
    mace_calc = MACECalculator(
/home/s5f/twarf.s5f/maces/mace/mace/calculators/mace.py:225: in __init__
    prepare(extract_model)(model=model, map_location=device),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/maces/mace/mace/tools/compile.py:45: in wrapper
    model = func(*args, **kwargs)
            ^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/maces/mace/mace/tools/scripts_utils.py:426: in extract_model
    model_copy.load_state_dict(model.state_dict())
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = ScaleShiftMACE(
  (node_embedding): LinearNodeEmbeddingBlock(
    (linear): Linear(89x0e -> 128x0e | 11392 weights)
  ...(linear_2): Linear(16x0e -> 1x0e | 16 weights)
    )
  )
  (scale_shift): ScaleShiftBlock(scale=0.8042, shift=0.1641)
)
state_dict = OrderedDict({'atomic_numbers': tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
       ...'cuda:0'), 'scale_shift.scale': tensor(0.8042, device='cuda:0'), 'scale_shift.shift': tensor(0.1641, device='cuda:0')})
strict = True, assign = False

    def load_state_dict(
        self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
    ):
        r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
    
        If :attr:`strict` is ``True``, then
        the keys of :attr:`state_dict` must exactly match the keys returned
        by this module's :meth:`~torch.nn.Module.state_dict` function.
    
        .. warning::
            If :attr:`assign` is ``True`` the optimizer must be created after
            the call to :attr:`load_state_dict` unless
            :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.
    
        Args:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            strict (bool, optional): whether to strictly enforce that the keys
                in :attr:`state_dict` match the keys returned by this module's
                :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
            assign (bool, optional): When set to ``False``, the properties of the tensors
                in the current module are preserved whereas setting it to ``True`` preserves
                properties of the Tensors in the state dict. The only
                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
                for which the value from the module is preserved.
                Default: ``False``
    
        Returns:
            ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
                * **missing_keys** is a list of str containing any keys that are expected
                    by this module but missing from the provided ``state_dict``.
                * **unexpected_keys** is a list of str containing the keys that are not
                    expected by this module but present in the provided ``state_dict``.
    
        Note:
            If a parameter or buffer is registered as ``None`` and its corresponding key
            exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
            ``RuntimeError``.
        """
        if not isinstance(state_dict, Mapping):
            raise TypeError(
                f"Expected state_dict to be dict-like, got {type(state_dict)}."
            )
    
        missing_keys: List[str] = []
        unexpected_keys: List[str] = []
        error_msgs: List[str] = []
    
        # copy state_dict so _load_from_state_dict can modify it
        metadata = getattr(state_dict, "_metadata", None)
        state_dict = OrderedDict(state_dict)
        if metadata is not None:
            # mypy isn't aware that "_metadata" exists in state_dict
            state_dict._metadata = metadata  # type: ignore[attr-defined]
    
        def load(module, local_state_dict, prefix=""):
            local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
            if assign:
                local_metadata["assign_to_params_buffers"] = assign
            module._load_from_state_dict(
                local_state_dict,
                prefix,
                local_metadata,
                True,
                missing_keys,
                unexpected_keys,
                error_msgs,
            )
            for name, child in module._modules.items():
                if child is not None:
                    child_prefix = prefix + name + "."
                    child_state_dict = {
                        k: v
                        for k, v in local_state_dict.items()
                        if k.startswith(child_prefix)
                    }
                    load(child, child_state_dict, child_prefix)  # noqa: F821
    
            # Note that the hook can modify missing_keys and unexpected_keys.
            incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
            for hook in module._load_state_dict_post_hooks.values():
                out = hook(module, incompatible_keys)
                assert out is None, (
                    "Hooks registered with ``register_load_state_dict_post_hook`` are not"
                    "expected to return new values, if incompatible_keys need to be modified,"
                    "it should be done inplace."
                )
    
        load(self, state_dict)
        del load
    
        if strict:
            if len(unexpected_keys) > 0:
                error_msgs.insert(
                    0,
                    "Unexpected key(s) in state_dict: {}. ".format(
                        ", ".join(f'"{k}"' for k in unexpected_keys)
                    ),
                )
            if len(missing_keys) > 0:
                error_msgs.insert(
                    0,
                    "Missing key(s) in state_dict: {}. ".format(
                        ", ".join(f'"{k}"' for k in missing_keys)
                    ),
                )
    
        if len(error_msgs) > 0:
>           raise RuntimeError(
                "Error(s) in loading state_dict for {}:\n\t{}".format(
                    self.__class__.__name__, "\n\t".join(error_msgs)
                )
            )
E           RuntimeError: Error(s) in loading state_dict for ScaleShiftMACE:
E           	Missing key(s) in state_dict: "products.0.symmetric_contractions.contractions.0.weights_0_zeroed", "products.0.symmetric_contractions.contractions.0.weights_1_zeroed", "products.0.symmetric_contractions.contractions.0.weights_max_zeroed", "products.0.symmetric_contractions.contractions.1.weights_0_zeroed", "products.0.symmetric_contractions.contractions.1.weights_1_zeroed", "products.0.symmetric_contractions.contractions.1.weights_max_zeroed", "products.1.symmetric_contractions.contractions.0.weights_0_zeroed", "products.1.symmetric_contractions.contractions.0.weights_1_zeroed", "products.1.symmetric_contractions.contractions.0.weights_max_zeroed".

../../miniforge3/envs/gpu/lib/python3.12/site-packages/torch/nn/modules/module.py:2581: RuntimeError
----------------------------- Captured stdout call -----------------------------
Using Materials Project MACE for MACECalculator with /home/s5f/twarf.s5f/.cache/mace/20231203mace128L1_epoch199model
Using float64 for MACECalculator, which is slower but more accurate. Recommended for geometry optimization.
Torch compile is enabled with mode: default

@ThomasWarford
Copy link
Contributor Author

I can wait on that PR before investigating further, if that would make sense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants