Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ makedocs(;
"vnt/implementation.md",
"vnt/arraylikeblocks.md",
],
"Model evaluation" => "flow.md",
"Storing values" => "values.md",
],
checkdocs=:exports,
doctest=false,
Expand Down
23 changes: 22 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ AbstractVarInfo

```@docs
VarInfo
DynamicPPL.TransformedValue
DynamicPPL.setindex_with_dist!!
```

Expand Down Expand Up @@ -437,6 +436,10 @@ DynamicPPL.StaticTransformation

```@docs
DynamicPPL.transformation
DynamicPPL.LinkAll
DynamicPPL.UnlinkAll
DynamicPPL.LinkSome
DynamicPPL.UnlinkSome
DynamicPPL.link
DynamicPPL.invlink
DynamicPPL.link!!
Expand Down Expand Up @@ -537,6 +540,24 @@ init
get_param_eltype
```

The function [`DynamicPPL.init`](@ref) should return an `AbstractTransformedValue`.
There are three subtypes currently available:

```@docs
DynamicPPL.AbstractTransformedValue
DynamicPPL.VectorValue
DynamicPPL.LinkedVectorValue
DynamicPPL.UntransformedValue
```

The interface for working with transformed values consists of:

```@docs
DynamicPPL.get_transform
DynamicPPL.get_internal_value
DynamicPPL.set_internal_value
```

### Converting VarInfos to/from chains

It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis.
Expand Down
121 changes: 121 additions & 0 deletions docs/src/flow.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# How data flows through a model

Having discussed initialisation strategies and accumulators, we can now put all the pieces together to show how data enters a model, is used to perform computations, and how the results are extracted.

**The summary is: initialisation strategies are responsible for telling the model what values to use for its parameters, whereas accumulators act as containers for aggregated outputs.**

Thus, there is a clear separation between the *inputs* to the model, and the *outputs* of the model.

!!! note

While `VarInfo` and `DefaultContext` still exist, this is mostly a historical remnant. `DefaultContext` means that the inputs should come from the values of the provided `VarInfo`, and the outputs are stored in the accumulators of the provided `VarInfo`. However, this can easily be refactored such that the values are provided directly as an initialisation strategy. See [this issue](https://github.com/TuringLang/DynamicPPL.jl/issues/1184) for more details.

There are three stages to every tilde-statement:

1. Initialisation: get an `AbstractTransformedValue` from the initialisation strategy.

2. Computation: figure out the untransformed (raw) value; compute the log-Jacobian if necessary.
3. Accumulation: pass all the relevant information to the accumulators, which individually decide what to do with it.

In fact this (more or less) directly translates into three lines of code: see e.g. the method for `tilde_assume!!` in `src/onlyaccs.jl`, which (as of the time of writing) looks like:

```julia
function DynamicPPL.tilde_assume!!(ctx::InitContext, dist, vn, template, vi)
# 1. Initialisation
tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)

# 2. Computation
# (Though see also the warning in the computation section below.)
x, inv_logjac = Bijectors.with_logabsdet_jacobian(
DynamicPPL.get_transform(tval), DynamicPPL.get_internal_value(tval)
)

# 3. Accumulation
vi = DynamicPPL.accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template)
return x, vi
end
```

For `tilde_observe!!`, the code is very similar, but even easier: the value can be read directly from the data provided to the model, so there is no need for an initialisation step.
Since the value is already untransformed, we can skip the second step.
Finally, accumulators must behave differently: e.g. incrementing the likelihood instead of the prior.
That is accomplished by calling `accumulate_observe!!` instead of `accumulate_assume!!`.

In the following sections, we stick to the three sections of `tilde_assume!!`.

## Initialisation

```julia
tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy)
```

The initialisation step is handled by the `init` function, which dispatches on the initialisation strategy.
For example, if `ctx.strategy` is `InitFromPrior()`, then `init()` samples a value from the distribution `dist`.

!!! note

For `DefaultContext`, this is replaced by looking for the value stored inside `vi`. As described above, this can be refactored in the near future.

This step, in general, does not return just the raw value (like `rand(dist)`).
It returns an [`DynamicPPL.AbstractTransformedValue`](@ref), which represents a value that _may_ have been transformed.
In the case of `InitFromPrior()`, the value is of course not transformed; we return a [`DynamicPPL.UntransformedValue`](@ref) wrapping the sampled value.

However, consider the case where we are using parameters stored inside a `VarInfo`: the value may have been stored either as a vectorised form, or as a linked vectorised form.
In this case, `init()` will return either a [`DynamicPPL.VectorValue`](@ref) or a [`DynamicPPL.LinkedVectorValue`](@ref).

The reason why we return this wrapped value is because sometimes we don't want to eagerly perform the transformation.
Consider the case where we have an accumulator that attempts to store linked values (this is done precisely when linking a VarInfo: the linked values are stored in an accumulator, which then becomes the basis of the linked VarInfo).
In this case, if we eagerly perform the inverse link transformation, we would have to link it again inside the accumulator, which is inefficient!

The `AbstractTransformedValue` is passed straight through and is used by both the computation and accumulation steps.

## Computation

```julia
x, inv_logjac = Bijectors.with_logabsdet_jacobian(
DynamicPPL.get_transform(tval), DynamicPPL.get_internal_value(tval)
)
```

At *some* point, we do need to perform the transformation to get the actual raw value.
This is because DynamicPPL promises in the model that the variables on the left-hand side of the tilde are actual raw values.

```julia
@model function f()
x ~ dist
# Here, `x` _must_ be the actual raw value.
@show x
end
```

Thus, regardless of what we are accumulating, we will have to unwrap the transformed value provided by `init()`.
We also need to account for the log-Jacobian of the transformation, if any.

!!! note

In principle, if the log-Jacobian is not of interest to any of the accumulators, we _could_ skip computing it here.
However, that is not easy to determine in practice.
We also cannot defer the log-Jacobian computation to the accumulator, since if multiple accumulators need the log-Jacobian, we would end up computing it multiple times.
The current situation of computing it once here is the most sensible compromise (for now).

One could envision a future where accumulators declare upfront (via their type) whether they need the log-Jacobian or not. We could then skip computing it if no accumulator needs it.

!!! warning

If you look at the source code for that method, it is more complicated than the above!
Have we lied?
It turns out that there is a subtlety here: the transformation obtained from `DynamicPPL.get_transform(tval)` may in fact be incorrect.

Consider the case where a transform is dependent on the value itself (e.g., a variable whose support depends on another variable).
In this case, setting new values into a VarInfo (via `unflatten!!`) may cause the cached transformations to be invalid.
Where possible, it is better to re-obtain the transformation from `dist`, which is always up-to-date since it is obtained from model execution.

## Accumulation

```julia
vi = DynamicPPL.accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template)
```

This step is where most of the interesting action happens.

[...]
121 changes: 121 additions & 0 deletions docs/src/values.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Storing values

## The role of VarInfo

As described in the [model evaluation documentation page](./flow.md), each tilde-statement is split up into three parts:

1. Initialisation;
2. Computation; and
3. Accumulation.

Unfortunately, not everything in DynamicPPL follows this clean structure yet.
In particular, there is a struct, called `VarInfo`, which has a dual role in both initialisation and accumulation:

```julia
struct VarInfo{linked,V<:VarNamedTuple,A<:AccumulatorTuple}
values::V
accs::A
end
```

The `values` field stores either [`LinkedVectorValue`](@ref)s or [`VectorValue`](@ref)s.
The `link` type parameter can either be `true` or `false`, which indicates that _all values stored_ are linked or unlinked, respectively; or it can be `nothing`, which indicates that it is not known whether the values are linked or unlinked, and must be checked on a case-by-case basis.

Here is an example:

```@example 1
using DynamicPPL, Distributions

@model function dirichlet()
x = zeros(3)
return x[1:3] ~ Dirichlet(ones(3))
end
dirichlet_model = dirichlet()
vi = VarInfo(dirichlet_model)
vi
```

In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref) documentation for information on this).
The reason is because, if the value is linked, it may have a different size than the number of indices in the `VarName`.
This means that when retrieving the keys, we obtain each block as a single key:

```@example 1
keys(vi.values)
```

## Towards a new framework

In a `VarInfo`, the `accs` field is responsible for the accumulation step, just like an ordinary `AccumulatorTuple`.

However, `values` serves a dual purpose: it is sometimes used for initialisation (when the model's leaf context is `DefaultContext`, the `AbstractTransformedValue` to be used in the computation step is read from it) and it is sometimes also used for accumulation (when linking a VarInfo, we will potentially store a new `AbstractTransformedValue` in it).

The path to removing `VarInfo` is essentially to separate these two roles:

1. The initialisation role of `varinfo.values` can be taken over by an initialisation strategy that wraps it.
Recall that the only role of an initialisation strategy is to provide an `AbstractTransformedValue` via [`DynamicPPL.init`](@ref).
This can be trivially done by indexing into the `VarNamedTuple` stored in the strategy.

2. The accumulation role of `varinfo.values` can be taken over by a new accumulator, which we call `TransformedValueAccumulator`.
The nomenclature here is not especially precise: it does not store `AbstractTransformedValue`s per se, but only two subtypes of it, `LinkedVectorValue` and `VectorValue`.

`TransformedValueAccumulator` is implemented inside `src/accs/transformed_value.jl`, and additionally includes a link strategy as a parameter: the link strategy is responsible for deciding which values should be stored as `LinkedVectorValue`s and which as `VectorValue`s.

!!! note

Decoupling the initialisation from the accumulation also means that we can pair different initialisation strategies with a `TransformedValueAccumulator`.
Previously, to link a VarInfo, you would need to first generate an unlinked VarInfo and then link it.
Now, you can directly create a linked VarInfo (i.e., accumulate `LinkedVectorValue`s) by sampling from the prior (i.e., initialise with `InitFromPrior`).

## ValuesAsInModelAccumulator

Earlier we said that `TransformedValueAccumulator` stores only two subtypes of `AbstractTransformedValue`: `LinkedVectorValue` and `VectorValue`.

It is often also useful to store the raw values, i.e., [`UntransformedValue`](@ref)s; but additionally, since `UntransformedValue`s must always correspond exactly to the indices they are assigned to, we can unwrap them and do not need to store them as array-like blocks.

This is the role of `ValuesAsInModelAccumulator`.

!!! note

The name is a historical artefact, and can definitely be improved. Suggestions are welcome!

```@example 1
oavi = DynamicPPL.OnlyAccsVarInfo()
oavi = DynamicPPL.setaccs!!(oavi, (DynamicPPL.ValuesAsInModelAccumulator(false),))
_, oavi = DynamicPPL.init!!(dirichlet_model, oavi)
raw_vals = DynamicPPL.getacc(oavi, Val(:ValuesAsInModel)).values
```

Note that when we unwrap `UntransformedValue`s, we also lose the block structure that was present in the model.
That means that in `ValuesAsInModelAccumulator`, there is no longer any notion that `x[1:3]` was set together, so the keys correspond to the individual indices.

```@example 1
keys(raw_vals)
```

In particular, the outputs of `ValuesAsInModelAccumulator` are used for chain construction.
This is why indices are split up in chains.

!!! note

If you have an entire vector stored as a top-level symbol, e.g. `x ~ Dirichlet(ones(3))`, it will not be broken up (as long as you use FlexiChains).

## Why do we still need to store `TransformedValue`s?

Given that `ValuesAsInModelAccumulator` exists, one may wonder why we still need to store `TransformedValue`s at all, i.e. what the purpose of `TransformedValueAccumulator` is.

Currently, the only remaining reason for transformed values is the fact that we may sometimes need to perform [`DynamicPPL.unflatten!!`](@ref) on a `VarInfo`, to insert new values into it from a vector.

```@example 1
vi = VarInfo(dirichlet_model)
vi[@varname(x[1:3])]
```

```@example 1
vi = DynamicPPL.unflatten!!(vi, [0.2, 0.5, 0.3])
vi[@varname(x[1:3])]
```

If we do not store the vectorised form of the values, we will not know how many values to read from the input vector for each key.

Removing upstream usage of `unflatten!!` would allow us to completely get rid of `TransformedValueAccumulator` and only ever use `ValuesAsInModelAccumulator`.
See [this DynamicPPL issue](https://github.com/TuringLang/DynamicPPL.jl/issues/836) for more information.
29 changes: 5 additions & 24 deletions docs/src/vnt/arraylikeblocks.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ Some examples follow.

In `VarInfo`, we need to be able to store either linked or unlinked values (in general, `AbstractTransformedValue`s).
These are always vectorised values, and the linked and unlinked vectors may have different sizes (this is indeed the case for Dirichlet distributions).
This means that we have to collectively assign multiple indices in the `VarNamedTuple` to a single vector, which may or may not have the same size as the indices.

```@example 1
@model function dirichlet()
Expand All @@ -132,33 +133,13 @@ vi = VarInfo(dirichlet_model)
vi.values
```

Thus, in the actual `VarInfo` we do not have a notion of what `x[1]` is.
This means that in the actual `VarInfo` we do not have a notion of what `x[1]` is:

**Note**: this is in contrast to `ValuesAsInModelAccumulator`, where we do store raw values:

```@example 1
oavi = DynamicPPL.OnlyAccsVarInfo()
oavi = DynamicPPL.setaccs!!(oavi, (DynamicPPL.ValuesAsInModelAccumulator(false),))
_, oavi = DynamicPPL.init!!(dirichlet_model, oavi)
raw_vals = DynamicPPL.getacc(oavi, Val(:ValuesAsInModel)).values
```

This distinction is important to understand when working with downstream code that uses `VarInfo` and its outputs.
In particular, when constructing a chain, we use the raw values from `ValuesAsInModelAccumulator`, not the linked/unlinked values from `VarInfo`.

There is also a difference between the keys.
Because the `VarInfo` stores array-like blocks, the keys correspond to the entire blocks:

```@example 1
keys(vi.values)
```@repl 1
vi[@varname(x[1])]
```

On the other hand, in `ValuesAsInModelAccumulator`, there is no longer any notion that `x[1:3]` was set together, so the keys correspond to the individual indices.
This is why indices are split up in chains:

```@example 1
keys(raw_vals)
```
See the [documentation on storing values](@ref "Storing values") for more details.

### Prior distributions

Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ end
"""
AbstractMCMC.from_samples(
::Type{MCMCChains.Chains},
params_and_stats::AbstractMatrix{<:ParamsWithStats}
params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats}
)

Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object.
Expand Down
Loading