Skip to content

[NeurIPS2025 Spotlight CoLT] The conditional localization test for assessing the accuracy of neural posterior estimates — Official PyTorch Implementation

Notifications You must be signed in to change notification settings

TianyuCodings/CoLT

Repository files navigation

[NeurIPS 2025 Spotlight] CoLT — Official PyTorch Implementation

CoLT: The conditional localization test for assessing the accuracy of neural posterior estimates
Tianyu Chen*, Vansh Bansal*, and James G. Scott

Abstract: We consider the problem of validating whether a neural posterior estimate $q(\theta \mid x)$ is an accurate approximation to the true, unknown true posterior $p(\theta \mid x)$. Existing methods for evaluating the quality of an NPE estimate are largely derived from classifier-based tests or divergence measures, but these suffer from several practical drawbacks. As an alternative, we introduce the Conditional Localization Test (CoLT), a principled method designed to detect discrepancies between $p(\theta \mid x)$ and $q(\theta \mid x)$ across the full range of conditioning inputs. Rather than relying on exhaustive comparisons or density estimation at every $x$, CoLT learns a localization function that adaptively selects points $\theta_l(x)$ where the neural posterior $q$ deviates most strongly from the true posterior $p$ for that $x$. This approach is particularly advantageous in typical simulation-based inference settings, where only a single draw $\theta \sim p(\theta \mid x)$ from the true posterior is observed for each conditioning input, but where the neural posterior $q(\theta \mid x)$ can be sampled an arbitrary number of times. Our theoretical results establish necessary and sufficient conditions for assessing distributional equality across all $x$, offering both rigorous guarantees and practical scalability. Empirically, we demonstrate that CoLT not only performs better than existing methods at comparing $p$ and $q$, but also pinpoints regions of significant divergence, providing actionable insights for model refinement. These properties position CoLT as a state-of-the-art solution for validating neural posterior estimates.

True posterior (tree example) alpha = 1.5 (tree example) alpha = 4 (tree example) Power curve across alpha (tree example)

Figure. Results on the toy tree-shaped example. As perturbation level α increases, the distribution becomes blurrier and deviates from the true manifold shown in Panel A. CoLT with a learned metric embedding maintains strong statistical power even for modest perturbations (Panel 2), whereas the C2ST, SBC, and TARP all perform poorly even for much larger ones like α = 4 (Panels 3/4).

Introduction

This repository contains code for CoLT diagnostics of neural posterior estimates, with comparisons to common baselines (TARP, SBC, C2ST). The main entrypoint train.py samples synthetic SBI problems, trains the CoLT localization components, and evaluates against baselines.

Experiments

Requirements

We provide a conda environment file:

conda env create -f environment.yml
conda activate colt

Core dependencies include PyTorch, NumPy/SciPy, Matplotlib, Click, geomloss (Sinkhorn), tqdm, and pyro-ppl (Student-t tails in data/approximations.py).

Running

Minimal single-file demo (toy misspecification). All imporant components are already included in this single file. Feel free to only see this file if you only want to test your own data:

python minimal_colt_demo.py --n_dim 5 --n_sim 256 --n_posterior 64 --epochs 200 --dist sinkhorn --out_dir results_minimal_colt

Main experiment runner (train.py) example. These files are used to replicate main experimens in the paper, together with other baseline methods.

python train.py \
  --output_dir results \
  --device cpu \
  --epochs 200 \
  --c2st_epochs 200 \
  --n_sim 256 \
  --n_posterior_per_x 64 \
  --n_eval 20 \
  --x_dim 10 \
  --theta_dim 10 \
  --approx_sampling_mathod perturbed_var \
  --alpha 0.2

Sweep scripts for larger grids live in run_bash_curve/ and run_bash_non_curve/ (they assume multi-GPU and paths like /scratch/...; edit as needed).

Code Organization

  • train.py: main experiment runner (samples data, trains CoLT + baselines, evaluates).
  • training/: CoLT implementation (loss, networks, training loop).
  • baselines/: TARP, SBC, C2ST.
  • data/: synthetic priors/posteriors/approximate posteriors + sampler utilities.
  • minimal_colt_demo.py: minimal CoLT demo script.

Perturbation Methods

Change the perturbation types by reaplace the value accroding to the following table

python train.py \
  --approx_sampling_mathod perturbed_var

Table: Six types of perturbations used to assess the sensitivity of CoLT

$p(\theta \mid x)$ $q(\theta \mid x)$ Explanation approx_sampling_mathod
$\mathcal{N}(\mu_x, \Sigma_x)$ $\mathcal{N}((1+\alpha)\mu_x, \Sigma_x)$ Mean Shift: Introduces a systematic bias by shifting the mean. perturb_mean
$\mathcal{N}(\mu_x, \Sigma_x)$ $\mathcal{N}(\mu_x, (1+\alpha)\Sigma_x)$ Covariance Scaling: Uniformly inflates the variance. perturb_var
$\mathcal{N}(\mu_x, \Sigma_x)$ $\mathcal{N}(\mu_x, \Sigma_x + \alpha \Delta)$ Anisotropic Covariance Perturbation: Adds variability along the minimum-variance eigenvector of $\Sigma_x$: $\Delta = \mathbf{v}{\min} \mathbf{v}{\min}^\top$. distort_var
$\mathcal{N}(\mu_x, \Sigma_x)$ $t_\nu(\mu_x, \Sigma_x)$ Tail Adjustment via $t$-Distribution: Introduces heavier tails, with degrees of freedom $\nu = 1/(\alpha + \epsilon)$, approaching Gaussian as $\alpha \to 0$. tail
$\mathcal{N}(\mu_x, \Sigma_x)$ $(1 - \alpha)\mathcal{N}(\mu_x, \Sigma_x) + \alpha, \mathcal{N}(-\mu_x, \Sigma_x)$ Additional Modes: $q$ introduces spurious multimodality. mixture
$(1 - \alpha)\mathcal{N}(\mu_x, \Sigma_x) + \alpha, \mathcal{N}(-\mu_x, \Sigma_x)$ $\mathcal{N}(\mu_x, \Sigma_x)$ Mode Collapse: $q$ loses multi-modal structure. collapse

Citation

If you find this code useful, please cite:

@article{chen2025colt,
  title={CoLT: The conditional localization test for assessing the accuracy of neural posterior estimates},
  author={Chen, Tianyu and Bansal, Vansh and Scott, James G},
  journal={arXiv preprint arXiv:2507.17030},
  year={2025}
}

About

[NeurIPS2025 Spotlight CoLT] The conditional localization test for assessing the accuracy of neural posterior estimates — Official PyTorch Implementation

Resources

Stars

Watchers

Forks

Packages

No packages published