High-performance molecular active learning with JAX. Built with Flax NNX (the modern Flax API) and jraph for efficient graph batching, achieving ~400x speedup over naive implementations.
# Using uv (recommended)
git clone https://github.com/HFooladi/molax
cd molax
uv pip install -e .
# Or with pip
pip install -e .from molax.utils.data import MolecularDataset
from molax.models.gcn import GCNConfig, UncertaintyGCN
from flax import nnx
import jraph
# Load and batch data
dataset = MolecularDataset('datasets/esol.csv')
train_data, test_data = dataset.split(test_size=0.2, seed=42)
train_graphs = jraph.batch(train_data.graphs)
# Create model with uncertainty
config = GCNConfig(node_features=6, hidden_features=[64, 64], out_features=1)
model = UncertaintyGCN(config, rngs=nnx.Rngs(0))
# Get predictions with uncertainty
mean, variance = model(train_graphs, training=True)See the Core Concepts guide for the batch-once-then-mask pattern that enables the 400x speedup.
- Multiple uncertainty methods: MC Dropout, Deep Ensembles, Evidential Deep Learning
- Calibration metrics: ECE, calibration curves, reliability diagrams
- Acquisition functions: Uncertainty sampling, diversity sampling, combined strategies
- GPU-accelerated: Full JAX/Flax NNX integration with JIT compilation
@software{molax2025,
title={molax: Molecular Active Learning with JAX},
author={Hosein Fooladi},
year={2025},
url={https://github.com/hfooladi/molax}
}MIT License