diff --git a/docs/make.jl b/docs/make.jl index 7ee874140..5689acb3d 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -46,6 +46,8 @@ makedocs(; "vnt/implementation.md", "vnt/arraylikeblocks.md", ], + "Model evaluation" => "flow.md", + "Storing values" => "values.md", ], checkdocs=:exports, doctest=false, diff --git a/docs/src/api.md b/docs/src/api.md index 661417777..518189185 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -343,7 +343,6 @@ AbstractVarInfo ```@docs VarInfo -DynamicPPL.TransformedValue DynamicPPL.setindex_with_dist!! ``` @@ -437,6 +436,10 @@ DynamicPPL.StaticTransformation ```@docs DynamicPPL.transformation +DynamicPPL.LinkAll +DynamicPPL.UnlinkAll +DynamicPPL.LinkSome +DynamicPPL.UnlinkSome DynamicPPL.link DynamicPPL.invlink DynamicPPL.link!! @@ -537,6 +540,24 @@ init get_param_eltype ``` +The function [`DynamicPPL.init`](@ref) should return an `AbstractTransformedValue`. +There are three subtypes currently available: + +```@docs +DynamicPPL.AbstractTransformedValue +DynamicPPL.VectorValue +DynamicPPL.LinkedVectorValue +DynamicPPL.UntransformedValue +``` + +The interface for working with transformed values consists of: + +```@docs +DynamicPPL.get_transform +DynamicPPL.get_internal_value +DynamicPPL.set_internal_value +``` + ### Converting VarInfos to/from chains It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis. diff --git a/docs/src/flow.md b/docs/src/flow.md new file mode 100644 index 000000000..d72d49a3b --- /dev/null +++ b/docs/src/flow.md @@ -0,0 +1,121 @@ +# How data flows through a model + +Having discussed initialisation strategies and accumulators, we can now put all the pieces together to show how data enters a model, is used to perform computations, and how the results are extracted. + +**The summary is: initialisation strategies are responsible for telling the model what values to use for its parameters, whereas accumulators act as containers for aggregated outputs.** + +Thus, there is a clear separation between the *inputs* to the model, and the *outputs* of the model. + +!!! note + + While `VarInfo` and `DefaultContext` still exist, this is mostly a historical remnant. `DefaultContext` means that the inputs should come from the values of the provided `VarInfo`, and the outputs are stored in the accumulators of the provided `VarInfo`. However, this can easily be refactored such that the values are provided directly as an initialisation strategy. See [this issue](https://github.com/TuringLang/DynamicPPL.jl/issues/1184) for more details. + +There are three stages to every tilde-statement: + + 1. Initialisation: get an `AbstractTransformedValue` from the initialisation strategy. + + 2. Computation: figure out the untransformed (raw) value; compute the log-Jacobian if necessary. + 3. Accumulation: pass all the relevant information to the accumulators, which individually decide what to do with it. + +In fact this (more or less) directly translates into three lines of code: see e.g. the method for `tilde_assume!!` in `src/onlyaccs.jl`, which (as of the time of writing) looks like: + +```julia +function DynamicPPL.tilde_assume!!(ctx::InitContext, dist, vn, template, vi) + # 1. Initialisation + tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy) + + # 2. Computation + # (Though see also the warning in the computation section below.) + x, inv_logjac = Bijectors.with_logabsdet_jacobian( + DynamicPPL.get_transform(tval), DynamicPPL.get_internal_value(tval) + ) + + # 3. Accumulation + vi = DynamicPPL.accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template) + return x, vi +end +``` + +For `tilde_observe!!`, the code is very similar, but even easier: the value can be read directly from the data provided to the model, so there is no need for an initialisation step. +Since the value is already untransformed, we can skip the second step. +Finally, accumulators must behave differently: e.g. incrementing the likelihood instead of the prior. +That is accomplished by calling `accumulate_observe!!` instead of `accumulate_assume!!`. + +In the following sections, we stick to the three sections of `tilde_assume!!`. + +## Initialisation + +```julia +tval = DynamicPPL.init(ctx.rng, vn, dist, ctx.strategy) +``` + +The initialisation step is handled by the `init` function, which dispatches on the initialisation strategy. +For example, if `ctx.strategy` is `InitFromPrior()`, then `init()` samples a value from the distribution `dist`. + +!!! note + + For `DefaultContext`, this is replaced by looking for the value stored inside `vi`. As described above, this can be refactored in the near future. + +This step, in general, does not return just the raw value (like `rand(dist)`). +It returns an [`DynamicPPL.AbstractTransformedValue`](@ref), which represents a value that _may_ have been transformed. +In the case of `InitFromPrior()`, the value is of course not transformed; we return a [`DynamicPPL.UntransformedValue`](@ref) wrapping the sampled value. + +However, consider the case where we are using parameters stored inside a `VarInfo`: the value may have been stored either as a vectorised form, or as a linked vectorised form. +In this case, `init()` will return either a [`DynamicPPL.VectorValue`](@ref) or a [`DynamicPPL.LinkedVectorValue`](@ref). + +The reason why we return this wrapped value is because sometimes we don't want to eagerly perform the transformation. +Consider the case where we have an accumulator that attempts to store linked values (this is done precisely when linking a VarInfo: the linked values are stored in an accumulator, which then becomes the basis of the linked VarInfo). +In this case, if we eagerly perform the inverse link transformation, we would have to link it again inside the accumulator, which is inefficient! + +The `AbstractTransformedValue` is passed straight through and is used by both the computation and accumulation steps. + +## Computation + +```julia +x, inv_logjac = Bijectors.with_logabsdet_jacobian( + DynamicPPL.get_transform(tval), DynamicPPL.get_internal_value(tval) +) +``` + +At *some* point, we do need to perform the transformation to get the actual raw value. +This is because DynamicPPL promises in the model that the variables on the left-hand side of the tilde are actual raw values. + +```julia +@model function f() + x ~ dist + # Here, `x` _must_ be the actual raw value. + @show x +end +``` + +Thus, regardless of what we are accumulating, we will have to unwrap the transformed value provided by `init()`. +We also need to account for the log-Jacobian of the transformation, if any. + +!!! note + + In principle, if the log-Jacobian is not of interest to any of the accumulators, we _could_ skip computing it here. + However, that is not easy to determine in practice. + We also cannot defer the log-Jacobian computation to the accumulator, since if multiple accumulators need the log-Jacobian, we would end up computing it multiple times. + The current situation of computing it once here is the most sensible compromise (for now). + + One could envision a future where accumulators declare upfront (via their type) whether they need the log-Jacobian or not. We could then skip computing it if no accumulator needs it. + +!!! warning + + If you look at the source code for that method, it is more complicated than the above! + Have we lied? + It turns out that there is a subtlety here: the transformation obtained from `DynamicPPL.get_transform(tval)` may in fact be incorrect. + + Consider the case where a transform is dependent on the value itself (e.g., a variable whose support depends on another variable). + In this case, setting new values into a VarInfo (via `unflatten!!`) may cause the cached transformations to be invalid. + Where possible, it is better to re-obtain the transformation from `dist`, which is always up-to-date since it is obtained from model execution. + +## Accumulation + +```julia +vi = DynamicPPL.accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template) +``` + +This step is where most of the interesting action happens. + +[...] diff --git a/docs/src/values.md b/docs/src/values.md new file mode 100644 index 000000000..bb288522b --- /dev/null +++ b/docs/src/values.md @@ -0,0 +1,121 @@ +# Storing values + +## The role of VarInfo + +As described in the [model evaluation documentation page](./flow.md), each tilde-statement is split up into three parts: + + 1. Initialisation; + 2. Computation; and + 3. Accumulation. + +Unfortunately, not everything in DynamicPPL follows this clean structure yet. +In particular, there is a struct, called `VarInfo`, which has a dual role in both initialisation and accumulation: + +```julia +struct VarInfo{linked,V<:VarNamedTuple,A<:AccumulatorTuple} + values::V + accs::A +end +``` + +The `values` field stores either [`LinkedVectorValue`](@ref)s or [`VectorValue`](@ref)s. +The `link` type parameter can either be `true` or `false`, which indicates that _all values stored_ are linked or unlinked, respectively; or it can be `nothing`, which indicates that it is not known whether the values are linked or unlinked, and must be checked on a case-by-case basis. + +Here is an example: + +```@example 1 +using DynamicPPL, Distributions + +@model function dirichlet() + x = zeros(3) + return x[1:3] ~ Dirichlet(ones(3)) +end +dirichlet_model = dirichlet() +vi = VarInfo(dirichlet_model) +vi +``` + +In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref) documentation for information on this). +The reason is because, if the value is linked, it may have a different size than the number of indices in the `VarName`. +This means that when retrieving the keys, we obtain each block as a single key: + +```@example 1 +keys(vi.values) +``` + +## Towards a new framework + +In a `VarInfo`, the `accs` field is responsible for the accumulation step, just like an ordinary `AccumulatorTuple`. + +However, `values` serves a dual purpose: it is sometimes used for initialisation (when the model's leaf context is `DefaultContext`, the `AbstractTransformedValue` to be used in the computation step is read from it) and it is sometimes also used for accumulation (when linking a VarInfo, we will potentially store a new `AbstractTransformedValue` in it). + +The path to removing `VarInfo` is essentially to separate these two roles: + + 1. The initialisation role of `varinfo.values` can be taken over by an initialisation strategy that wraps it. + Recall that the only role of an initialisation strategy is to provide an `AbstractTransformedValue` via [`DynamicPPL.init`](@ref). + This can be trivially done by indexing into the `VarNamedTuple` stored in the strategy. + + 2. The accumulation role of `varinfo.values` can be taken over by a new accumulator, which we call `TransformedValueAccumulator`. + The nomenclature here is not especially precise: it does not store `AbstractTransformedValue`s per se, but only two subtypes of it, `LinkedVectorValue` and `VectorValue`. + +`TransformedValueAccumulator` is implemented inside `src/accs/transformed_value.jl`, and additionally includes a link strategy as a parameter: the link strategy is responsible for deciding which values should be stored as `LinkedVectorValue`s and which as `VectorValue`s. + +!!! note + + Decoupling the initialisation from the accumulation also means that we can pair different initialisation strategies with a `TransformedValueAccumulator`. + Previously, to link a VarInfo, you would need to first generate an unlinked VarInfo and then link it. + Now, you can directly create a linked VarInfo (i.e., accumulate `LinkedVectorValue`s) by sampling from the prior (i.e., initialise with `InitFromPrior`). + +## ValuesAsInModelAccumulator + +Earlier we said that `TransformedValueAccumulator` stores only two subtypes of `AbstractTransformedValue`: `LinkedVectorValue` and `VectorValue`. + +It is often also useful to store the raw values, i.e., [`UntransformedValue`](@ref)s; but additionally, since `UntransformedValue`s must always correspond exactly to the indices they are assigned to, we can unwrap them and do not need to store them as array-like blocks. + +This is the role of `ValuesAsInModelAccumulator`. + +!!! note + + The name is a historical artefact, and can definitely be improved. Suggestions are welcome! + +```@example 1 +oavi = DynamicPPL.OnlyAccsVarInfo() +oavi = DynamicPPL.setaccs!!(oavi, (DynamicPPL.ValuesAsInModelAccumulator(false),)) +_, oavi = DynamicPPL.init!!(dirichlet_model, oavi) +raw_vals = DynamicPPL.getacc(oavi, Val(:ValuesAsInModel)).values +``` + +Note that when we unwrap `UntransformedValue`s, we also lose the block structure that was present in the model. +That means that in `ValuesAsInModelAccumulator`, there is no longer any notion that `x[1:3]` was set together, so the keys correspond to the individual indices. + +```@example 1 +keys(raw_vals) +``` + +In particular, the outputs of `ValuesAsInModelAccumulator` are used for chain construction. +This is why indices are split up in chains. + +!!! note + + If you have an entire vector stored as a top-level symbol, e.g. `x ~ Dirichlet(ones(3))`, it will not be broken up (as long as you use FlexiChains). + +## Why do we still need to store `TransformedValue`s? + +Given that `ValuesAsInModelAccumulator` exists, one may wonder why we still need to store `TransformedValue`s at all, i.e. what the purpose of `TransformedValueAccumulator` is. + +Currently, the only remaining reason for transformed values is the fact that we may sometimes need to perform [`DynamicPPL.unflatten!!`](@ref) on a `VarInfo`, to insert new values into it from a vector. + +```@example 1 +vi = VarInfo(dirichlet_model) +vi[@varname(x[1:3])] +``` + +```@example 1 +vi = DynamicPPL.unflatten!!(vi, [0.2, 0.5, 0.3]) +vi[@varname(x[1:3])] +``` + +If we do not store the vectorised form of the values, we will not know how many values to read from the input vector for each key. + +Removing upstream usage of `unflatten!!` would allow us to completely get rid of `TransformedValueAccumulator` and only ever use `ValuesAsInModelAccumulator`. +See [this DynamicPPL issue](https://github.com/TuringLang/DynamicPPL.jl/issues/836) for more information. diff --git a/docs/src/vnt/arraylikeblocks.md b/docs/src/vnt/arraylikeblocks.md index 5bc494a61..90282dbb8 100644 --- a/docs/src/vnt/arraylikeblocks.md +++ b/docs/src/vnt/arraylikeblocks.md @@ -121,6 +121,7 @@ Some examples follow. In `VarInfo`, we need to be able to store either linked or unlinked values (in general, `AbstractTransformedValue`s). These are always vectorised values, and the linked and unlinked vectors may have different sizes (this is indeed the case for Dirichlet distributions). +This means that we have to collectively assign multiple indices in the `VarNamedTuple` to a single vector, which may or may not have the same size as the indices. ```@example 1 @model function dirichlet() @@ -132,33 +133,13 @@ vi = VarInfo(dirichlet_model) vi.values ``` -Thus, in the actual `VarInfo` we do not have a notion of what `x[1]` is. +This means that in the actual `VarInfo` we do not have a notion of what `x[1]` is: -**Note**: this is in contrast to `ValuesAsInModelAccumulator`, where we do store raw values: - -```@example 1 -oavi = DynamicPPL.OnlyAccsVarInfo() -oavi = DynamicPPL.setaccs!!(oavi, (DynamicPPL.ValuesAsInModelAccumulator(false),)) -_, oavi = DynamicPPL.init!!(dirichlet_model, oavi) -raw_vals = DynamicPPL.getacc(oavi, Val(:ValuesAsInModel)).values -``` - -This distinction is important to understand when working with downstream code that uses `VarInfo` and its outputs. -In particular, when constructing a chain, we use the raw values from `ValuesAsInModelAccumulator`, not the linked/unlinked values from `VarInfo`. - -There is also a difference between the keys. -Because the `VarInfo` stores array-like blocks, the keys correspond to the entire blocks: - -```@example 1 -keys(vi.values) +```@repl 1 +vi[@varname(x[1])] ``` -On the other hand, in `ValuesAsInModelAccumulator`, there is no longer any notion that `x[1:3]` was set together, so the keys correspond to the individual indices. -This is why indices are split up in chains: - -```@example 1 -keys(raw_vals) -``` +See the [documentation on storing values](@ref "Storing values") for more details. ### Prior distributions diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index af8753dc1..36b516dd6 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -30,7 +30,7 @@ end """ AbstractMCMC.from_samples( ::Type{MCMCChains.Chains}, - params_and_stats::AbstractMatrix{<:ParamsWithStats} + params_and_stats::AbstractMatrix{<:DynamicPPL.ParamsWithStats} ) Convert an array of `DynamicPPL.ParamsWithStats` to an `MCMCChains.Chains` object. diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 6cbc00edc..511fc1e55 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -75,10 +75,6 @@ export AbstractVarInfo, accloglikelihood!!, is_transformed, set_transformed!!, - link, - link!!, - invlink, - invlink!!, values_as, # VarName (reexport from AbstractPPL) VarName, @@ -122,8 +118,24 @@ export AbstractVarInfo, InitFromPrior, InitFromUniform, InitFromParams, - init, get_param_eltype, + init, + # Transformed values + VectorValue, + LinkedVectorValue, + UntransformedValue, + get_transform, + get_internal_value, + set_internal_value, + # Linking + LinkAll, + UnlinkAll, + LinkSome, + UnlinkSome, + link, + link!!, + invlink, + invlink!!, # Pseudo distributions NamedDist, NoDist, @@ -204,10 +216,13 @@ include("contexts/prefix.jl") include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") include("varname.jl") +include("transformed_values.jl") include("distribution_wrappers.jl") include("submodel.jl") include("accumulators.jl") include("accs/default.jl") +include("accs/vnt.jl") +include("accs/transformed_values.jl") include("accs/priors.jl") include("accs/values.jl") include("accs/pointwise_logdensities.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0c0b919b8..db162e77d 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -289,13 +289,13 @@ function getacc(vi::AbstractVarInfo, accname::Symbol) end """ - accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right, template) + accumulate_assume!!(vi::AbstractVarInfo, val, tval, logjac, vn, right, template) Update all the accumulators of `vi` by calling `accumulate_assume!!` on them. """ -function accumulate_assume!!(vi::AbstractVarInfo, val, logjac, vn, right, template) +function accumulate_assume!!(vi::AbstractVarInfo, val, tval, logjac, vn, right, template) return map_accumulators!!( - acc -> accumulate_assume!!(acc, val, logjac, vn, right, template), vi + acc -> accumulate_assume!!(acc, val, tval, logjac, vn, right, template), vi ) end @@ -467,6 +467,17 @@ See also: [`getindex(vi::AbstractVarInfo, vn::VarName, dist::Distribution)`](@re """ function getindex_internal end +""" + get_transformed_value(vi::AbstractVarInfo, vn::VarName) + +Return the actual `AbstractTransformedValue` stored in `vi` for variable `vn`. + +This differs from `getindex_internal`, which obtains the `AbstractTransformedValue` and then +directly returns `get_internal_value(tval)`; and `getindex` which returns +`get_transform(tval)(get_internal_value(tval))`. +""" +function get_transformed_value end + @doc """ empty!!(vi::AbstractVarInfo) diff --git a/src/accs/default.jl b/src/accs/default.jl index 54d01901f..8d561f34e 100644 --- a/src/accs/default.jl +++ b/src/accs/default.jl @@ -91,7 +91,9 @@ logp(acc::LogPriorAccumulator) = acc.logp accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior -function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right, template) +function accumulate_assume!!( + acc::LogPriorAccumulator, val, tval, logjac, vn, right, template +) return acclogp(acc, logpdf(right, val)) end accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc @@ -135,7 +137,9 @@ logp(acc::LogJacobianAccumulator) = acc.logjac accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian -function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right, template) +function accumulate_assume!!( + acc::LogJacobianAccumulator, val, tval, logjac, vn, right, template +) return acclogp(acc, logjac) end accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc @@ -157,7 +161,11 @@ logp(acc::LogLikelihoodAccumulator) = acc.logp accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood -accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right, template) = acc +function accumulate_assume!!( + acc::LogLikelihoodAccumulator, val, tval, logjac, vn, right, template +) + return acc +end function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) # Note that it's important to use the loglikelihood function here, not logpdf, because # they handle vectors differently: diff --git a/src/accs/pointwise_logdensities.jl b/src/accs/pointwise_logdensities.jl index 970ab4886..d079a3ce8 100644 --- a/src/accs/pointwise_logdensities.jl +++ b/src/accs/pointwise_logdensities.jl @@ -48,7 +48,7 @@ function combine( end function accumulate_assume!!( - acc::PointwiseLogProbAccumulator{whichlogprob}, val, logjac, vn, right, template + acc::PointwiseLogProbAccumulator{whichlogprob}, val, tval, logjac, vn, right, template ) where {whichlogprob} if whichlogprob == :both || whichlogprob == :prior acc.logps[vn] = logpdf(right, val) diff --git a/src/accs/priors.jl b/src/accs/priors.jl index a34df8e51..d4d425aef 100644 --- a/src/accs/priors.jl +++ b/src/accs/priors.jl @@ -23,7 +23,7 @@ function combine(acc1::PriorDistributionAccumulator, acc2::PriorDistributionAccu end function accumulate_assume!!( - acc::PriorDistributionAccumulator, val, logjac, vn, dist, template + acc::PriorDistributionAccumulator, val, tval, logjac, vn, dist, template ) Accessors.@reset acc.priors = DynamicPPL.templated_setindex!!( acc.priors, dist, vn, template diff --git a/src/accs/transformed_values.jl b/src/accs/transformed_values.jl new file mode 100644 index 000000000..4a0ab25c1 --- /dev/null +++ b/src/accs/transformed_values.jl @@ -0,0 +1,166 @@ +""" + abstract type AbstractLinkStrategy end + +An abstract type for strategies specifying which variables to link or unlink. + +Current subtypes are [`LinkAll`](@ref), [`UnlinkAll`](@ref), [`LinkSome`](@ref), and +[`UnlinkSome`](@ref). + +!!! warning + Even though the subtypes listed above are public, this abstract type is not part of + DynamicPPL's public API and end users should not subtype this. (There should really not + be any reason to!) + +For subtypes of `AbstractLinkStrategy`, the only method that needs to be overloaded is +`DynamicPPL.generate_linked_value`. Note that this is also an internal function. +""" +abstract type AbstractLinkStrategy end + +""" + generate_linked_value(linker::AbstractLinkStrategy, vn::VarName) + +Determine whether a variable with name `vn` should be linked according to the +`linker` strategy. +""" +function generate_linked_value end + +""" + UnlinkAll() <: AbstractLinkStrategy + +Indicate that all variables should be unlinked. +""" +struct UnlinkAll <: AbstractLinkStrategy end +generate_linked_value(::UnlinkAll, ::VarName, ::AbstractTransformedValue) = false + +""" + LinkAll() <: AbstractLinkStrategy + +Indicate that all variables should be linked. +""" +struct LinkAll <: AbstractLinkStrategy end +generate_linked_value(::LinkAll, ::VarName, ::AbstractTransformedValue) = true + +""" + LinkSome(vns) <: AbstractLinkStrategy + +Indicate that the variables in `vns` must be linked. The link statuses of other variables +are preserved. `vns` should be some iterable collection of `VarName`s, although there is no +strict type requirement. +""" +struct LinkSome{V} <: AbstractLinkStrategy + vns::V +end +generate_linked_value(::LinkSome, ::VarName, ::LinkedVectorValue) = true +function generate_linked_value( + linker::LinkSome, vn::VarName, ::Union{VectorValue,UntransformedValue} +) + return any(linker_vn -> subsumes(linker_vn, vn), linker.vns) +end + +""" + UnlinkSome(vns}) <: AbstractLinkStrategy + +Indicate that the variables in `vns` must not Be linked. The link statuses of other +variables are preserved. `vns` should be some iterable collection of `VarName`s, although +there is no strict type requirement. +""" +struct UnlinkSome{V} <: AbstractLinkStrategy + vns::V +end +function generate_linked_value( + ::UnlinkSome, ::VarName, ::Union{VectorValue,UntransformedValue} +) + return false +end +function generate_linked_value(linker::UnlinkSome, vn::VarName, ::LinkedVectorValue) + return !any(linker_vn -> subsumes(linker_vn, vn), linker.vns) +end + +# A transformed value accumulator is just a `VNTAccumulator` that collects +# `AbstractTransformedValues`. It has a double role: +# +# (1) It stores transformed values; which makes it similar in principle to VarInfo itself; +# (2) It also converts those transformed values between linked (`LinkedVectorValue`) and +# unlinked (`VectorValue`) according to an `AbstractLinkStrategy`. In the process it +# also keeps track of the log-Jacobian adjustments that need to be made. +# +# For example, if the strategy is `LinkAll()`, then all transformed values stored will be +# linked. That gives us a way to essentially 'link' a VarInfo, although the idea here is +# different: instead of reaching into the VarInfo and modifying its contents, we instead +# run the model and accumulate transformed values that are linked. +# +# When executing a model with a VarInfo, the `tval` passed through will always be either +# a `VectorValue` or a `LinkedVectorValue` (never an `UntransformedValue`), since that +# is what VarInfo stores. So you might ask: why do we need to handle `UntransformedValue`s? +# +# The answer is that initialisation strategies like `InitFromPrior()` will generate +# `UntransformedValue`s, which can be passed through all the way into these accumulators. +# What this means is that we can _immediately_ generate a linked VarInfo by sampling from +# the prior, without having to first create an unlinked VarInfo and then link it! See the +# `VarInfo` constructors in `src/varinfo.jl` for examples. +# +# This accumulator is used in the implementation of `link!!` and `invlink!!`; however, we +# can't define them in this file as we haven't defined the `VarInfo` struct yet. See +# `src/varinfo.jl` for the definitions. + +const LINK_ACCNAME = :LinkAccumulator +mutable struct Link!{V<:AbstractLinkStrategy} + strategy::V + logjac::LogProbType + Link!(vns::V) where {V} = new{V}(vns, zero(LogProbType)) +end + +function (linker::Link!)(val::Any, tval::LinkedVectorValue, logjac::Any, vn::Any, dist::Any) + original_val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () + return if generate_linked_value(linker.strategy, vn, tval) + # No need to do anything. + tval + else + # tval contains a linked value, we need to invlink it. + # Note that logjac of from_linked_vec_transform will already be included in + # the `logjac` argument (!) so we only need to add the logjac of to_vec_transform + # here, which in principle should be zero, but... + f = to_vec_transform(dist) + new_val, vect_logjac = with_logabsdet_jacobian(f, val) + # In this case, the LogJacobianAccumulator will have counted logjac. We want to + # cancel out that contribution here since we are removing the linking. + linker.logjac += vect_logjac - logjac + VectorValue(new_val, inverse(f), original_val_size) + end +end + +function (linker::Link!)(val::Any, tval::VectorValue, logjac::Any, vn::Any, dist::Any) + # Note that we don't need to care about the logjac passed in, since + # LogJacobianAccumulator takes care of _that_. + original_val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () + return if generate_linked_value(linker.strategy, vn, tval) + # tval contains an unlinked value, we need to generate a new linked value. + f = to_linked_vec_transform(dist) + new_val, link_logjac = with_logabsdet_jacobian(f, val) + linker.logjac += link_logjac + LinkedVectorValue(new_val, inverse(f), original_val_size) + else + # No need to do anything. + tval + end +end + +function (linker::Link!)( + val::Any, tval::UntransformedValue, logjac::Any, vn::Any, dist::Any +) + original_val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () + # Inside here we can just use `val` directly since that is the same thing as unwrapping + # `tval`. + return if generate_linked_value(linker.strategy, vn, tval) + f = to_linked_vec_transform(dist) + new_val, link_logjac = with_logabsdet_jacobian(f, val) + linker.logjac += link_logjac + LinkedVectorValue(new_val, inverse(f), original_val_size) + else + f = to_vec_transform(dist) + # logjac should really be zero, but well. Just check, I guess. + new_val, logjac = with_logabsdet_jacobian(f, val) + linker.logjac += logjac + VectorValue(new_val, inverse(f), original_val_size) + end +end diff --git a/src/accs/values.jl b/src/accs/values.jl index b9843255b..60408e3b3 100644 --- a/src/accs/values.jl +++ b/src/accs/values.jl @@ -64,7 +64,7 @@ function is_extracting_values(vi::AbstractVarInfo) end function accumulate_assume!!( - acc::ValuesAsInModelAccumulator, val, logjac, vn::VarName, right, template + acc::ValuesAsInModelAccumulator, val, tval, logjac, vn::VarName, right, template ) return push!!(acc, vn, val, template) end diff --git a/src/accs/vnt.jl b/src/accs/vnt.jl new file mode 100644 index 000000000..f8b8e0cac --- /dev/null +++ b/src/accs/vnt.jl @@ -0,0 +1,56 @@ +""" + VNTAccumulator{AccName}(f::F, values::VarNamedTuple=VarNamedTuple()) where {AccName,F} + +A generic accumulator that applies a function `f` to values seen during model execution +and stores the results in a `VarNamedTuple`. + +`AccName` is the name of the accumulator, and is exposed to allow users to define and use +multiple forms of `VNTAccumulator` within the same set of accumulators. In theory, each +`VNTAccumulator` with the same function `f` should use the same accumulator name. This is +not enforced. + +The function `f` should have the signature: + + f(val, tval, logjac, vn, dist) -> value_to_store + +where `val`, `tval`, `logjac`, `vn`, and `dist` have their usual meanings in +accumulate_assume!! (see its docstring for more details). +""" +struct VNTAccumulator{AccName,F,VNT<:VarNamedTuple} <: AbstractAccumulator + f::F + values::VNT +end +function VNTAccumulator{AccName}( + f::F, values::VarNamedTuple=VarNamedTuple() +) where {AccName,F} + return VNTAccumulator{AccName,F,typeof(values)}(f, values) +end + +function Base.copy(acc::VNTAccumulator{AccName}) where {AccName} + return VNTAccumulator{AccName}(acc.f, copy(acc.values)) +end + +accumulator_name(::VNTAccumulator{AccName}) where {AccName} = AccName + +function _zero(acc::VNTAccumulator{AccName}) where {AccName} + return VNTAccumulator{AccName}(acc.f, empty(acc.values)) +end +reset(acc::VNTAccumulator{AccName}) where {AccName} = _zero(acc) +split(acc::VNTAccumulator{AccName}) where {AccName} = _zero(acc) +function combine( + acc1::VNTAccumulator{AccName}, acc2::VNTAccumulator{AccName} +) where {AccName} + if acc1.f != acc2.f + throw(ArgumentError("Cannot combine VNTAccumulators with different functions")) + end + return VNTAccumulator{AccName}(acc2.f, merge(acc1.values, acc2.values)) +end + +function accumulate_assume!!( + acc::VNTAccumulator{AccName}, val, tval, logjac, vn, dist, template +) where {AccName} + new_val = acc.f(val, tval, logjac, vn, dist) + new_values = DynamicPPL.templated_setindex!!(acc.values, new_val, vn, template) + return VNTAccumulator{AccName}(acc.f, new_values) +end +accumulate_observe!!(acc::VNTAccumulator, right, left, vn) = acc diff --git a/src/accumulators.jl b/src/accumulators.jl index 1c8c032d2..4af4000ad 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -12,7 +12,7 @@ seen so far. An accumulator type `T <: AbstractAccumulator` must implement the following methods: - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` - `accumulate_observe!!(acc::T, dist, val, vn)` -- `accumulate_assume!!(acc::T, val, logjac, vn, dist, template)` +- `accumulate_assume!!(acc::T, val, tval, logjac, vn, dist, template)` - `reset(acc::T)` - `Base.copy(acc::T)` @@ -20,6 +20,10 @@ In these functions: - `val` is the new value of the random variable sampled from a distribution (always in the original unlinked space), or the value on the left-hand side of an observe statement. +- `tval` is the original `AbstractTransformedValue` that was obtained from the + initialisation strategy. This is passed through unchanged to `accumulate_assume!!` since + it can be reused for some accumulators (e.g. when storing linked values, if the linked + value was already provided, it is faster to reuse it than to re-link `val`). - `dist` is the distribution on the RHS of the tilde statement. - `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. @@ -62,7 +66,7 @@ See also: [`accumulate_assume!!`](@ref) function accumulate_observe!! end """ - accumulate_assume!!(acc::AbstractAccumulator, val, logjac, vn, right, template) + accumulate_assume!!(acc::AbstractAccumulator, val, tval, logjac, vn, right, template) Update `acc` in a `tilde_assume!!` call. Returns the updated `acc`. @@ -227,6 +231,17 @@ function getacc(at::AccumulatorTuple, ::Val{accname}) where {accname} return at[accname] end +""" + deleteacc(at::AccumulatorTuple, ::Val{accname}) + +Delete the accumulator with name `accname` from `at`. Returns a new `AccumulatorTuple`. +""" +function deleteacc( + accs::AccumulatorTuple{N,<:NamedTuple{names}}, ::Val{T} +) where {N,names,T} + return AccumulatorTuple(NamedTuple{filter(x -> x != T, names)}(accs.nt)) +end + function Base.map(func::Function, at::AccumulatorTuple) return AccumulatorTuple(map(func, at.nt)) end diff --git a/src/bijector.jl b/src/bijector.jl index 218262f1a..49a9e117b 100644 --- a/src/bijector.jl +++ b/src/bijector.jl @@ -26,7 +26,9 @@ function combine(acc1::BijectorAccumulator, acc2::BijectorAccumulator) ) end -function accumulate_assume!!(acc::BijectorAccumulator, val, logjac, vn, right, template) +function accumulate_assume!!( + acc::BijectorAccumulator, val, tval, logjac, vn, right, template +) bijector = _compose_no_identity( to_linked_vec_transform(right), from_vec_transform(right) ) diff --git a/src/contexts/default.jl b/src/contexts/default.jl index 5a6ca1095..6c6f50813 100644 --- a/src/contexts/default.jl +++ b/src/contexts/default.jl @@ -33,10 +33,18 @@ with `vn` from `vi`, If `vi` does not contain an appropriate value then this wil function tilde_assume!!( ::DefaultContext, right::Distribution, vn::VarName, template::Any, vi::AbstractVarInfo ) - y = getindex_internal(vi, vn) - f = from_maybe_linked_internal_transform(vi, vn, right) - x, inv_logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, right, template) + tval = get_transformed_value(vi, vn) + trf = if tval isa LinkedVectorValue + # Note that we can't rely on the stored transform being correct (e.g. if new values + # were placed in `vi` via `unflatten!!`, so we regenerate the transforms. + from_linked_vec_transform(right) + elseif tval isa VectorValue + from_vec_transform(right) + else + error("Expected transformed value to be a VectorValue or LinkedVectorValue") + end + x, inv_logjac = with_logabsdet_jacobian(trf, get_internal_value(tval)) + vi = accumulate_assume!!(vi, x, tval, -inv_logjac, vn, right, template) return x, vi end diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 5968fa8a2..af16c30ce 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -14,14 +14,13 @@ abstract type AbstractInitStrategy end Generate a new value for a random variable with the given distribution. -This function must return a tuple `(x, trf)`, where - -- `x` is the generated value -- `trf` is a function that transforms the generated value back to the unlinked space. If the - value is already in unlinked space, then this should be `DynamicPPL.typed_identity`. You - can also use `Base.identity`, but if you use this, you **must** be confident that - `zero(eltype(x))` will **never** error. See the docstring of `typed_identity` for more - information. +This function must return an `AbstractTransformedValue`. + +If `strategy` provides values that are already untransformed (e.g., a Float64 within (0, 1) +for `dist::Beta`, then you should return an `UntransformedValue`. + +Otherwise, often there are cases where this will return either a `VectorValue` or a +`LinkedVectorValue`, for example, if the strategy is reading from an existing `VarInfo`. """ function init end @@ -76,7 +75,7 @@ Obtain new values by sampling from the prior distribution. """ struct InitFromPrior <: AbstractInitStrategy end function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, ::InitFromPrior) - return rand(rng, dist), typed_identity + return UntransformedValue(rand(rng, dist)) end """ @@ -116,7 +115,11 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro if x isa Array{<:Any,0} x = x[] end - return x, typed_identity + # NOTE: We don't return `LinkedVectorValue(y, ...)` here because we don't want the + # logjac of this transform to be included when evaluating the model! The fact that + # b_inv(y) has a non-trivial logjacobian is just an artefact of how the sampling is done + # and has nothing to do with the model. + return UntransformedValue(x) end """ @@ -171,20 +174,24 @@ InitFromParams(params) = InitFromParams(params, InitFromPrior()) function init( rng::Random.AbstractRNG, vn::VarName, dist::Distribution, p::InitFromParams{P} ) where {P<:Union{AbstractDict{<:VarName},NamedTuple,VarNamedTuple}} - # TODO(penelopeysm): It would be nice to do a check to make sure that all - # of the parameters in `p.params` were actually used, and either warn or - # error if they aren't. This is actually quite non-trivial though because - # the structure of Dicts in particular can have arbitrary nesting. return if hasvalue(p.params, vn, dist) x = getvalue(p.params, vn, dist) if x === missing p.fallback === nothing && error("A `missing` value was provided for the variable `$(vn)`.") init(rng, vn, dist, p.fallback) + elseif x isa VectorValue + # In this case, we can't trust the transform stored in x because the _value_ + # in x may have been changed via unflatten!! without the transform being + # updated. Therefore, we always recompute the transform here. + VectorValue(x.val, from_vec_transform(dist), x.size) + elseif x isa LinkedVectorValue + # Same as above. + LinkedVectorValue(x.val, from_linked_vec_transform(dist), x.size) + elseif x isa UntransformedValue + x else - # TODO(penelopeysm): Since x is user-supplied, maybe we could also - # check here that the type / size of x matches the dist? - x, typed_identity + UntransformedValue(x) end else p.fallback === nothing && error("No value was provided for the variable `$(vn)`.") @@ -197,6 +204,47 @@ function get_param_eltype( return infer_nested_eltype(typeof(strategy.params)) end +""" +Like InitFromParams, but it is always assumed that the VNT contains _exactly_ the +correct set of variables, and that indexing into them will always return _exactly_ +the values for those variables. + +The main difference is that InitFromParams will call hasvalue(p.params, vn, dist) +rather than just hasvalue(p.params, vn), which can be substantially slower. + +TODO(penelopeysm): Get rid of MCMCChains and never call the three-value argument again. +Seriously. It's just nuts that I have to do these workarounds because of a package that +isn't even DynamicPPL. +""" +struct InitFromParamsUnsafe{P<:VarNamedTuple} <: AbstractInitStrategy + params::P +end +function init( + ::Random.AbstractRNG, + vn::VarName, + dist::Distribution, + p::InitFromParamsUnsafe{<:VarNamedTuple}, +) + return if haskey(p.params, vn) + x = p.params[vn] + if x isa VectorValue + # In this case, we can't trust the transform stored in x because the _value_ + # in x may have been changed via unflatten!! without the transform being + # updated. Therefore, we always recompute the transform here. + VectorValue(x.val, from_vec_transform(dist), x.size) + elseif x isa LinkedVectorValue + # Same as above. + LinkedVectorValue(x.val, from_linked_vec_transform(dist), x.size) + elseif x isa UntransformedValue + x + else + UntransformedValue(x) + end + else + error("No value was provided for the variable `$(vn)`.") + end +end + """ RangeAndLinked @@ -271,12 +319,19 @@ function init( # case we use the stored link status), or `true` / `false`, which # indicates that all variables are linked / unlinked. linked = isnothing(T) ? range_and_linked.is_linked : T - transform = if linked - from_linked_vec_transform(dist) + return if linked + LinkedVectorValue( + view(vr.vect, range_and_linked.range), + from_linked_vec_transform(dist), + range_and_linked.original_size, + ) else - from_vec_transform(dist) + VectorValue( + view(vr.vect, range_and_linked.range), + from_vec_transform(dist), + range_and_linked.original_size, + ) end - return (@view vr.vect[range_and_linked.range]), transform end function get_param_eltype(strategy::InitFromParams{<:VectorWithRanges}) return eltype(strategy.params.vect) @@ -309,11 +364,16 @@ end function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, template::Any, vi::AbstractVarInfo ) - val, transform = init(ctx.rng, vn, dist, ctx.strategy) - x, init_logjac = with_logabsdet_jacobian(transform, val) - vi, logjac = setindex_with_dist!!(vi, x, dist, vn, template) + tval = init(ctx.rng, vn, dist, ctx.strategy) + x, init_logjac = with_logabsdet_jacobian(get_transform(tval), get_internal_value(tval)) + # TODO(penelopeysm): This could be inefficient if `tval` is already linked and + # `setindex_with_dist!!` tells it to create a new linked value again. In particular, + # this is inefficient if we use `InitFromParams` that provides linked values. The answer + # to this is to stop using setindex_with_dist!! and just use the TransformedValue + # accumulator. + vi, logjac, _ = setindex_with_dist!!(vi, x, dist, vn, template) # `accumulate_assume!!` wants untransformed values as the second argument. - vi = accumulate_assume!!(vi, x, init_logjac + logjac, vn, dist, template) + vi = accumulate_assume!!(vi, x, tval, init_logjac + logjac, vn, dist, template) # We always return the untransformed value here, as that will determine # what the lhs of the tilde-statement is set to. return x, vi diff --git a/src/debug_utils.jl b/src/debug_utils.jl index 94c4e9ef5..568b81afb 100644 --- a/src/debug_utils.jl +++ b/src/debug_utils.jl @@ -242,7 +242,7 @@ _has_nans(x) = isnan(x) _has_nans(::Missing) = false function DynamicPPL.accumulate_assume!!( - acc::DebugAccumulator, val, _logjac, vn::VarName, right::Distribution, template + acc::DebugAccumulator, val, tval, logjac, vn::VarName, right::Distribution, template ) record_varname!(acc, vn, right) stmt = AssumeStmt(; varname=vn, right=right, value=val) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 835ddbaac..5127e5bd0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -435,7 +435,7 @@ function get_ranges_and_linked(vi::VarInfo) val = tv.val range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, is_transformed(tv), tv.size) + ral = RangeAndLinked(range, tv isa LinkedVectorValue, tv.size) template = vi.values.data[AbstractPPL.getsym(vn)] vnt = templated_setindex!!(vnt, ral, vn, template) return vnt, offset diff --git a/src/onlyaccs.jl b/src/onlyaccs.jl index 42edcac7b..dbd99c943 100644 --- a/src/onlyaccs.jl +++ b/src/onlyaccs.jl @@ -36,8 +36,16 @@ function tilde_assume!!( ) # For OnlyAccsVarInfo, since we don't need to write into the VarInfo, we can # cut out a lot of the code above. - val, transform = init(ctx.rng, vn, dist, ctx.strategy) - x, inv_logjac = with_logabsdet_jacobian(transform, val) - vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist, template) + tval = init(ctx.rng, vn, dist, ctx.strategy) + # Prefer to use the transform from the distribution. + transform = if tval isa LinkedVectorValue + DynamicPPL.from_linked_vec_transform(dist) + elseif tval isa VectorValue + DynamicPPL.from_vec_transform(dist) + else + DynamicPPL.get_transform(tval) + end + x, inv_logjac = with_logabsdet_jacobian(transform, DynamicPPL.get_internal_value(tval)) + vi = accumulate_assume!!(vi, x, tval, -inv_logjac, vn, dist, template) return x, vi end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c769d144c..dd0e518cf 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -102,8 +102,8 @@ end function setindex_with_dist!!( vi::ThreadSafeVarInfo, val, dist::Distribution, vn::VarName, template ) - vi_inner, logjac = setindex_with_dist!!(vi.varinfo, val, dist, vn, template) - return Accessors.@set(vi.varinfo = vi_inner), logjac + vi_inner, logjac, tval = setindex_with_dist!!(vi.varinfo, val, dist, vn, template) + return Accessors.@set(vi.varinfo = vi_inner), logjac, tval end function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) @@ -141,6 +141,9 @@ function is_transformed(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) end getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) +function get_transformed_value(vi::ThreadSafeVarInfo, vn::VarName) + return get_transformed_value(vi.varinfo, vn) +end function unflatten!!(vi::ThreadSafeVarInfo, x::AbstractVector) return Accessors.@set vi.varinfo = unflatten!!(vi.varinfo, x) diff --git a/src/transformed_values.jl b/src/transformed_values.jl new file mode 100644 index 000000000..b1d7e92e6 --- /dev/null +++ b/src/transformed_values.jl @@ -0,0 +1,199 @@ +# TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was +# using a Vector as the internal storage in all cases. We should revisit this, and allow +# values to be stored "raw", since VarNamedTuple supports it. +# +# NOTE(penelopeysm): The main problem with unvectorising values is that when calling +# `unflatten!!`, it is not clear how many elements to take from the vector. In general +# we would need to know the distribution to get this data, which is fine if we are +# executing the model, but `unflatten!!` does not have that information. As long as we +# depend on the behaviour of `unflatten!!` somewhere, we cannot get rid of it. + +# TODO(mhauru) Related to the above, I think we should reconsider whether we should store +# transformations at all. We rarely use them, since they may be dynamic in a model. +# tilde_assume!! rather gets the transformation from the current distribution encountered +# during model execution. However, this would change the interface quite a lot, so I want to +# finish implementing VarInfo using VNT (mostly) respecting the old interface first. +# +# NOTE(penelopeysm) The above is in principle doable right now. the main issue with removing +# the transform is that we cannot get `varinfo[vn]` any more. It is arguable whether this +# method is really needed. On one hand, it is a pretty useful way of seeing the current +# value of a variable in the VarInfo. On the other hand, it is not guaranteed to be correct +# (because `unflatten!` might change the required transform); so one could argue that the +# question of "what is the true value" is generally unanswerable, and we should not expose a +# method that pretends to know the answer.. I would lean towards removing it, but doing so +# would require a fair amount of changes in the test suite, so it will have to wait for a +# time when fewer big PRs are ongoing. + +""" + AbstractTransformedValue + +An abstract type for values that enter the DynamicPPL tilde-pipeline. + +These values are generated by an [`AbstractInitStrategy`](@ref): the function +[`DynamicPPL.init`](@ref) should return an `AbstractTransformedValue`. + +Each `AbstractTransformedValue` contains some version of the actual variable's value, +together with a transformation that can be used to convert the internal value back to the +original space. + +Current subtypes are [`VectorValue`](@ref), [`LinkedVectorValue`](@ref), and +[`UntransformedValue`](@ref). DynamicPPL's [`VarInfo`](@ref) type stores either +`VectorValue`s or `LinkedVectorValue`s internally, depending on the link status of the +`VarInfo`. + +!!! warning + Even though the subtypes listed above are public, this abstract type is not itself part + of the public API and should not be subtyped by end users. Much of DynamicPPL's model + evaluation methods depends on these subtypes having predictable behaviour, i.e., their + transforms should always be `from_linked_vec_transform(dist)`, + `from_vec_transform(dist)`, or their inverse. If you create a new subtype of + `AbstractTransformedValue` and use it, DynamicPPL will not know how to handle it and may + either error or silently give incorrect results. + + In principle, it should be possible to subtype this and allow for custom transformations + to be used (not just the 'default' ones). However, this is not currently implemented. + +Subtypes of this should implement the following functions: + +- `DynamicPPL.get_transform(tv::AbstractTransformedValue)`: Get the transformation that + converts the internal value back to the original space. + +- `DynamicPPL.get_internal_value(tv::AbstractTransformedValue)`: Get the internal value + stored in `tv`. + +- `DynamicPPL.set_internal_value(tv::AbstractTransformedValue, new_val)`: Create a new + `AbstractTransformedValue` with the same transformation as `tv`, but with internal value + `new_val`. + +- `DynamicPPL.VarNamedTuples.vnt_size(tv::AbstractTransformedValue)`: Get the size of the + original value before transformation. +""" +abstract type AbstractTransformedValue end + +""" + get_transform(tv::AbstractTransformedValue) + +Get the transformation that converts the internal value back to the raw value. + +!!! warning + If the distribution associated with the variable has changed since this + `AbstractTransformedValue` was created, this transform may be inaccurate. This can + happen e.g. if `unflatten!!` has been called on a VarInfo containing this. + + Consequently, when the distribution on the right-hand side of a tilde-statement is + available, you should always prefer regenerating the transform from that distribution + rather than using this function. +""" +function get_transform end + +""" + get_internal_value(tv::AbstractTransformedValue) + +Get the internal value stored in `tv`. +""" +function get_internal_value end + +""" + set_internal_value(tv::AbstractTransformedValue, new_val) + +Create a new `AbstractTransformedValue` with the same transformation as `tv`, but with +internal value `new_val`. +""" +function set_internal_value end + +""" + VectorValue{V<:AbstractVector,T,S} + +A transformed value that stores its internal value as a vectorised form. This is what +VarInfo sees as an "unlinked value". + +These values can be generated when using `InitFromParams` with a VarInfo's internal values. +""" +struct VectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue + "The internal (vectorised) value." + val::V + """The unvectorisation transform required to convert `val` back to the original space. + + Note that this transform is cached and thus may be inaccurate if `unflatten!!` is called + on the VarInfo containing this `VectorValue`. This transform is only ever used when + calling `varinfo[vn]` to get the original value back; in all other cases, where model + evaluation occurs, the correct transform is determined from the distribution associated + with the variable.""" + transform::T + """The size of the original value before transformation. This is needed when a + `TransformedValue` is stored as a block in an array.""" + size::S + function VectorValue(val::V, tfm::T, size::S) where {V<:AbstractVector,T,S} + return new{V,T,S}(val, tfm, size) + end +end + +""" + LinkedVectorValue{V<:AbstractVector,T,S} + +A transformed value that stores its internal value as a linked andvectorised form. This is +what VarInfo sees as a "linked value". + +These values can be generated when using `InitFromParams` with a VarInfo's internal values. +""" +struct LinkedVectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue + "The internal (linked + vectorised) value." + val::V + """The unlinking transform required to convert `val` back to the original space. + + Note that this transform is cached and thus may be inaccurate if `unflatten!!` is called + on the VarInfo containing this `VectorValue`. This transform is only ever used when + calling `varinfo[vn]` to get the original value back; in all other cases, where model + evaluation occurs, the correct transform is determined from the distribution associated + with the variable.""" + transform::T + """The size of the original value before transformation. This is needed when a + `TransformedValue` is stored as a block in an array.""" + size::S + function LinkedVectorValue(val::V, tfm::T, size::S) where {V<:AbstractVector,T,S} + return new{V,T,S}(val, tfm, size) + end +end + +for T in (:VectorValue, :LinkedVectorValue) + @eval begin + function Base.:(==)(tv1::$T, tv2::$T) + return (tv1.val == tv2.val) & + (tv1.transform == tv2.transform) & + (tv1.size == tv2.size) + end + function Base.isequal(tv1::$T, tv2::$T) + return isequal(tv1.val, tv2.val) && + isequal(tv1.transform, tv2.transform) && + isequal(tv1.size, tv2.size) + end + VarNamedTuples.vnt_size(tv::$T) = tv.size + + get_transform(tv::$T) = tv.transform + get_internal_value(tv::$T) = tv.val + + function set_internal_value(tv::$T, new_val) + return $T(new_val, tv.transform, tv.size) + end + end +end + +""" + UntransformedValue{V} + +A raw, untransformed, value. + +These values can be generated from initialisation strategies such as `InitFromPrior`, +`InitFromUniform`, and `InitFromParams` on a standard container type. +""" +struct UntransformedValue{V} <: AbstractTransformedValue + "The value." + val::V + UntransformedValue(val::V) where {V} = new{V}(val) +end +VarNamedTuples.vnt_size(tv::UntransformedValue) = vnt_size(tv.val) +Base.:(==)(tv1::UntransformedValue, tv2::UntransformedValue) = tv1.val == tv2.val +Base.isequal(tv1::UntransformedValue, tv2::UntransformedValue) = isequal(tv1.val, tv2.val) +get_transform(::UntransformedValue) = typed_identity +get_internal_value(tv::UntransformedValue) = tv.val +set_internal_value(::UntransformedValue, new_val) = UntransformedValue(new_val) diff --git a/src/varinfo.jl b/src/varinfo.jl index c80361687..0c5ea1dca 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -7,14 +7,13 @@ The `Linked` type parameter is either `true` or `false` to mark that all variabl `VarInfo` are linked, or `nothing` to indicate that some variables may be linked and some not, and a runtime check is needed. -`VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, -and a tuple of accumulators. The only really noteworthy thing about it is that it stores -the values of variables vectorised as instances of `TransformedValue`. That is, it stores -each value as a vector and a transformation to be applied to that vector to get the actual -value. It also stores whether the transformation is such that it guarantees all real vectors -to be valid internal representations of the variable (i.e., whether the variable has been -linked), as well as the size of the actual post-transformation value. These are all fields -of [`TransformedValue`](@ref). +`VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, and +a tuple of accumulators. The only really noteworthy thing about it is that it stores the +values of variables vectorised as instances of [`AbstractTransformedValue`](@ref). That is, +it stores each value as a special vector with a flag indicating whether it is just a +vectorised value ([`VectorValue`](@ref)), or whether it is also linked +([`LinkedVectorValue`](@ref)). It also stores the size of the actual post-transformation +value. These are all accessible via [`AbstractTransformedValue`](@ref). Note that `setindex!!` and `getindex` on `VarInfo` take and return values in the support of the original distribution. To get access to the internal vectorised values, use @@ -24,8 +23,8 @@ There's also a `VarInfo`-specific function [`setindex_with_dist!!`](@ref), which variable's value with a transformation based on the statistical distribution this value is a sample for. -For more details on the internal storage, see documentation of [`TransformedValue`](@ref) and -[`VarNamedTuple`](@ref). +For more details on the internal storage, see documentation of +[`AbstractTransformedValue`](@ref) and [`VarNamedTuple`](@ref). # Fields $(TYPEDFIELDS) @@ -42,51 +41,12 @@ struct VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInf end end -# TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was -# using a Vector as the internal storage in all cases. We should revisit this, and allow -# values to be stored "raw", since VarNamedTuple supports it. - -# TODO(mhauru) Related to the above, I think we should reconsider whether we should store -# transformations at all. We rarely use them, since they may be dynamic in a model. -# tilde_assume!! rather gets the transformation from the current distribution encountered -# during model execution. However, this would change the interface quite a lot, so I want to -# finish implementing VarInfo using VNT (mostly) respecting the old interface first. - -# TODO(mhauru) We are considering removing `transform` completely, and forcing people to use -# ValuesAsInModelAcc instead. If that is done, we may want to move the Linked type parameter -# to just be a bool field. It's currently a type parameter to make the type of `transform` -# easier to type infer, but if `transform` no longer exists, it might start to cause -# unnecessary type inconcreteness in the elements of PartialArray. -""" - TransformedValue{Linked,ValType,TransformType,SizeType} - -A struct for storing a variable's value in its internal (vectorised) form. - -The type parameter `Linked` is a `Bool` indicating whether the variable is linked, i.e. -whether the transformation maps all real vectors to valid values. -# Fields -$(TYPEDFIELDS) -""" -struct TransformedValue{Linked,ValType<:AbstractVector,TransformType,SizeType} - "The internal (vectorised) value." - val::ValType - """The transformation from internal (vectorised) to actual value. In other words, the - actual value of the variable being stored is `transform(val)`.""" - transform::TransformType - """The size of the actual value after transformation. This is needed when a - `TransformedValue` is stored as a block in an array.""" - size::SizeType - - function TransformedValue{Linked}( - val::ValType, transform::TransformType, size::SizeType - ) where {Linked,ValType,TransformType,SizeType} - return new{Linked,ValType,TransformType,SizeType}(val, transform, size) - end +function Base.:(==)(vi1::VarInfo, vi2::VarInfo) + return (vi1.values == vi2.values) & (vi1.accs == vi2.accs) +end +function Base.isequal(vi1::VarInfo, vi2::VarInfo) + return isequal(vi1.values, vi2.values) && isequal(vi1.accs, vi2.accs) end - -is_transformed(::TransformedValue{Linked}) where {Linked} = Linked - -VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size VarInfo() = VarInfo{false}(VarNamedTuple(), default_accumulators()) @@ -99,16 +59,96 @@ function VarInfo(values::Union{NamedTuple,AbstractDict}) return vi end -function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return VarInfo(Random.default_rng(), model, init_strategy) -end +""" + VarInfo( + [rng::AbstractRNG,] + model::Model, + link::AbstractLinkStrategy=UnlinkAll(), + init::AbstractInitStrategy=InitFromPrior() + ) -function VarInfo( +Create a fresh `VarInfo` for the given model by running the model and populating it +according to the given initialisation strategy. The resulting variables in the `VarInfo` can +be linked or unlinked according to the given linking strategy. + +# Arguments + +- `rng::AbstractRNG`: An optional random number generator to use for any stochastic + initialisation. If not provided, `Random.default_rng()` is used. +- `model::Model`: The model for which to create the `VarInfo`. +- `link::AbstractLinkStrategy`: An optional linking strategy (see `AbstractLinkStrategy`). + Defaults to `UnlinkAll()`, i.e., all variables are vectorised but not linked. +- `init::AbstractInitStrategy`: An optional initialisation strategy (see + [`AbstractInitStrategy`](@ref)). Defaults to `InitFromPrior()`, i.e., all variables are + initialised by sampling from their prior distributions. + +# Extended help + +## Performance characteristics of linked VarInfo + +This method allows the immediate generation of a linked VarInfo, which was not possible in +previous versions of DynamicPPL. It is guaranteed that `link!!(VarInfo(rng, model), model)` +(the old way of instantiating a linked `VarInfo`) is equivalent to `VarInfo(rng, model, +LinkAll())`. + +Depending on the model, each of these two methods may be more performant, although the +reasons for this are still somewhat unclear. Small models tend to do better with +instantiating an unlinked `VarInfo` and then linking it, while large models tend to do +better with directly instantiating a linked `VarInfo`. The hope is that this generally does +not impact usage since linking is not typically something done in performance-critical +sections of Turing.jl. If linking performance is critical, it is recommended to benchmark +both methods for the specific model in question. +""" +function DynamicPPL.VarInfo( + rng::Random.AbstractRNG, + model::Model, + ::Union{UnlinkAll,UnlinkSome}, + initstrat::AbstractInitStrategy, +) + # In this case, no variables are to be linked. We can optimise performance by directly + # calling init!! and not faffing about with accumulators. (This does lead to significant + # performance improvements for the typical use case of generating an unlinked VarInfo.) + return last(init!!(rng, model, VarInfo(), initstrat)) +end +function DynamicPPL.VarInfo( rng::Random.AbstractRNG, model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), + linkstrat::AbstractLinkStrategy, + initstrat::AbstractInitStrategy=InitFromPrior(), +) + linked_value_acc = VNTAccumulator{LINK_ACCNAME}(Link!(linkstrat)) + vi = OnlyAccsVarInfo((linked_value_acc, default_accumulators()...)) + vi = last(init!!(rng, model, vi, initstrat)) + # Extract the linked values and the change in logjac. + link_acc = getacc(vi, Val(LINK_ACCNAME)) + new_vi_is_linked = if linkstrat isa LinkAll + true + else + # TODO(penelopeysm): We can definitely do better here. The linking accumulator can + # keep track of whether any variables were linked or unlinked, and we can use that + # here. It won't be type-stable, but that's fine, right now it isn't either. + nothing + end + vi = VarInfo{new_vi_is_linked}( + link_acc.values, DynamicPPL.deleteacc(vi.accs, Val(LINK_ACCNAME)) + ) + vi = acclogjac!!(vi, link_acc.f.logjac) + return vi +end +function DynamicPPL.VarInfo( + model::Model, + linkstrat::AbstractLinkStrategy, + initstrat::AbstractInitStrategy=InitFromPrior(), +) + return DynamicPPL.VarInfo(Random.default_rng(), model, linkstrat, initstrat) +end +function VarInfo( + rng::Random.AbstractRNG, model::Model, initstrat::AbstractInitStrategy=InitFromPrior() ) - return last(init!!(rng, model, VarInfo(), init_strategy)) + return VarInfo(rng, model, UnlinkAll(), initstrat) +end +function VarInfo(model::Model, initstrat::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, initstrat) end getaccs(vi::VarInfo) = vi.accs @@ -130,7 +170,12 @@ Base.keys(vi::VarInfo) = keys(vi.values) # Union{Vector{Union{}}, Vector{Float64}} (I suppose this is because it can't tell whether # the result will be empty or not...? Not sure). function Base.values(vi::VarInfo) - return mapreduce(p -> p.second.transform(p.second.val), push!, vi.values; init=Any[]) + return mapreduce( + p -> DynamicPPL.get_transform(p.second)(DynamicPPL.get_internal_value(p.second)), + push!, + vi.values; + init=Any[], + ) end function Base.show(io::IO, ::MIME"text/plain", vi::VarInfo{link}) where {link} @@ -150,9 +195,8 @@ end function Base.getindex(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) - return tv.transform(tv.val) + return get_transform(tv)(get_internal_value(tv)) end - function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) return [getindex(vi, vn) for vn in vns] end @@ -170,7 +214,7 @@ This does not change the transformation or linked status of the variable. """ function setindex_internal!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} old_tv = getindex(vi.values, vn) - new_tv = TransformedValue{is_transformed(old_tv)}(val, old_tv.transform, old_tv.size) + new_tv = set_internal_value(old_tv, val) new_values = setindex!!(vi.values, new_tv, vn) return VarInfo{Linked}(new_values, vi.accs) end @@ -192,8 +236,10 @@ currently linked in `vi`, or doesn't exist in `vi` but all other variables in `v linked, the linking transformation is used; otherwise, the standard vector transformation is used. -Returns the modified `vi` together with the log absolute determinant of the Jacobian of the -transformation applied. +Returns three things: + - the modified `vi`, + - the log absolute determinant of the Jacobian of the transformation applied, + - the `AbstractTransformedValue` used to store the value. """ function setindex_with_dist!!( vi::VarInfo{Linked}, val, dist::Distribution, vn::VarName, template @@ -204,17 +250,21 @@ function setindex_with_dist!!( Linked end transform = if link - from_linked_vec_transform(dist) + to_linked_vec_transform(dist) else - from_vec_transform(dist) + to_vec_transform(dist) end - transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) + transformed_val, logjac = with_logabsdet_jacobian(transform, val) # All values for which `size` is not defined are assumed to be scalars. val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () - tv = TransformedValue{link}(transformed_val, transform, val_size) + tv = if link + LinkedVectorValue(transformed_val, inverse(transform), val_size) + else + VectorValue(transformed_val, inverse(transform), val_size) + end new_linked = Linked == link ? Linked : nothing vi = VarInfo{new_linked}(templated_setindex!!(vi.values, tv, vn, template), vi.accs) - return vi, logjac + return vi, logjac, tv end # TODO(mhauru) The below is somewhat unsafe or incomplete: For instance, from_vec_transform @@ -230,10 +280,11 @@ The transformation for `vn` is reset to be the standard vector transformation fo the type of `val` and linking status is set to false. """ function BangBang.setindex!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} + # TODO(penelopeysm) This function is BS, really should get rid of it asap new_linked = Linked == false ? false : nothing transform = from_vec_transform(val) transformed_val = inverse(transform)(val) - tv = TransformedValue{false}(transformed_val, transform, size(val)) + tv = VectorValue(transformed_val, transform, size(val)) return VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) end @@ -242,11 +293,16 @@ end Set the linked status of variable `vn` in `vi` to `linked`. -This does not change the value or transformation of the variable. +Note that this function is potentially unsafe as it does not change the value or +transformation of the variable! """ function set_transformed!!(vi::VarInfo{Linked}, linked::Bool, vn::VarName) where {Linked} old_tv = getindex(vi.values, vn) - new_tv = TransformedValue{linked}(old_tv.val, old_tv.transform, old_tv.size) + new_tv = if linked + LinkedVectorValue(old_tv.val, old_tv.transform, old_tv.size) + else + VectorValue(old_tv.val, old_tv.transform, old_tv.size) + end new_values = setindex!!(vi.values, new_tv, vn) new_linked = Linked == linked ? Linked : nothing return VarInfo{new_linked}(new_values, vi.accs) @@ -267,8 +323,9 @@ end set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) function set_transformed!!(vi::VarInfo, linked::Bool) + ctor = linked ? LinkedVectorValue : VectorValue new_values = map_values!!(vi.values) do tv - TransformedValue{linked}(tv.val, tv.transform, tv.size) + ctor(tv.val, tv.transform, tv.size) end return VarInfo{linked}(new_values, vi.accs) end @@ -282,9 +339,16 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val # TODO(mhauru) The below should be removed together with unflatten!!. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) +""" + get_transformed_value(vi::VarInfo, vn::VarName) + +Get the entire `AbstractTransformedValue` for variable `vn` in `vi`. +""" +get_transformed_value(vi::VarInfo, vn::VarName) = getindex(vi.values, vn) + function is_transformed(vi::VarInfo{Linked}, vn::VarName) where {Linked} return if Linked === nothing - is_transformed(getindex(vi.values, vn)) + getindex(vi.values, vn) isa LinkedVectorValue else Linked end @@ -299,87 +363,43 @@ function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution end function from_internal_transform(vi::VarInfo, vn::VarName) - return getindex(vi.values, vn).transform + return DynamicPPL.get_transform(getindex(vi.values, vn)) end function from_linked_internal_transform(vi::VarInfo, vn::VarName) if !is_transformed(vi, vn) error("Variable $vn is not linked; cannot get linked transformation.") end - return getindex(vi.values, vn).transform -end - -""" - _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link isa Bool} - -The internal function that implements both link!! and invlink!!. - -The last argument controls whether linking (true) or invlinking (false) is performed. If -`vns` is `nothing`, all variables in `vi` are transformed; otherwise, only the variables -in `vns` are transformed. Existing variables already in the desired state are left -unchanged. -""" -function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link} - @assert link isa Bool - # Note that extract_priors causes a model execution. In the past with the Metadata-based - # VarInfo we rather derived the transformations from the distributions stored in the - # VarInfo itself. However, that is not fail-safe with dynamic models, and would require - # storing the distributions in TransformedValue (which we could start doing). Instead we - # use extract_priors to get the current, correct transformations. This logic is very - # similar to what DynamicTransformation used to do, and we might replace this with a - # context that transforms each variable in turn during the execution. - dists = extract_priors(model, vi) - cumulative_logjac = zero(LogProbType) - new_values = map_pairs!!(vi.values) do pair - vn, tv = pair - if vns !== nothing && !any(x -> subsumes(x, vn), vns) - # Not one of the target variables. - return tv - end - if is_transformed(tv) == link - # Already in the desired state. - return tv - end - dist = getindex(dists, vn)::Distribution - vec_transform = from_vec_transform(dist) - link_transform = from_linked_vec_transform(dist) - current_transform, new_transform = if link - (vec_transform, link_transform) - else - (link_transform, vec_transform) - end - val_untransformed, logjac1 = with_logabsdet_jacobian(current_transform, tv.val) - val_new, logjac2 = with_logabsdet_jacobian( - inverse(new_transform), val_untransformed - ) - # !is_transformed(tv) is the same as `link`, but might be easier for type inference. - new_tv = TransformedValue{!is_transformed(tv)}(val_new, new_transform, tv.size) - cumulative_logjac += logjac1 + logjac2 - return new_tv - end - vi_linked = if vns === nothing - link - else - nothing + return DynamicPPL.get_transform(getindex(vi.values, vn)) +end + +# TODO(penelopeysm): In principle, `link` can be statically determined from the type of +# `linker`. However, I'm not sure if doing that could mess with type stability. +function _link_or_invlink!!( + orig_vi::VarInfo, linker::AbstractLinkStrategy, model::Model, ::Val{link} +) where {link} + linked_value_acc = VNTAccumulator{LINK_ACCNAME}(Link!(linker)) + new_vi = OnlyAccsVarInfo((linked_value_acc,)) + new_vi = last(init!!(model, new_vi, InitFromParamsUnsafe(orig_vi.values))) + link_acc = getacc(new_vi, Val(LINK_ACCNAME)) + new_vi = VarInfo{link}(link_acc.values, orig_vi.accs) + if hasacc(new_vi, Val(:LogJacobian)) + new_vi = acclogjac!!(new_vi, link_acc.f.logjac) end - vi = VarInfo{vi_linked}(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi + return new_vi end function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) - return _link_or_invlink!!(vi, vns, model, Val(true)) -end -function link!!(::DynamicTransformation, vi::VarInfo, model::Model) - return _link_or_invlink!!(vi, nothing, model, Val(true)) + return _link_or_invlink!!(vi, LinkSome(Set(vns)), model, Val(nothing)) end function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) - return _link_or_invlink!!(vi, vns, model, Val(false)) + return _link_or_invlink!!(vi, UnlinkSome(Set(vns)), model, Val(nothing)) +end +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, LinkAll(), model, Val(true)) end function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) - return _link_or_invlink!!(vi, nothing, model, Val(false)) + return _link_or_invlink!!(vi, UnlinkAll(), model, Val(false)) end function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) @@ -422,11 +442,16 @@ function values_as(vi::VarInfo, ::Type{Vector}) end function values_as(vi::VarInfo, ::Type{T}) where {T<:AbstractDict} - return mapfoldl(identity, function (cumulant, pair) - vn, tv = pair - val = tv.transform(tv.val) - return setindex!!(cumulant, val, vn) - end, vi.values; init=T()) + return mapfoldl( + identity, + function (cumulant, pair) + vn, tv = pair + val = DynamicPPL.get_transform(tv)(DynamicPPL.get_internal_value(tv)) + return setindex!!(cumulant, val, vn) + end, + vi.values; + init=T(), + ) end # TODO(mhauru) I really dislike this sort of conversion to Symbols, but it's the current @@ -438,7 +463,7 @@ function values_as(vi::VarInfo, ::Type{NamedTuple}) identity, function (cumulant, pair) vn, tv = pair - val = tv.transform(tv.val) + val = DynamicPPL.get_transform(tv)(DynamicPPL.get_internal_value(tv)) return setindex!!(cumulant, val, Symbol(vn)) end, vi.values; @@ -458,12 +483,16 @@ mutable struct VectorChunkIterator!{T<:AbstractVector} vec::T index::Int end -function (vci::VectorChunkIterator!)(tv::TransformedValue{Linked}) where {Linked} - old_val = tv.val - len = length(old_val) - new_val = @view vci.vec[(vci.index):(vci.index + len - 1)] - vci.index += len - return TransformedValue{Linked}(new_val, tv.transform, tv.size) +for T in (:VectorValue, :LinkedVectorValue) + @eval begin + function (vci::VectorChunkIterator!)(tv::$T) + old_val = tv.val + len = length(old_val) + new_val = @view vci.vec[(vci.index):(vci.index + len - 1)] + vci.index += len + return $T(new_val, tv.transform, tv.size) + end + end end function unflatten!!(vi::VarInfo{Linked}, vec::AbstractVector) where {Linked} vci = VectorChunkIterator!(vec, 1) diff --git a/test/accumulators.jl b/test/accumulators.jl index 846173981..9d1d295ca 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -67,18 +67,19 @@ using DynamicPPL: @testset "accumulate_assume" begin val = 2.0 + tval = DynamicPPL.UntransformedValue(nothing) logjac = pi vn = @varname(x) dist = Normal() template = nothing @test accumulate_assume!!( - LogPriorAccumulator(1.0), val, logjac, vn, dist, template + LogPriorAccumulator(1.0), val, tval, logjac, vn, dist, template ) == LogPriorAccumulator(1.0 + logpdf(dist, val)) @test accumulate_assume!!( - LogJacobianAccumulator(2.0), val, logjac, vn, dist, template + LogJacobianAccumulator(2.0), val, tval, logjac, vn, dist, template ) == LogJacobianAccumulator(2.0 + logjac) @test accumulate_assume!!( - LogLikelihoodAccumulator(1.0), val, logjac, vn, dist, template + LogLikelihoodAccumulator(1.0), val, tval, logjac, vn, dist, template ) == LogLikelihoodAccumulator(1.0) end diff --git a/test/contexts.jl b/test/contexts.jl index 7f9e666c5..f2cfffe5c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -464,10 +464,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # start by generating some rubbish values vi = deepcopy(empty_vi) old_x, old_y = 100000.00, [300000.00, 500000.00] - vi, _ = DynamicPPL.setindex_with_dist!!( + vi, _, _ = DynamicPPL.setindex_with_dist!!( vi, old_x, Normal(), @varname(x), DynamicPPL.NoTemplate() ) - vi, _ = DynamicPPL.setindex_with_dist!!( + vi, _, _ = DynamicPPL.setindex_with_dist!!( vi, old_y, MvNormal(fill(old_x, 2), I), diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 445270ef8..c580a1591 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -1,6 +1,7 @@ module DynamicPPLMCMCChainsExtTests -using DynamicPPL, Distributions, MCMCChains, Test, AbstractMCMC +using DynamicPPL, Distributions, MCMCChains, Test +using AbstractMCMC: AbstractMCMC @testset "DynamicPPLMCMCChainsExt" begin @testset "from_samples" begin diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 02d3e0ce4..6ad93a75c 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -120,7 +120,7 @@ end struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR DynamicPPL.accumulate_assume!!( - ::ErrorAccumulator, ::Any, ::Any, ::VarName, ::Distribution, ::Any + ::ErrorAccumulator, ::Any, ::Any, ::Any, ::VarName, ::Distribution, ::Any ) = throw(ErrorAccumulatorException()) DynamicPPL.accumulate_observe!!( ::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing} diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index fde807dda..e15339022 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -103,7 +103,10 @@ end model_m_only = m_only() chain_m_only = AbstractMCMC.from_samples( MCMCChains.Chains, - hcat([ParamsWithStats(VarInfo(model_m_only), model_m_only) for _ in 1:50]), + hcat([ + DynamicPPL.ParamsWithStats(VarInfo(model_m_only), model_m_only) for + _ in 1:50 + ]), ) # Define a model that needs both `m` and `s`. diff --git a/test/test_util.jl b/test/test_util.jl index 8f402ad8f..c92f770e2 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -44,7 +44,10 @@ Construct an MCMCChains.Chains object by sampling from the prior of `model` for function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) vi = VarInfo(model) vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.ValuesAsInModelAccumulator(false),)) - ps = hcat([ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi))) for _ in 1:n_iters]) + ps = hcat([ + DynamicPPL.ParamsWithStats(last(DynamicPPL.init!!(rng, model, vi))) for + _ in 1:n_iters + ]) return AbstractMCMC.from_samples(MCMCChains.Chains, ps) end function make_chain_from_prior(model::Model, n_iters::Int) diff --git a/test/varinfo.jl b/test/varinfo.jl index 323050165..f5e061117 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -13,6 +13,8 @@ function check_metadata_type_equal( return check_metadata_type_equal(v1.varinfo, v2.varinfo) end +using Random: Xoshiro + @testset "varinfo.jl" begin @testset "Base" begin # Test Base functions: @@ -275,10 +277,12 @@ end _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) vals = values(vi) - all_transformed(vi) = - mapreduce(p -> is_transformed(p.second), &, vi.values; init=true) - any_transformed(vi) = - mapreduce(p -> is_transformed(p.second), |, vi.values; init=false) + all_transformed(vi) = mapreduce( + p -> p.second isa DynamicPPL.LinkedVectorValue, &, vi.values; init=true + ) + any_transformed(vi) = mapreduce( + p -> p.second isa DynamicPPL.LinkedVectorValue, |, vi.values; init=false + ) @test !any_transformed(vi) @@ -312,6 +316,55 @@ end end end + @testset "instantiation with link strategy" begin + @model function f() + x ~ Beta(2, 2) + return y ~ LogNormal(0, 1) + end + + function test_link_strategy( + link_strategy::DynamicPPL.AbstractLinkStrategy, + model::DynamicPPL.Model, + expected_linked_vns::Set{<:VarName}, + ) + # Test that the variables are linked according to the link strategy + vi = VarInfo(Xoshiro(468), model, link_strategy) + for vn in keys(vi) + if vn in expected_linked_vns + @test DynamicPPL.get_transformed_value(vi, vn) isa + DynamicPPL.LinkedVectorValue + else + @test DynamicPPL.get_transformed_value(vi, vn) isa + DynamicPPL.VectorValue + end + end + # Test that initialising directly is the same as linking later (if rng is the + # same) + if link_strategy isa LinkAll + vi2 = VarInfo(Xoshiro(468), model) + vi2 = DynamicPPL.link!!(vi2, model) + @test vi == vi2 + end + if link_strategy isa LinkSome + vi2 = VarInfo(Xoshiro(468), model) + vi2 = DynamicPPL.link!!(vi2, link_strategy.vns, model) + @test vi == vi2 + end + end + + model = f() + test_link_strategy(LinkAll(), model, Set([@varname(x), @varname(y)])) + test_link_strategy(LinkSome((@varname(x),)), model, Set([@varname(x)])) + test_link_strategy(LinkSome((@varname(y),)), model, Set([@varname(y)])) + test_link_strategy( + LinkSome((@varname(x), @varname(y))), model, Set([@varname(x), @varname(y)]) + ) + test_link_strategy(UnlinkAll(), model, Set{VarName}()) + test_link_strategy(UnlinkSome((@varname(x),)), model, Set{VarName}()) + test_link_strategy(UnlinkSome((@varname(y),)), model, Set{VarName}()) + test_link_strategy(UnlinkSome((@varname(x), @varname(y))), model, Set{VarName}()) + end + @testset "logp evaluation on linked varinfo" begin @model demo_constrained() = x ~ truncated(Normal(); lower=0) model = demo_constrained()