CoLT: The conditional localization test for assessing the accuracy of neural posterior estimates
Tianyu Chen*, Vansh Bansal*, and James G. Scott
Abstract: We consider the problem of validating whether a neural posterior estimate
Figure. Results on the toy tree-shaped example. As perturbation level α increases, the distribution becomes blurrier and deviates from the true manifold shown in Panel A. CoLT with a learned metric embedding maintains strong statistical power even for modest perturbations (Panel 2), whereas the C2ST, SBC, and TARP all perform poorly even for much larger ones like α = 4 (Panels 3/4).
This repository contains code for CoLT diagnostics of neural posterior estimates, with comparisons to common baselines (TARP, SBC, C2ST). The main entrypoint train.py samples synthetic SBI problems, trains the CoLT localization components, and evaluates against baselines.
We provide a conda environment file:
conda env create -f environment.yml
conda activate coltCore dependencies include PyTorch, NumPy/SciPy, Matplotlib, Click, geomloss (Sinkhorn), tqdm, and pyro-ppl (Student-t tails in data/approximations.py).
Minimal single-file demo (toy misspecification). All imporant components are already included in this single file. Feel free to only see this file if you only want to test your own data:
python minimal_colt_demo.py --n_dim 5 --n_sim 256 --n_posterior 64 --epochs 200 --dist sinkhorn --out_dir results_minimal_coltMain experiment runner (train.py) example. These files are used to replicate main experimens in the paper, together with other baseline methods.
python train.py \
--output_dir results \
--device cpu \
--epochs 200 \
--c2st_epochs 200 \
--n_sim 256 \
--n_posterior_per_x 64 \
--n_eval 20 \
--x_dim 10 \
--theta_dim 10 \
--approx_sampling_mathod perturbed_var \
--alpha 0.2Sweep scripts for larger grids live in run_bash_curve/ and run_bash_non_curve/ (they assume multi-GPU and paths like /scratch/...; edit as needed).
train.py: main experiment runner (samples data, trains CoLT + baselines, evaluates).training/: CoLT implementation (loss, networks, training loop).baselines/: TARP, SBC, C2ST.data/: synthetic priors/posteriors/approximate posteriors + sampler utilities.minimal_colt_demo.py: minimal CoLT demo script.
Change the perturbation types by reaplace the value accroding to the following table
python train.py \
--approx_sampling_mathod perturbed_varTable: Six types of perturbations used to assess the sensitivity of CoLT
| Explanation | approx_sampling_mathod | ||
|---|---|---|---|
| Mean Shift: Introduces a systematic bias by shifting the mean. | perturb_mean | ||
| Covariance Scaling: Uniformly inflates the variance. | perturb_var | ||
|
Anisotropic Covariance Perturbation: Adds variability along the minimum-variance eigenvector of |
distort_var | ||
|
Tail Adjustment via |
tail | ||
|
Additional Modes: |
mixture | ||
|
Mode Collapse: |
collapse |
If you find this code useful, please cite:
@article{chen2025colt,
title={CoLT: The conditional localization test for assessing the accuracy of neural posterior estimates},
author={Chen, Tianyu and Bansal, Vansh and Scott, James G},
journal={arXiv preprint arXiv:2507.17030},
year={2025}
}


