Skip to content

Poor inference for CRP DP GMM via SMC, PG, IS #10

@luiarthur

Description

@luiarthur

I simulated some data and fit the following model to the data. I get poor inference. i.e., the results don't really match up with the simulation truth. Particularly, I tend to learn 4 to 5 clusters when my data has 4 almost-equally-sized clusters, but there are usually only 1 or 2 big (dominating) clusters in the posterior inference.

So, my first question is am I abusing the API? In my DP gaussian mixtures of location and scale, I use a base measure (H) which is Normal x InverseGamma (two independent distributions) for the location (mu) and scale (sigma). Am I doing this correctly? (It runs, but I suspect I'm doing something outside the intended use.)

My second question is, if the model is implemented correctly, what might be the cause for poor inference? Admittedly, I'm not familiar with SMC/PC. But does increasing the number of particles generally lead to better inference?

# DP GMM model under CRP construction
@model dp_gmm_crp(y) = begin
    nobs = length(y)
    
    alpha ~ Gamma(1, 0.1)  # mean = a*b
    rpm = DirichletProcess(alpha)
    
    # Base measure.
    H = arraydist([Normal(0, 3), InverseGamma(2, 0.05)])  # is this OK?
    
    # Latent assignment.
    z = tzeros(Int, nobs)
    
    # Locations and scales of infinitely many clusters.
    mu_sigma = TArray(Vector{Float64}, 0)  # is this OK?
    
    for i in 1:nobs
        # Number of clusters.
        K = maximum(z)
        n = Vector{Int}([sum(z .== k) for k in 1:K])
        
        # Sample cluster label.
        z[i] ~ ChineseRestaurantProcess(rpm,  n)
        
        # Create a new cluster.
        if z[i] > K
            push!(mu_sigma, [0.0, 0.1])  # is this OK?
            mu_sigma[z[i]] ~ H  # is this OK?
        end
        
        # Sampling distribution.
        mu, sigma = mu_sigma[z[i]]  # is this OK?
        y[i] ~ Normal(mu, sigma)
    end
end
;
# Set random seed for reproducibility
Random.seed!(0);

# Sample from posterior
@time chain = begin
    burn = 2000  # NOTE: The burn in is also returned. Discard manually.
    n_samples = 1000
    iterations = burn + n_samples

    sample(dp_gmm_crp(y), SMC(), iterations)
    # sample(dp_gmm_crp(y), IS(), iterations)
    # sample(dp_gmm_crp(y), Gibbs(PG(5, :z), PG(5, :mu_sigma)), iterations)
end;

Here is the complete notebook.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions