Benchmarking Backends for Reinforcement Learning: Flax.NNX (JAX) vs Flax.Linen (JAX) vs PyTorch
In this project I present high quality implentations of Proximal Policy Optimization (PPO) in multiple frameworks, while using GPU environments implemented in JAX from the Gymnax repository.
I compare the performance of the frameworks and analyse how their performance differs.
The purpose of this repository is to inform the choice of framework for new RL projects, including my own. It is also a good starting point for starting your own repository in a new framework if you haven't used it before, being able to compare it to other frameworks you might be more familiar with.
Overall, Linen is the fastest, NNX marginally slower, and Torch is much slower. In all cases, the effect is less prenounced with larger models and more complex environments. Note that the most complex environments in this experiment are still small and fast. Overall, if the RL application requires large models and the environment takes a lot of time to compute (even more so for CPU environments), the choice of framework should matter less. Also, NNX only has larger overheads than Linen and their performance is almost identical for most of training. NNX and Linen should therefore perform almost identically in practical applications.
- Overall metrics:
- For small models, Linen is 1.2x faster than NNX and 6x faster than torch
- For large models, Linen is 1.05x faster than NNX and 1.7x faster than torch
- If we ignore overheads, Linen is only 1.02x faster than NNX
- Ease of use
- Torch is the easiest to use: good documentation, easy to understand code
- Linen is the hardest to use: decent documentation, but hardest to code in. Anything that needs to be efficient and compiled with jax.jit must be fully functional and not stateful. It also requires specifc patterns, like having separate model parameters from the model structure. It also requires the use of efficient jax control flow like instead of native python if/for control flow, to achieve GPU efficiency. Although this is more difficult to code in, it's necessary for efficiency and I like it.
- NNX: bad documentation, but allows for a nicer stateful object oriented approach than linen. NNX still requires special control flow like linen, but that's necessary for efficiency. NNX is still in beta though, and that can really be felt with missing documentation.
Out of these frameworks, I am tempted to use NNX in my next personal project. Over time, as it matures, it has the best potential in my eyes. Although I want to add a comparison to Equinox before making a decision.
Note that Linen and NNX have somewhat of an unfair advantage in this comparison, because the environment is run in JAX on GYMNAX on the GPU. This means that we need to transfer tensors from JAX to Torch using DLPack, which could contribute to the slowdown. However, if we used a CPU environment, the overhead from that would most likely completely overshadow the performance of these frameworks. In other words, implementing environments that can run on the GPU is likely even more important than choice of framework.
The charts shown visualise the distribution and a scatter plot of the individual results.
Each scatter plot has a small x axis which shows the environment's complexity. This complexity is calculated as a benchmarked amount of time it takes to perform a number of steps of the environment. We then divide the environment's step time by the fastest environment, to get 1x for the fastest environment and 30.4x for the slowest.
The model size is distinguished using the marker used, the shade of the colour of the marker, and visualy separated along the x axis into separate groups (s/m/l).
The distribution is calculated across all the environments and all the model sizes shown in the scatter.
In the first experiment we time the total duration of training. We vary the model size (small, medium, large model) and the environment that we are training on. See the Method section for details of the run.

Key observations from the experiment:
- Overall Framework performance: (quantified in next chart)
- Linen is the fastest
- NNX is marginally slower
- Torch is much slower
- Model size influence:
- Positive correlation between model size and total duration: Larger models take longer to run.
- Torch exhibits a smaller proportional increase in total duration as model size grows compared to Linen/NNX. This suggests the presence of larger fixed overheads in Torch.
- Environment complexity influence:
- Positive correlation between env complexity and total duration: Complex environments take longer to run.
- Torch exhibits a smaller proportional increase in total duration as environment complexity grows to Linen/NNX. This also suggests the presence of largerfixed overheads in Torch.
To quantify differences between frameworks, we plot the speedup when moving from one framework to another. For each model and environment, we compute speedups as the ratio of their total duration, for all pairwise combinations of individual runs between frameworks (3x3 comparisons for three repeats).
Let's quantify the values seen in this chart, let's calculate the geometric mean and median values for each model size and framework pair in this chart:
| Geometric Mean Speedup | Small | Medium | Large |
| NNX → Linen | 1.25x | 1.21x | 1.05x |
| Torch → Linen | 5.88x | 4.62x | 1.69x |
| Torch → NNX | 4.70x | 3.81x | 1.60x |
| Median Speedup | Small | Medium | Large |
| NNX → Linen | 1.20x | 1.16x | 1.04x |
| Torch → Linen | 6.11x | 4.77x | 1.69x |
| Torch → NNX | 5.07x | 4.11x | 1.62x |
Key Observations:
- Model size influence:
- Larger models lead to a smaller speedup (closer to 1x) that implies more similar performance across all frameworks with alrger models. This suggests that all frameworks perform their model calculations with more similar efficiency, and the differences in the framework performance comes from elsewhere.
- Environment complexity influence:
- For NNX->Linen, we observe a larger speedup with more complex environments
- For Torch->Linen and Torch->NNX, we observe a smaller speedup with more complex environments. Moreover, on large models, the effect of environment complexity is weaker.
- This inconsistency suggests that the source of the speedup is different for different frameworks. We explore this further in Experiment 2.
Let us examine the overhead in the runs from experiment 1. We compare the duration of the first iteration (index 0), to the average duration of the next 7 iterations (indices 1:7), to the average of the remaining iterations (indices 7:100).
To better inspect the overhead, let us also plot the difference between the iteration durations for all pairwise combination of runs with the same model size and environment. Let us draw this on a symlog axis that is linear below 0.1 and logarithmic above.
Key observations:
- Initial overhead in iteration
0:- The initial overhead is similarly large for both Linen and NNX.
- Torch has the smallest overhead for most environments. This could be because of the jit compilation taking a long time for Linen and NNX.
- The runs with the two most complex environments (Freeway-MinAtar and Asterix-MinAtar) have a much larger overhead than any other runs, across all frameworks. There must therefore be overhead that is related to the environments, which could also be the jax.jit operation used within the Gymnax environments.
- Larger models have a slightly larger overhead.
- Secondary overhead in iterations
1:7:- Linen has zero secondary overhead for runs with some environments, and a small overhead for other environments. The secondary overheads in Linen must therefore be exclusively environment related.
- NNX has secondary overheads that are larger with larger models, and larger environments.
- Torch has no secondary overhead and its iteration time staibilises immediately on the second iteration.
Let us also examine the relative speedup between the frameworks of the duration of iterations 7:100.
We can see the same patter to experiment 1 with one notable exception: NNX and Linen have almost identical performance. This means NNX only has more overhead than Linen and performs almost identically afterwards. As training runs get longer, which is likely in practical applications, the difference between NNX and Linen is likely to get smaller.
Next let us inspect the durations of the rollout step and the update step from a single iteration separately. Rollout is heavy on environment computation and light for model usage, while update doesn't use the environment at all and performs a lot of model computations.
To time these functions properly, we must force a synchronisation around their exection to ensure that they finished computing when the time is taken. This can reduce the overall efficiency, however, it does not interrupt operations within jitted or otherwise compiled functions, only around them.
In this experiment, we only run all environments for the small model size, while the medium and large models are only run with the Acrobot-v1 environment.
First, let us plot the average rollout and update duration from iterations 7:100.
Then let's compare the speedup between frameworks.
Key observations:
- Environment influence:
- More complex environments lead to a larger rollout duration across all frameworks.
- More complex environments lead to larger update durations only for some environments. These are the 4 most complex environments which are the Atari environments that have a much larger observation space (multiple hunderds of features instead of 6 or fewer). This means that the their models have more features, which affects the model's overall computational cost.
- Model size influence:
- The model size has a much larger effect on the update speedup than it does on rollout duration, as expected.
- Framework comparison
- NNX -> Linen, NNX is slightly slower than Linen on the Rollout Duration, while slightly faster on Update Duration. This effect is very small though. Linen and NNX perform almost identically.
Lastly, let us examine the effect of different compilation methods on the different frameworks by examining the performance after the compilations are turned off. Since the jit operations are applied directly to the rollout and update functions, let us examine their average durations for iterations 7:100.
- For Linen, we use jax.jit for the entire rollout and update functions. These functions include the gymnax environment computations. When turned off, Gymnax still automatically applies some jax.jit compilation to the model.
- NNX has been compiled using nnx.jit and nnx.cached_partial for the rollout and update steps.
- The torch model has been compiled using torch.compile, and a jax.jit around the environment explicitly.
Key observations:
- Torch is mostly unaffected by compilation for both Rollout and Update Durations. There is only a slight increase in rollout duration when the envronment jax.jit compilation is turned off.
- For NNX, when turning off nnx.cached_partial, there is a slight increase in both Rollout Duration and Update duration.
- For both NNX and Linen, when turning off jitting, the performance drops massively. Both frameworks behave amost identically, but are both are slower than Torch (with the exception of a single environment).
This suggests that the speedup observed in Linen and NNX over Torch is indeed due to the jit compilation, which allows for many small GPU operations to run very effectively within a single kernel. All of their benefit is lost when jit is not used. Moreover, since they are not designed to be used without jit, their performance drops even beyond Torch.
All previous experiments have been run on Linux (Ubuntu), with the Nvidia 508.126.09 driver inside Docker containers.
I also ran the same experiments on Windows, inside Docker containers running inside WSL2. (Although I only ran every experiment once instead of three times)
To my surprise, I found performance on Windows much worse than on Linux natively. The full set of plots on windows can be found in the results directory, while the following two plots compare the overall performance on Experiment 1:
And to quantify the comparison between the operating systems:
- We can see that Windows is slower across all frameworks, but the slowdown is stronger for Torch than it is for Linen and NNX.
- Runs with larger models are affected less by the OS across all frameworks.
- The OS slowdown varies with environments without a clear trend.
- Implement PPO in PyTorch
- Implement PPO in Jax, Flax.Linen
- Implement PPO in Jax, Flax.NNX
- Implement entrypoints and benchmarking experiments
- Test and debug everything
- Finalise Documentation and this readme
- Run all benchmarks
- Analyse results and present findings in readme
- Also compare to Equinox
- Also compare to CPU environments.
- Also compare to large and expensive GPU environments.
- Also compare to PyTorch jit








