Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/accs/values.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ vi = VarInfo(dirichlet_model)
vi
```

In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref) documentation for information on this).
In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref array-like-blocks) documentation for information on this).
The reason is because, if the value is linked, it may have a different size than the number of indices in the `VarName`.
This means that when retrieving the keys, we obtain each block as a single key:

Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ For more details on `VarNamedTuple`, see the Internals section of our documentat
```@docs
DynamicPPL.VarNamedTuples.VarNamedTuple
DynamicPPL.VarNamedTuples.@vnt
DynamicPPL.VarNamedTuples.vnt_size
DynamicPPL.VarNamedTuples.apply!!
DynamicPPL.VarNamedTuples.densify!!
DynamicPPL.VarNamedTuples.map_pairs!!
Expand Down
26 changes: 1 addition & 25 deletions docs/src/vnt/arraylikeblocks.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Array-like blocks
# [Array-like blocks](@id array-like-blocks)

In a number of VNT use cases, it is necessary to associate multiple indices in a `VarNamedTuple` with an object that is not necessarily the same number of elements.

Expand Down Expand Up @@ -87,30 +87,6 @@ Furthermore, if you set a value into any of the indices covered by the block, th
vnt = DynamicPPL.templated_setindex!!(vnt, Normal(), @varname(x[2]), x)
```

## Size checks

Currently, when setting any object `val` as an `ArrayLikeBlock`, there is a size check: we make sure that the range of indices being set to has the same size as `DynamicPPL.VarNamedTuples.vnt_size(val)`.
By default, `vnt_size(x)` returns `Base.size(x)`.

```@example 1
DynamicPPL.VarNamedTuples.vnt_size(Dirichlet(ones(3)))
```

This is what allows us to set a `Dirichlet` distribution to three indices.
However, trying to set the same distribution to two indices will fail:

```@repl 1
vnt = DynamicPPL.templated_setindex!!(
VarNamedTuple(), Dirichlet(ones(3)), @varname(x[1:2]), zeros(5)
)
```

!!! note

In principle, these checks can be removed since if `Dirichlet(ones(3))` is set as the prior of `x[1:2]`, then model evaluation will error anyway.
Furthermore, if at any point we need to know the size of the block, we can always retrieve it via `size(view(parent_array, alb.ix...; alb.kw...))`.
However, the checks are still here for now.

## Which parts of DynamicPPL use array-like blocks?

Simply put, anywhere where we don't store raw values.
Expand Down
2 changes: 1 addition & 1 deletion src/accumulators/pointwise_logdensities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ function (plp::PointwiseLogProb{Prior,Likelihood})(
return DoNotAccumulate()
end
end
const POINTWISE_ACCNAME = :PointwiseLogProbAccumulator
const POINTWISE_ACCNAME = :PointwiseLogProb

# Not exported
function get_pointwise_logprobs(varinfo::AbstractVarInfo)
Expand Down
15 changes: 7 additions & 8 deletions src/accumulators/vector_values.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
const VECTORVAL_ACCNAME = :VectorValue
_get_vector_tval(val, tval::Union{VectorValue,LinkedVectorValue}, logjac, vn, dist) = tval
function _get_vector_tval(val, ::UntransformedValue, logjac, vn, dist)
original_val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : ()
f = to_vec_transform(dist)
new_val, logjac = with_logabsdet_jacobian(f, val)
@assert iszero(logjac) # otherwise we're in trouble...
return VectorValue(new_val, inverse(f), original_val_size)
return VectorValue(new_val, inverse(f))
end

# This is equivalent to `varinfo.values` where `varinfo isa VarInfo`
Expand Down Expand Up @@ -34,12 +33,12 @@ julia> using DynamicPPL

julia> # In a real setting the other fields would be filled in with meaningful values.
vnt = @vnt begin
x := VectorValue([1.0, 2.0], nothing, nothing)
y := LinkedVectorValue([3.0], nothing, nothing)
x := VectorValue([1.0, 2.0], nothing)
y := LinkedVectorValue([3.0], nothing)
end
VarNamedTuple
├─ x => VectorValue{Vector{Float64}, Nothing, Nothing}([1.0, 2.0], nothing, nothing)
└─ y => LinkedVectorValue{Vector{Float64}, Nothing, Nothing}([3.0], nothing, nothing)
├─ x => VectorValue{Vector{Float64}, Nothing}([1.0, 2.0], nothing)
└─ y => LinkedVectorValue{Vector{Float64}, Nothing}([3.0], nothing)

julia> internal_values_as_vector(vnt)
3-element Vector{Float64}:
Expand Down Expand Up @@ -70,8 +69,8 @@ julia> # note InitFromParams provides parameters in untransformed space
julia> # but because we specified LinkAll(), the vectorised values are transformed
vector_vals = get_vector_values(accs)
VarNamedTuple
├─ x => LinkedVectorValue{Vector{Float64}, ComposedFunction{typeof(identity), typeof(identity)}, Tuple{Int64}}([1.0, 2.0], identity ∘ identity, (2,))
└─ y => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}, Tuple{}}([0.0], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())), ())
├─ x => LinkedVectorValue{Vector{Float64}, ComposedFunction{typeof(identity), typeof(identity)}}([1.0, 2.0], identity ∘ identity)
└─ y => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}}([0.0], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())))

julia> # we can extract the internal values as a single vector
internal_values_as_vector(vector_vals)
Expand Down
1 change: 1 addition & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ function tilde_observe!!(
::Distribution,
::Any,
::Union{VarName,Nothing},
::Any,
::AbstractVarInfo,
)
return error("tilde_observe!! not implemented for context of type $(typeof(context))")
Expand Down
25 changes: 10 additions & 15 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro
# sample.
real_sz = prod(sz)
y = u.lower .+ ((u.upper - u.lower) .* rand(rng, real_sz))
return LinkedVectorValue(y, from_linked_vec_transform(dist), get_size_for_vnt(dist))
return LinkedVectorValue(y, from_linked_vec_transform(dist))
end

"""
Expand Down Expand Up @@ -188,10 +188,10 @@ function init(
# In this case, we can't trust the transform stored in x because the _value_
# in x may have been changed via unflatten!! without the transform being
# updated. Therefore, we always recompute the transform here.
VectorValue(x.val, from_vec_transform(dist), x.size)
VectorValue(x.val, from_vec_transform(dist))
elseif x isa LinkedVectorValue
# Same as above.
LinkedVectorValue(x.val, from_linked_vec_transform(dist), x.size)
LinkedVectorValue(x.val, from_linked_vec_transform(dist))
elseif x isa UntransformedValue
x
else
Expand Down Expand Up @@ -230,10 +230,10 @@ function init(
# In this case, we can't trust the transform stored in x because the _value_
# in x may have been changed via unflatten!! without the transform being
# updated. Therefore, we always recompute the transform here.
VectorValue(x.val, from_vec_transform(dist), x.size)
VectorValue(x.val, from_vec_transform(dist))
elseif x isa LinkedVectorValue
# Same as above.
LinkedVectorValue(x.val, from_linked_vec_transform(dist), x.size)
LinkedVectorValue(x.val, from_linked_vec_transform(dist))
elseif x isa UntransformedValue
x
else
Expand Down Expand Up @@ -270,17 +270,13 @@ an unlinked value.

$(TYPEDFIELDS)
"""
struct RangeAndLinked{T<:Tuple}
struct RangeAndLinked
# indices that the variable corresponds to in the vectorised parameter
range::UnitRange{Int}
# whether the variable is linked or unlinked
is_linked::Bool
# original size of the variable before vectorisation
original_size::T
end

VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size

"""
InitFromVector(
vect::AbstractVector{<:Real},
Expand Down Expand Up @@ -328,17 +324,16 @@ end
function init(::Random.AbstractRNG, vn::VarName, dist::Distribution, ifv::InitFromVector)
range_and_linked = _get_range_and_linked(ifv, vn)
vect = view(ifv.vect, range_and_linked.range)
sz = range_and_linked.original_size
# This block here is why we store transform_strategy inside the InitFromVector, as it
# allows for type stability.
return if ifv.transform_strategy isa LinkAll
LinkedVectorValue(vect, from_linked_vec_transform(dist), sz)
LinkedVectorValue(vect, from_linked_vec_transform(dist))
elseif ifv.transform_strategy isa UnlinkAll
VectorValue(vect, from_vec_transform(dist), sz)
VectorValue(vect, from_vec_transform(dist))
elseif range_and_linked.is_linked
LinkedVectorValue(vect, from_linked_vec_transform(dist), sz)
LinkedVectorValue(vect, from_linked_vec_transform(dist))
else
VectorValue(vect, from_vec_transform(dist), sz)
VectorValue(vect, from_vec_transform(dist))
end
end
function get_param_eltype(strategy::InitFromVector)
Expand Down
2 changes: 1 addition & 1 deletion src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ function get_ranges_and_linked(vnt::VarNamedTuple)
val = tv.val
range = offset:(offset + length(val) - 1)
offset += length(val)
ral = RangeAndLinked(range, tv isa LinkedVectorValue, tv.size)
ral = RangeAndLinked(range, tv isa LinkedVectorValue)
template = vnt.data[AbstractPPL.getsym(vn)]
ranges_vnt = templated_setindex!!(ranges_vnt, ral, vn, template)
return ranges_vnt, offset
Expand Down
45 changes: 14 additions & 31 deletions src/transformed_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ Subtypes of this should implement the following functions:
- `DynamicPPL.set_internal_value(tv::AbstractTransformedValue, new_val)`: Create a new
`AbstractTransformedValue` with the same transformation as `tv`, but with internal value
`new_val`.

- `DynamicPPL.VarNamedTuples.vnt_size(tv::AbstractTransformedValue)`: Get the size of the
original value before transformation.
"""
abstract type AbstractTransformedValue end

Expand Down Expand Up @@ -102,14 +99,14 @@ internal value `new_val`.
function set_internal_value end

"""
VectorValue{V<:AbstractVector,T,S}
VectorValue{V<:AbstractVector,T}

A transformed value that stores its internal value as a vectorised form. This is what
VarInfo sees as an "unlinked value".

These values can be generated when using `InitFromParams` with a VarInfo's internal values.
"""
struct VectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue
struct VectorValue{V<:AbstractVector,T} <: AbstractTransformedValue
"The internal (vectorised) value."
val::V
"""The unvectorisation transform required to convert `val` back to the original space.
Expand All @@ -120,23 +117,20 @@ struct VectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue
evaluation occurs, the correct transform is determined from the distribution associated
with the variable."""
transform::T
"""The size of the original value before transformation. This is needed when a
`TransformedValue` is stored as a block in an array."""
size::S
function VectorValue(val::V, tfm::T, size::S) where {V<:AbstractVector,T,S}
return new{V,T,S}(val, tfm, size)
function VectorValue(val::V, tfm::T) where {V<:AbstractVector,T}
return new{V,T}(val, tfm)
end
end

"""
LinkedVectorValue{V<:AbstractVector,T,S}
LinkedVectorValue{V<:AbstractVector,T}

A transformed value that stores its internal value as a linked andvectorised form. This is
A transformed value that stores its internal value as a linked and vectorised form. This is
what VarInfo sees as a "linked value".

These values can be generated when using `InitFromParams` with a VarInfo's internal values.
"""
struct LinkedVectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue
struct LinkedVectorValue{V<:AbstractVector,T} <: AbstractTransformedValue
"The internal (linked + vectorised) value."
val::V
"""The unlinking transform required to convert `val` back to the original space.
Expand All @@ -147,33 +141,25 @@ struct LinkedVectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue
evaluation occurs, the correct transform is determined from the distribution associated
with the variable."""
transform::T
"""The size of the original value before transformation. This is needed when a
`TransformedValue` is stored as a block in an array."""
size::S
function LinkedVectorValue(val::V, tfm::T, size::S) where {V<:AbstractVector,T,S}
return new{V,T,S}(val, tfm, size)
function LinkedVectorValue(val::V, tfm::T) where {V<:AbstractVector,T}
return new{V,T}(val, tfm)
end
end

for T in (:VectorValue, :LinkedVectorValue)
@eval begin
function Base.:(==)(tv1::$T, tv2::$T)
return (tv1.val == tv2.val) &
(tv1.transform == tv2.transform) &
(tv1.size == tv2.size)
return (tv1.val == tv2.val) & (tv1.transform == tv2.transform)
end
function Base.isequal(tv1::$T, tv2::$T)
return isequal(tv1.val, tv2.val) &&
isequal(tv1.transform, tv2.transform) &&
isequal(tv1.size, tv2.size)
return isequal(tv1.val, tv2.val) && isequal(tv1.transform, tv2.transform)
end
VarNamedTuples.vnt_size(tv::$T) = tv.size

get_transform(tv::$T) = tv.transform
get_internal_value(tv::$T) = tv.val

function set_internal_value(tv::$T, new_val)
return $T(new_val, tv.transform, tv.size)
return $T(new_val, tv.transform)
end
end
end
Expand All @@ -191,15 +177,12 @@ struct UntransformedValue{V} <: AbstractTransformedValue
val::V
UntransformedValue(val::V) where {V} = new{V}(val)
end
VarNamedTuples.vnt_size(tv::UntransformedValue) = vnt_size(tv.val)
Base.:(==)(tv1::UntransformedValue, tv2::UntransformedValue) = tv1.val == tv2.val
Base.isequal(tv1::UntransformedValue, tv2::UntransformedValue) = isequal(tv1.val, tv2.val)
get_transform(::UntransformedValue) = typed_identity
get_internal_value(tv::UntransformedValue) = tv.val
set_internal_value(::UntransformedValue, new_val) = UntransformedValue(new_val)

get_size_for_vnt(val) = hasmethod(size, Tuple{typeof(val)}) ? size(val) : ()

"""
abstract type AbstractTransform end

Expand Down Expand Up @@ -396,7 +379,7 @@ function apply_transform_strategy(
flink = DynamicPPL.to_linked_vec_transform(dist)
linked_value, logjac = with_logabsdet_jacobian(flink, raw_value)
finvlink = DynamicPPL.from_linked_vec_transform(dist)
linked_tv = LinkedVectorValue(linked_value, finvlink, get_size_for_vnt(raw_value))
linked_tv = LinkedVectorValue(linked_value, finvlink)
(raw_value, linked_tv, logjac)
elseif target isa Unlink
# No need to transform further
Expand All @@ -419,7 +402,7 @@ function apply_transform_strategy(
flink = DynamicPPL.to_linked_vec_transform(dist)
linked_value, logjac = with_logabsdet_jacobian(flink, raw_value)
finvlink = DynamicPPL.from_linked_vec_transform(dist)
linked_tv = LinkedVectorValue(linked_value, finvlink, get_size_for_vnt(raw_value))
linked_tv = LinkedVectorValue(linked_value, finvlink)
(raw_value, linked_tv, logjac)
elseif target isa Unlink
# No need to transform further
Expand Down
11 changes: 5 additions & 6 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,7 @@ function setindex_with_dist!!(
vi::VarInfo, tval::UntransformedValue, dist::Distribution, vn::VarName, template
)
raw_value = DynamicPPL.get_internal_value(tval)
sz = hasmethod(size, (typeof(raw_value),)) ? size(raw_value) : ()
tval = VectorValue(to_vec_transform(dist)(raw_value), from_vec_transform(dist), sz)
tval = VectorValue(to_vec_transform(dist)(raw_value), from_vec_transform(dist))
return setindex_with_dist!!(vi, tval, dist, vn, template)
end

Expand All @@ -330,9 +329,9 @@ transformation of the variable!
function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName)
old_tv = getindex(vi.values, vn)
new_tv = if linked
LinkedVectorValue(old_tv.val, old_tv.transform, old_tv.size)
LinkedVectorValue(old_tv.val, old_tv.transform)
else
VectorValue(old_tv.val, old_tv.transform, old_tv.size)
VectorValue(old_tv.val, old_tv.transform)
end
new_values = setindex!!(vi.values, new_tv, vn)
new_transform_strategy = update_transform_strategy(
Expand All @@ -358,7 +357,7 @@ set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false
function set_transformed!!(vi::VarInfo, linked::Bool)
ctor = linked ? LinkedVectorValue : VectorValue
new_values = map_values!!(vi.values) do tv
ctor(tv.val, tv.transform, tv.size)
ctor(tv.val, tv.transform)
end
new_transform_strategy = linked ? LinkAll() : UnlinkAll()
return VarInfo(new_transform_strategy, new_values, vi.accs)
Expand Down Expand Up @@ -528,7 +527,7 @@ for T in (:VectorValue, :LinkedVectorValue)
len = length(old_val)
new_val = @view vci.vec[(vci.index):(vci.index + len - 1)]
vci.index += len
return $T(new_val, tv.transform, tv.size)
return $T(new_val, tv.transform)
end
end
end
Expand Down
1 change: 0 additions & 1 deletion src/varnamedtuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using BangBang
using DynamicPPL: DynamicPPL

export VarNamedTuple,
vnt_size,
map_pairs!!,
map_values!!,
apply!!,
Expand Down
Loading