Skip to content

Easy no-frills Jax implementations of common abstractions for simple diffusion models.

License

Notifications You must be signed in to change notification settings

DruvPai/DiffusionLab

Repository files navigation

DiffusionLab

DiffusionLab Logo

pip install diffusionlabllms.txt

PyPI version License: MIT Python 3.13+

A no-frills JAX library providing core abstractions for diffusion models. DiffusionLab focuses on corruption/forward processes for continuous, discrete, and simplicial (soft-categorical) data, designed to be easy to understand, modify, and extend for research.

The API is PyTree-first: data, times, and predictions can be arbitrarily nested (e.g. multimodal dicts), and everything composes naturally with JAX transformations.

Features

  • Corruption processes for continuous (Gaussian), discrete (absorbing/uniform), and simplicial (Dirichlet) data
  • Reverse-process samplers including Euler, DDM, and discrete ancestral sampling
  • Training losses with configurable SNR weighting and time-sampling schedules
  • Reference model architectures: MLP and DiT (Diffusion Transformer)
  • Classifier-free guidance out of the box
  • Multimodal support: jointly corrupt and denoise mixed continuous/discrete data with a single model

Installation

Install from PyPI:

pip install diffusionlab

Or with uv:

uv add diffusionlab

From source

git clone https://github.com/druvpai/DiffusionLab.git
cd DiffusionLab
uv sync

Quick start

Set up and sample from a flow-matching model on 2D data in ~30 lines:

import jax
import jax.numpy as jnp
from diffusionlab.processes.gaussian import FlowMatching
from diffusionlab.models.mlp import DiffusionMLP
from diffusionlab.sampling.continuous import EulerSampler
from diffusionlab.sampling.base import sample
from diffusionlab.sampling.schedules import uniform_schedule
from diffusionlab.training.denoising import DenoisingLoss
from diffusionlab.training.schedules import uniform_time_with_eps

# Define process and model
process = FlowMatching(data_shape=(2,))
model = DiffusionMLP(
    data_dim=2, hidden_dim=128, num_layers=4,
    time_emb_dim=64, cond_dim=64, prediction_kind="v_t",
    key=jax.random.key(0),
)

# Compute loss on a batch
loss_obj = DenoisingLoss(process=process)
key = jax.random.key(1)
key_t, key_fwd = jax.random.split(key)
t = uniform_time_with_eps(key_t, batch=256)
x_t, aux = process.forward(key_fwd, x_0, t)
out = loss_obj(predictor=model, x_0=x_0, x_t=x_t, t=t, aux=aux, cond=None)

# Sample from the trained model
sampler = EulerSampler(process=process)
schedule = uniform_schedule(num_steps=100, batch_size=512)
samples = sample(key=jax.random.key(2), sampler=sampler, predictor=model,
                 batch_size=512, time_schedule=schedule)

Examples

The examples/ directory contains complete, runnable scripts:

Example Description
two_moons.py Flow matching on 2D synthetic data with trajectory visualization
mnist_dit.py Class-conditional DiT on MNIST with classifier-free guidance
mnist_multimodal.py Joint image + label generation with mixed continuous/discrete diffusion

Run any example with:

uv run python examples/two_moons.py
uv run python examples/mnist_dit.py
uv run python examples/mnist_multimodal.py

The MNIST examples download data automatically on first run.

Project structure

diffusionlab/
├── processes/              # Forward corruption processes
│   ├── base.py             #   CorruptionProcess, MultimodalCorruptionProcess
│   ├── interpolation.py    #   InterpolationProcess (abstract)
│   ├── gaussian.py         #   VPCosine, VPLinear, VE, EDM, FlowMatching
│   ├── discrete.py         #   MaskLinear, MaskCosine, UniformLinear, UniformCosine
│   └── simplicial.py       #   SimplicialLinear, SimplicialCosine
├── sampling/               # Reverse-process samplers
│   ├── base.py             #   Sampler, MultimodalSampler, sample, sample_trajectory
│   ├── continuous.py       #   EulerSampler, DDMSampler
│   ├── discrete.py         #   DiscreteAncestralSampler
│   ├── guidance.py         #   CFG (classifier-free guidance)
│   └── schedules.py        #   uniform_schedule, edm_schedule
├── training/               # Training losses and utilities
│   ├── base.py             #   Loss, MultimodalLoss, mse_loss, ce_loss, kl_loss
│   ├── denoising.py        #   DenoisingLoss
│   ├── distillation.py     #   DistillationLoss
│   ├── consistency.py      #   ConsistencyLoss
│   ├── weighting.py        #   SNR-based weight functions
│   └── schedules.py        #   Time sampling strategies
├── models/                 # Reference neural network architectures
│   ├── base.py             #   DiffusionModel (abstract)
│   ├── components.py       #   SinusoidalEmbedding, FiLM, AdaLNZero
│   ├── mlp.py              #   DiffusionMLP, ConditionedDiffusionMLP
│   └── dit.py              #   DiT, ConditionedDiT
├── typing.py               # Prediction dataclass, type aliases
└── utils/
    └── tree_ops.py         # PyTree utilities (bcast_right, tree_map_with_key)

Development

Prerequisites

  • Python >= 3.13
  • uv (recommended) or pip

Setup

git clone https://github.com/druvpai/DiffusionLab.git
cd DiffusionLab
uv sync

Running tests

# Run the full test suite
uv run pytest tests/

# Run a specific test file
uv run pytest tests/processes/test_gaussian.py

# Run tests matching a pattern
uv run pytest tests/ -k "test_forward"

Linting and formatting

uv run ruff format diffusionlab/ tests/
uv run ruff check diffusionlab/ tests/

Type checking

uv run ty check diffusionlab/

Contributing

Contributions are welcome! Here's how to get started:

  1. Fork the repository and clone your fork
  2. Create a branch for your feature or fix:
    git checkout -b my-feature
  3. Install development dependencies:
    uv sync
  4. Make your changes — keep them focused and minimal
  5. Add tests for any new functionality in tests/
  6. Run the checks before submitting:
    uv run ruff format diffusionlab/ tests/
    uv run ruff check diffusionlab/ tests/
    uv run pytest tests/
  7. Open a pull request with a clear description of your changes

Citation

If you use DiffusionLab in your research, please consider citing it:

@Misc{pai25diffusionlab,
    author = {Pai, Druv},
    title = {DiffusionLab},
    howpublished = {\url{https://github.com/DruvPai/DiffusionLab}},
    year = {2025}
}

Many thanks!

License

MIT License. See LICENSE for details.

Acknowledgements

This package was greatly inspired by Hackable Diffusion, but the implementations and some core primitives are very different, leading to greater flexibility and extensibility (IMO).

About

Easy no-frills Jax implementations of common abstractions for simple diffusion models.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages