Skip to content

Conversation

@hatemhelal
Copy link
Collaborator

@hatemhelal hatemhelal commented Sep 12, 2025

This PR provides a simple fix to enable the use of torch.compile in conjunction with cueq.

The main change is to provide a tighter scope for the symbolic tracing that was previously applied to the NonLinearReadoutBlock and instead just simplify the activation constructor.

Also updated tests to parameterise over enabling cueq

@hatemhelal
Copy link
Collaborator Author

hatemhelal commented Sep 12, 2025

@ilyes319 , still more work to do for compilation but I wanted to check that the choices for the methods here look ok to you.

@ilyes319
Copy link
Contributor

mmm good question @mariogeiger, is this the intended settings?

Copy link
Contributor

@mariogeiger mariogeiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I answer the questions @ilyes319 ?

shared_weights=shared_weights,
internal_weights=internal_weights,
use_fallback=True,
method="naive",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good.

Just a question, is this for the skipTP? is one of the input always a one-hot vector? If so, why not indexing the weights instead of contract with a one-hot?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah indeed, does cueq provide the option to pass the weight in the forward?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ilyes319, since 0.6.0 we support indexing in the linear: you just need to provide the number of weight_classes at init and pass the weight_indices in the fwd.
If your indices are sorted you can also use the experimental "indexed_linear" method, but even in the "naive" case this would first index into the weights, then compute a linear for each element, which should still be faster.

@ilyes319
Copy link
Contributor

@mariogeiger Yes thank you!!!

@hatemhelal
Copy link
Collaborator Author

hatemhelal commented Sep 15, 2025

The latest commit on this PR works in my limited testing but really just changes the scope of the workaround of using symbolic tracing.

@mariogeiger would be good to have your take on this - the cueq torch ops perform a number of runtime checks on input tensors. This is obviously a great thing for developer UX but breaks the symbolic tracing that is currently used in MACE. This is a broader design choice since having runtime errors that introduce side-effects that are incompatible with the static control flow requirement for symbolic tracing. A couple of choices are possible here:

  • make runtime checks optional. by default they would be on but compilation in mace could toggle them off
  • functionalize the runtime checks (something like jax's checkify) but unsure how well or even if it is supported in the torch compiler stack
  • figure out a path to not using symbolic tracing in mace

Keen to hear thought or suggestions!

@mariogeiger
Copy link
Contributor

Since you are using the argument "method=..." you should maybe update the version in your setup.cfg

cueq = cuequivariance-torch>=0.2.0
cueq-cuda-11 = cuequivariance-ops-torch-cu11>=0.2.0
cueq-cuda-12 = cuequivariance-ops-torch-cu12>=0.2.0

->

cueq = cuequivariance-torch>=0.6.0
cueq-cuda-11 = cuequivariance-ops-torch-cu11>=0.6.0
cueq-cuda-12 = cuequivariance-ops-torch-cu12>=0.6.0

@phiandark
Copy link

Hi, we were discussing about this: in the past we had some of these checks with a wrapper like

if (
    not torch.jit.is_scripting()
    and not torch.jit.is_tracing()
    and not torch.compiler.is_compiling()
):
 ...

Would something like this, with the appropriate mode, work to disable the checks for compiled models?

Also, can you tell me exactly what kind of compilation command you're trying to enable here? Because we're testing some like torch.compile and torch.jit.script, but maybe we're missing this, or the right flag to catch the failure?

@hatemhelal
Copy link
Collaborator Author

Also, can you tell me exactly what kind of compilation command you're trying to enable here? Because we're testing some like torch.compile and torch.jit.script, but maybe we're missing this, or the right flag to catch the failure?

At the moment mace is first passed through torch.fx.symbolic_trace (see here) which uses abstract tensor types to capture the compute graph.

@ilyes319
Copy link
Contributor

Hey @hatemhelal, curious what the status of the PR? should it be merged, is there anything else to work on?

I can think about the following point to use that most effectively in MD:

  • Do a bit of padding in the calc (ASE and MLIAP) creating fake structures in the batch to optimise the recompilation.
  • Look in details at the best compilation modes to recommend the best settings to people.
  • Do some benchmarking in MLIAP with different systems sizes for MD.

@hatemhelal
Copy link
Collaborator Author

@ilyes319 think this should be ok to merge, the follow on discussion is a tangential direction for making this more robust. Also recall that @ThomasWarford was looking to add benchmarks for this in #1184.

Agree with all the follow on points you suggest, just not sure when I might get to do them so can't commit at the moment!

@ilyes319 ilyes319 merged commit aa9b6e5 into ACEsuit:develop Oct 28, 2025
8 checks passed
@ilyes319
Copy link
Contributor

ok I merged it for now, @ThomasWarford tell me if you are up to try implementing these.

@ThomasWarford
Copy link
Contributor

ThomasWarford commented Oct 31, 2025

I tried implementing it today and I get errors running test_compile.py.

The first was fixed by module loading gcc.

Here's the first other error:

============================= test session starts ==============================
platform linux -- Python 3.12.0, pytest-8.4.2, pluggy-1.6.0 -- /home/s5f/twarf.s5f/.local/share/mamba/envs/mace/bin/python3.12
cachedir: .pytest_cache
benchmark: 5.2.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace
configfile: pyproject.toml
plugins: benchmark-5.2.0
collecting ... collected 22 items

../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_mace[fp32-cpu] PASSED [  4%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_mace[fp32-cuda] PASSED [  9%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_mace[fp64-cpu] PASSED [ 13%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_mace[fp64-cuda] PASSED [ 18%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_eager_benchmark[fp32-True] PASSED [ 22%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_eager_benchmark[fp32-False] PASSED [ 27%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_eager_benchmark[fp64-True] PASSED [ 31%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_eager_benchmark[fp64-False] PASSED [ 36%]
../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_compile_benchmark[False-fp32-default] FAILED [ 40%]

=================================== FAILURES ===================================
__________________ test_compile_benchmark[False-fp32-default] __________________

benchmark = <pytest_benchmark.fixture.BenchmarkFixture object at 0x400c0a8cb3e0>
compile_mode = 'default', enable_amp = False, enable_cueq = False

    @pytest.mark.skipif(os.name == "nt", reason="Not supported on Windows")
    @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda is not available")
    @pytest.mark.parametrize("compile_mode", ["default", "reduce-overhead", "max-autotune"])
    @pytest.mark.parametrize("enable_amp", [False, True], ids=["fp32", "mixed"])
    @pytest.mark.parametrize("enable_cueq", [False, True])
    def test_compile_benchmark(benchmark, compile_mode, enable_amp, enable_cueq):
        with tools.torch_tools.default_dtype(torch.float32):
            batch = create_batch("cuda")
            torch.compiler.reset()
            model = mace_compile.prepare(create_mace)("cuda", enable_cueq=enable_cueq)
            model = torch.compile(model, mode=compile_mode)
            model = time_func(model)
    
            with torch.autocast("cuda", enabled=enable_amp):
>               benchmark(model, batch, training=True)

/scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py:155: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/pytest_benchmark/fixture.py:179: in __call__
    return self._raw(function_to_benchmark, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/pytest_benchmark/fixture.py:211: in _raw
    duration, iterations, loops_range = self._calibrate_timer(runner)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/pytest_benchmark/fixture.py:365: in _calibrate_timer
    duration = runner(loops_range)
               ^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/pytest_benchmark/fixture.py:133: in runner
    function_to_benchmark(*args, **kwargs)
/scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py:95: in wrapper
    outputs = func(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:414: in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:832: in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/nn/modules/module.py:1775: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/nn/modules/module.py:1786: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1874: in __call__
    result = self._torchdynamo_orig_backend(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1624: in __call__
    result = self._inner_convert(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:688: in __call__
    result = _compile(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1433: in _compile
    guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_utils_internal.py:92: in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1117: in compile_inner
    return _compile_inner(code, one_graph, hooks)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1151: in _compile_inner
    dynamo_output = compile_frame(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1032: in compile_frame
    bytecode, tracer_output = transform_code_object(code, transform)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1592: in transform_code_object
    tracer_output = transformations(instructions, code_options)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1004: in transform
    tracer_output = trace_frame(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:312: in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:815: in trace_frame
    run_tracer()
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:797: in run_tracer
    tracer.run()
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1487: in run
    while self.step():
          ^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1348: in step
    self.dispatch_table[inst.opcode](self, inst)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:904: in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3411: in CALL
    self._call(inst)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3405: in _call
    self.call_function(fn, args, kwargs)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1266: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:212: in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:598: in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:342: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1288: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:4112: in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:4315: in inline_call_
    self.run()
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1487: in run
    while self.step():
          ^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1348: in step
    self.dispatch_table[inst.opcode](self, inst)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:904: in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3411: in CALL
    self._call(inst)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3405: in _call
    self.call_function(fn, args, kwargs)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1266: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:212: in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:598: in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:342: in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1288: in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:4112: in inline_call
    return tracer.inline_call_()
           ^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:4315: in inline_call_
    self.run()
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1487: in run
    while self.step():
          ^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1348: in step
    self.dispatch_table[inst.opcode](self, inst)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:904: in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3411: in CALL
    self._call(inst)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3405: in _call
    self.call_function(fn, args, kwargs)
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1266: in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:212: in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py:1516: in call_function
    tensor_variable = wrap_fx_proxy(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2645: in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2711: in wrap_fx_proxy_cls
    return _wrap_fx_proxy(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2809: in _wrap_fx_proxy
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/utils.py:3478: in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/utils.py:3376: in get_fake_value
    ret_val = wrap_fake_exception(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/utils.py:2864: in wrap_fake_exception
    return fn()
           ^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/utils.py:3377: in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/utils.py:3587: in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_dynamo/utils.py:3546: in run_node
    return node.target(*args, **kwargs)  # type: ignore[operator]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/autograd/__init__.py:503: in grad
    result = _engine_run_backward(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

t_outputs = (FakeTensor(..., device='cuda:0', size=(1,), grad_fn=<ScatterAddBackward0>),)
args = ((FakeTensor(..., device='cuda:0', size=(1,)),), True, True, (FakeTensor(..., device='cuda:0', size=(64, 3)),), True)
kwargs = {'accumulate_grad': False}, attach_logging_hooks = False

    def _engine_run_backward(
        t_outputs: Sequence[Union[torch.Tensor, GradientEdge]],
        *args: Any,
        **kwargs: Any,
    ) -> tuple[torch.Tensor, ...]:
        attach_logging_hooks = log.getEffectiveLevel() <= logging.DEBUG
        if attach_logging_hooks:
            unregister_hooks = _register_logging_hooks_on_whole_graph(t_outputs)
        try:
>           return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
                t_outputs, *args, **kwargs
            )  # Calls into the C++ engine to run the backward pass
E           torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function grad at 0x4004798f19e0>(*(), **{'outputs': [FakeTensor(..., device='cuda:0', size=(1,), grad_fn=<ScatterAddBackward0>)], 'inputs': [FakeTensor(..., device='cuda:0', size=(64, 3))], 'grad_outputs': [FakeTensor(..., device='cuda:0', size=(1,))], 'retain_graph': True, 'create_graph': True, 'allow_unused': True}): got RuntimeError('One of the differentiated Tensors does not require grad')
E           
E           from user code:
E              File "/scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/mace/modules/models.py", line 574, in forward
E               forces, virials, stress, hessian, edge_forces = get_outputs(
E             File "/scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/mace/modules/utils.py", line 195, in get_outputs
E               compute_forces(
E             File "/scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/mace/modules/utils.py", line 26, in compute_forces
E               gradient = torch.autograd.grad(

/home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/autograd/graph.py:841: TorchRuntimeError
----------------------------- Captured stdout call -----------------------------
Number of atoms 64
=============================== warnings summary ===============================
../../../../../../../../../home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/e3nn/o3/_wigner.py:10
  /home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/e3nn/o3/_wigner.py:10: UserWarning: Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.
    _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))

tests/test_compile.py: 84 warnings
  /home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
    warnings.warn(

tests/test_compile.py::test_mace[fp32-cuda]
  /home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/backends/cuda/__init__.py:131: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)
    return torch._C._get_cublas_allow_tf32()

tests/test_compile.py::test_mace[fp32-cuda]
  /home/s5f/twarf.s5f/.local/share/mamba/envs/mace/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:312: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html

----------------------------------------------------------------------------------------- benchmark: 4 tests ----------------------------------------------------------------------------------------
Name (time in ms)                        Min                Max               Mean             StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_eager_benchmark[fp32-False]     26.3279 (1.0)      83.6342 (2.44)     32.2343 (1.0)      15.7875 (29.95)    26.8318 (1.0)      0.9573 (1.74)          2;2  31.0229 (1.0)          19           1
test_eager_benchmark[fp64-False]     26.5528 (1.01)     84.0796 (2.45)     32.8175 (1.02)     16.2634 (30.85)    27.3286 (1.02)     1.1425 (2.08)          2;2  30.4715 (0.98)         18           1
test_eager_benchmark[fp32-True]      33.1412 (1.26)     34.3125 (1.0)      33.6399 (1.04)      0.5272 (1.0)      33.4487 (1.25)     0.9415 (1.71)          1;0  29.7266 (0.96)          5           1
test_eager_benchmark[fp64-True]      33.5426 (1.27)     35.2076 (1.03)     33.9590 (1.05)      0.7035 (1.33)     33.6642 (1.25)     0.5501 (1.0)           1;1  29.4473 (0.95)          5           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
=========================== short test summary info ============================
FAILED ../../../../../../../../../scratch/s5f/twarf.s5f/old_home/twarf.s5f/maces/mace/tests/test_compile.py::test_compile_benchmark[False-fp32-default]
!!!!!!!!!!!!!!!!!!!!!!!!!! stopping after 1 failures !!!!!!!!!!!!!!!!!!!!!!!!!!!
============= 1 failed, 8 passed, 87 warnings in 276.43s (0:04:36) =============

I believe #1184 should work if this is resolved, but I do still get another error with keys (which I posted there), rather than this error.

This is python version 3.12, CUDA 12.6, PyTorch 2.9.0+cu126

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants