This repository contains the code and experiments for the paper "Variational Learning Finds Flatter Solutions at the Edge of Stability" by Avrajit Ghosh et al.
📄 Paper: arXiv:2506.12903
Variational Learning (VL) has recently gained popularity for training deep neural networks and is competitive to standard learning methods. Part of its empirical success can be explained by theories such as PAC-Bayes bounds, minimum description length and marginal likelihood, but there are few tools to unravel the implicit regularization in play. Here, we analyze the implicit regularization of VL through the Edge of Stability (EoS) framework. EoS has previously been used to show that gradient descent can find flat solutions and we extend this result to VL to show that it can find even flatter solutions.
edge-of-stability/
├── src/ # Core source code
│ ├── archs.py # Network architectures
│ ├── gd.py # Gradient descent implementations
│ ├── utilities.py # Utility functions
│ └── network_stability.py # Network stability analysis
├── data/ # Data directory (gitignored)
├── figures/ # Generated figures (gitignored)
├── training_logs/ # Training logs (gitignored)
├── requirements.txt # Python dependencies
├── README.md # This file
└── .gitignore # Git ignore rules
- Python 3.8+
- CUDA (for GPU training)
- Clone the repository:
git clone git@github.com:Avra98/eos-ivon.git
cd eos-ivon- Create a virtual environment:
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate- Install dependencies:
pip install -r requirements.txt# Run EoS analysis for different networks
python src/network_stability.py --arch fc-tanh --dataset cifar10
python src/network_stability.py --arch resnet20 --dataset cifar10# Run VL experiments
python src/gd.py --method ivon --lr 0.1 --mc_samples 10# Run quadratic dynamics analysis
python quad_dyn_sigma_n.pysrc/gd.py: Main training script with VL and standard GDsrc/network_stability.py: Network stability analysisquad_dyn_sigma_n.py: Quadratic dynamics analysissrc/utilities.py: Utility functions for analysis
- Analysis of sharpness dynamics during training
- Comparison between standard GD and VL
- Network-specific critical sharpness values
- Posterior covariance analysis
- Monte Carlo sample effects
- Flatter solution finding
- Theoretical foundations
- Dynamics of VL on quadratic objectives
- Connection to EoS framework
- Flatter Solutions: VL finds flatter solutions compared to standard GD
- Posterior Covariance Control: Controlling posterior covariance affects solution flatness
- Monte Carlo Samples: Number of MC samples influences the flatness of found solutions
- Network Generalization: Results hold across different architectures (ResNet, ViT, FC networks)
The main paper figures can be reproduced by running:
# Figure 1: EoS dynamics comparison
python src/network_stability.py --generate_figures
# Figure 2: VL sharpness analysis
python src/gd.py --analysis sharpness --method ivon
# Figure 3: Quadratic dynamics
python quad_dyn_sigma_n.py --plot_dynamicsAll experiment configurations are documented in the respective script files. Key parameters:
--lr: Learning rate--mc_samples: Number of Monte Carlo samples for VL--arch: Network architecture--dataset: Dataset to use
If you use this code in your research, please cite:
@article{ghosh2024variational,
title={Variational Learning Finds Flatter Solutions at the Edge of Stability},
author={Ghosh, Avrajit and Cong, Bai and Yokota, Rio and Ravishankar, Saiprasad and Wang, Rongrong and Tao, Molei and Khan, Mohammad Emtiyaz and Möllenhoff, Thomas},
journal={arXiv preprint arXiv:2506.12903},
year={2024}
}This project is licensed under the MIT License - see the LICENSE file for details.
For questions about this repository, please contact:
- Avrajit Ghosh: GitHub
- Paper: arXiv:2506.12903
We gratefully acknowledge support from the Simons Foundation and all contributors to this work.