Skip to content

HFooladi/molax

Repository files navigation

molax

CI Docs License: MIT Python JAX Flax Code style: ruff

High-performance molecular active learning with JAX. Built with Flax NNX (the modern Flax API) and jraph for efficient graph batching, achieving ~400x speedup over naive implementations.

Documentation | API Reference

Installation

# Using uv (recommended)
git clone https://github.com/HFooladi/molax
cd molax
uv pip install -e .

# Or with pip
pip install -e .

Quick Start

from molax.utils.data import MolecularDataset
from molax.models.gcn import GCNConfig, UncertaintyGCN
from flax import nnx
import jraph

# Load and batch data
dataset = MolecularDataset('datasets/esol.csv')
train_data, test_data = dataset.split(test_size=0.2, seed=42)
train_graphs = jraph.batch(train_data.graphs)

# Create model with uncertainty
config = GCNConfig(node_features=6, hidden_features=[64, 64], out_features=1)
model = UncertaintyGCN(config, rngs=nnx.Rngs(0))

# Get predictions with uncertainty
mean, variance = model(train_graphs, training=True)

See the Core Concepts guide for the batch-once-then-mask pattern that enables the 400x speedup.

Features

  • Multiple uncertainty methods: MC Dropout, Deep Ensembles, Evidential Deep Learning
  • Calibration metrics: ECE, calibration curves, reliability diagrams
  • Acquisition functions: Uncertainty sampling, diversity sampling, combined strategies
  • GPU-accelerated: Full JAX/Flax NNX integration with JIT compilation

Citation

@software{molax2025,
  title={molax: Molecular Active Learning with JAX},
  author={Hosein Fooladi},
  year={2025},
  url={https://github.com/hfooladi/molax}
}

License

MIT License