From 34ad663e52a64282c3d6c6068e55e42a464f1521 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 12:29:29 +0000 Subject: [PATCH 01/56] Fix and improve map!! and apply!! --- src/varnamedtuple.jl | 121 +++++++++++++++++++++++++++++++++++++----- test/varnamedtuple.jl | 83 ++++++++++++++++++++++++++++- 2 files changed, 190 insertions(+), 14 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index bb1f4a14b..c5cb2c681 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,7 @@ using BangBang using Accessors using ..DynamicPPL: _compose_no_identity -export VarNamedTuple +export VarNamedTuple, map!!, apply!! # We define our own getindex, setindex!!, and haskey functions, which we use to # get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be @@ -19,12 +19,33 @@ export VarNamedTuple # 2. We would want `haskey` to fall back onto `checkbounds` when called on Base.Arrays. function _getindex end function _haskey end + +""" + _setindex!!(collection, value, key; allow_new=Val(true)) + +Like `setindex!!`, but special-cased for `VarNamedTuple` and `PartialArray` to recurse +into nested structures. + +The `allow_new` keywword argument is a performance optimisation: If it is set to +`Val(false)`, the function can assume that the key being set already exists in `collection`. +This allows skipping some code paths, which may have a minor benefit at runtime, but more +importantly, allows for better constant propagation and type stability at compile time. + +`allow_new` being set to `Val(false)` does _not_ guarantee that no new keys will be added. +It only gives the implementation of `_setindex!!` the permission to assume that the key +already exists. Setting it to `Val(false)` should be done only when the caller is sure that +the key already exists, anything else is a bug in the caller. + +Most methods of _setindex!! ignore the `allow_new` keyword argument, as they have no use for +it. See the method for setting values in a `VarNamedTuple` with a `ComposedFunction` for +when it is useful. +""" function _setindex!! end _getindex(arr::AbstractArray, optic::IndexLens) = getindex(arr, optic.indices...) _haskey(arr::AbstractArray, optic::IndexLens) = _haskey(arr, optic.indices) _haskey(arr::AbstractArray, inds) = checkbounds(Bool, arr, inds...) -function _setindex!!(arr::AbstractArray, value, optic::IndexLens) +function _setindex!!(arr::AbstractArray, value, optic::IndexLens; allow_new=Val(true)) return setindex!!(arr, value, optic.indices...) end @@ -451,7 +472,7 @@ end _getindex(pa::PartialArray, optic::IndexLens) = _getindex(pa, optic.indices...) _haskey(pa::PartialArray, optic::IndexLens) = _haskey(pa, optic.indices) -function _setindex!!(pa::PartialArray, value, optic::IndexLens) +function _setindex!!(pa::PartialArray, value, optic::IndexLens; allow_new=Val(true)) return _setindex!!(pa, value, optic.indices...) end @@ -1006,11 +1027,13 @@ _haskey(vnt::VarNamedTuple, ::PropertyLens{S}) where {S} = haskey(vnt.data, S) _haskey(vnt::VarNamedTuple, ::typeof(identity)) = true _haskey(::VarNamedTuple, ::IndexLens) = false -function _setindex!!(vnt::VarNamedTuple, value, name::VarName) - return _setindex!!(vnt, value, _varname_to_lens(name)) +function _setindex!!(vnt::VarNamedTuple, value, name::VarName; allow_new=Val(true)) + return _setindex!!(vnt, value, _varname_to_lens(name); allow_new=allow_new) end -function _setindex!!(vnt::VarNamedTuple, value, ::PropertyLens{S}) where {S} +function _setindex!!( + vnt::VarNamedTuple, value, ::PropertyLens{S}; allow_new=Val(true) +) where {S} # I would like for this to just read # return VarNamedTuple(_setindex!!(vnt.data, value, S)) # but that seems to be type unstable. Why? Shouldn't it obviously be the same as the @@ -1041,13 +1064,13 @@ Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) return Expr(:block, exs...) end -# TODO(mhauru) The below remains unfinished an undertested. I think it's incorrect for more -# complex VarNames. It is unexported though. """ apply!!(func, vnt::VarNamedTuple, name::VarName) Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. +Like `map!!`, but only for a single `VarName`. + ```jldoctest julia> using DynamicPPL: VarNamedTuple, setindex!! @@ -1069,9 +1092,71 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) end subdata = _getindex(vnt, name) new_subdata = func(subdata) - return _setindex!!(vnt, new_subdata, name) + # The allow_new=Val(true) is a performance optimisation: Since we've already checked + # that the key exists, we know that no new fields will be created. + return _setindex!!(vnt, new_subdata, name; allow_new=Val(false)) +end + +""" + _map_recursive!!(func, x) + +Call `func` on `x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which case +call `_map_recursive!!` recursively on all their elements.. + +This is the internal implementation of `map!!`, but because it has a method defined for +literally every type in existence, we hide it behind the interface of the more +discriminating `map!!`. It makes the implementation a bit simpler, compared to checking +element types within `map!!` itself. +""" +_map_recursive!!(func, x) = func(x) + +function _map_recursive!!(func, pa::PartialArray) + # Ask the compiler to infer the return type of applying func to eltype(pa). + new_et = Core.Compiler.return_type(x -> _map_recursive!!(func, x), Tuple{eltype(pa)}) + new_data = if new_et <: eltype(pa) + pa.data + else + similar(pa.data, new_et) + end + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + new_data[i] = _map_recursive!!(func, pa.data[i]) + end + end + # The above type inference may be overly conservative, so we concretise the eltype. + return _concretise_eltype!!(PartialArray(new_data, pa.mask)) +end + +function _map_recursive!!(func, alb::ArrayLikeBlock) + new_block = _map_recursive!!(func, alb.block) + if size(new_block) != size(alb.block) + throw( + DimensionMismatch( + "map!! can't change the size of an ArrayLikeBlock. Tried to change from" * + "$(size(alb.block)) to $(size(new_block)).", + ), + ) + end + return ArrayLikeBlock(new_block, alb.inds) +end + +@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}) where {Names} + exs = Expr[] + for name in Names + push!(exs, :(_map_recursive!!(func, vnt.data.$name))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end end +""" + map!!(func, vnt::VarNamedTuple) + +Apply `func` to all set elements of the `vnt`, in place if possible. +""" +map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) + function Base.keys(vnt::VarNamedTuple) result = VarName[] for sym in keys(vnt.data) @@ -1132,13 +1217,23 @@ function _getindex(x::VNT_OR_PA, optic::ComposedFunction) return _getindex(subdata, optic.outer) end -function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction) +# The allow_new keyword argument is a performance optimisation that helps constant +# propagation and type inference by avoiding any possible dynamic dispatch calls to +# `make_leaf`. It should only be set to `Val(false) if we are sure that the key already +# exists, and thus there would be no need to call `make_leaf`. +function _setindex!!(vnt::VNT_OR_PA, value, optic::ComposedFunction; allow_new=Val(true)) sub = if _haskey(vnt, optic.inner) - _setindex!!(_getindex(vnt, optic.inner), value, optic.outer) - else + _setindex!!(_getindex(vnt, optic.inner), value, optic.outer; allow_new=allow_new) + elseif allow_new isa Val{true} make_leaf(value, optic.outer) + else + # If this branch is ever reached, then someone has used allow_new=Val(false) + # incorrectly. + error(""" + _setindex was called with allow_new=Val(false) but the key does not exist. + This indicates a bug in DynamicPPL: Please file an issue on GitHub.""") end - return _setindex!!(vnt, sub, optic.inner) + return _setindex!!(vnt, sub, optic.inner; allow_new=allow_new) end function _haskey(vnt::VNT_OR_PA, optic::ComposedFunction) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 7b81708ed..25cd14e49 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock +using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map!!, apply!! using AbstractPPL: VarName, concretize, prefix using BangBang: setindex!!, empty!! @@ -741,6 +741,87 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end + + @testset "map!! and apply!!" begin + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a))) + vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) + vnt = @inferred(setindex!!(vnt, [3.0], @varname(c.d))) + vnt = @inferred(setindex!!(vnt, "a", @varname(e.f[3].g.h[2].i))) + # The below can't be type stable because the element type of `h` depends on whether + # we are setting `h[2].j` (which overwrites the earlier `h[2]`) or some other + # `h[index].j` (which would leave both `h[2].i` and `h[index].j` in the same array). + vnt = setindex!!(vnt, 5.0, @varname(e.f[3].g.h[2].j)) + vnt = @inferred( + setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) + ) + test_invariants(vnt) + + struct AnotherSizedThing{T<:Tuple} + size::T + end + Base.size(st::AnotherSizedThing) = st.size + + function f(val) + if val isa Int + return val + 10 + elseif val isa AbstractVector{Int} + return val .+ 10 + elseif val isa Float64 + return val + 1.0 + elseif val isa AbstractVector{Float64} + return val .- 1.0 + elseif val isa String + return string(val, "b") + elseif val isa SizedThing + return AnotherSizedThing(size(val)) + else + error("Unexpected value type $(typeof(val))") + end + end + + vnt_mapped = @inferred(map!!(f, copy(vnt))) + test_invariants(vnt_mapped) + @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 + @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] + @test @inferred(getindex(vnt_mapped, @varname(c.d))) == [2.0] + @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 + @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == + AnotherSizedThing((2, 2)) + + vnt_applied = @inferred(apply!!(f, vnt, @varname(a))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(a))) == 11 + @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(b[1:2]))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(a))) == 11 + @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(c.d))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].i))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].j))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" + @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 + + # This can't be type stable because y.z might have many elements set, and we can't + # know at compile time that this sets the only one, thus allowing the element type + # to be AnotherSizedThing. + vnt_applied = apply!!(f, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == + AnotherSizedThing((2, 2)) + end end end From dc6291d9ec83bf6ddf214f51d28003d669c1d174 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 15:29:30 +0000 Subject: [PATCH 02/56] mapreduce and nested PartialArrays --- src/varnamedtuple.jl | 83 +++++++++++++++++++++++++++++++------ test/varnamedtuple.jl | 95 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 156 insertions(+), 22 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index c5cb2c681..b829523d7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -286,10 +286,19 @@ function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively # copy. pa_copy = PartialArray(copy(pa.data), copy(pa.mask)) - if VarNamedTuple <: eltype(pa) || eltype(pa) <: VarNamedTuple + et = eltype(pa) + if ( + VarNamedTuple <: et || + et <: VarNamedTuple || + PartialArray <: et || + et <: PartialArray + ) @inbounds for i in eachindex(pa.mask) - if pa.mask[i] && pa_copy.data[i] isa VarNamedTuple - pa_copy.data[i] = copy(pa.data[i]) + if pa.mask[i] + val = @inbounds pa_copy.data[i] + if val isa VarNamedTuple || val isa PartialArray + pa_copy.data[i] = copy(val) + end end end end @@ -754,6 +763,11 @@ function Base.keys(pa::PartialArray) sublens = _varname_to_lens(vn) push!(ks, _compose_no_identity(sublens, lens)) end + elseif val isa PartialArray + subkeys = keys(val) + for sublens in subkeys + push!(ks, _compose_no_identity(sublens, lens)) + end elseif val isa ArrayLikeBlock if !(val.inds in alb_inds_seen) push!(ks, IndexLens(Tuple(val.inds))) @@ -774,7 +788,7 @@ function Base.values(pa::PartialArray) continue end val = getindex(pa.data, ind) - if val isa VarNamedTuple + if val isa VarNamedTuple || val isa PartialArray subvalues = values(val) vs = push!!(vs, subvalues...) elseif val isa ArrayLikeBlock @@ -796,7 +810,7 @@ function Base.length(pa::PartialArray) continue end val = getindex(pa.data, ind) - if val isa VarNamedTuple + if val isa VarNamedTuple || val isa PartialArray len += length(val) else # Note we don't need to special case here for ArrayLikeBlocks. That's because @@ -1157,6 +1171,55 @@ Apply `func` to all set elements of the `vnt`, in place if possible. """ map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) +function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) + if init === nothing + throw( + NotImplementedError( + "mapreduce without init is not implemented for VarNamedTuple." + ), + ) + end + return _mapreduce_recursive(f, op, vnt, init) +end + +_mapreduce_recursive(f, op, x, init) = op(init, f(x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, init) = op(init, f(pa.block)) + +@generated function _mapreduce_recursive( + f, op, vnt::VarNamedTuple{Names}, init +) where {Names} + exs = Expr[] + push!( + exs, + quote + result = init + end, + ) + for name in Names + push!(exs, :(result = _mapreduce_recursive(f, op, vnt.data.$name, result))) + end + push!(exs, :(return result)) + return Expr(:block, exs...) +end + +function _mapreduce_recursive(f, op, pa::PartialArray, init) + result = init + albs_seen = Set{ArrayLikeBlock}() + @inbounds for i in eachindex(pa.mask) + if pa.mask[i] + val = @inbounds pa.data[i] + if val isa ArrayLikeBlock + if val in albs_seen + continue + end + push!(albs_seen, val) + end + result = _mapreduce_recursive(f, op, pa.data[i], result) + end + end + return result +end + function Base.keys(vnt::VarNamedTuple) result = VarName[] for sym in keys(vnt.data) @@ -1179,10 +1242,7 @@ function Base.values(vnt::VarNamedTuple) result = Any[] for sym in keys(vnt.data) subdata = vnt.data[sym] - if subdata isa VarNamedTuple - subvalues = values(subdata) - append!(result, subvalues) - elseif subdata isa PartialArray + if subdata isa VarNamedTuple || subdata isa PartialArray subvalues = values(subdata) append!(result, subvalues) else @@ -1196,9 +1256,7 @@ function Base.length(vnt::VarNamedTuple) len = 0 for sym in keys(vnt.data) subdata = vnt.data[sym] - if subdata isa VarNamedTuple - len += length(subdata) - elseif subdata isa PartialArray + if subdata isa VarNamedTuple || subdata isa PartialArray len += length(subdata) else len += 1 @@ -1245,6 +1303,7 @@ Base.haskey(vnt::VarNamedTuple, vn::VarName) = _haskey(vnt, vn) # PartialArrays are an implementation detail of VarNamedTuple, and should never be the # return value of getindex. Thus, we automatically convert them to dense arrays if needed. +# TODO(mhauru) The below doesn't handle nested PartialArrays. Is that a problem? _dense_array_if_needed(pa::PartialArray) = _dense_array(pa) _dense_array_if_needed(x) = x Base.getindex(vnt::VarNamedTuple, vn::VarName) = _dense_array_if_needed(_getindex(vnt, vn)) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 25cd14e49..fe0d6e1b2 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -272,6 +272,17 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, vn) @test @inferred(getindex(vnt, vn)) == x test_invariants(vnt) + + # Indices on indices + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a[1][1]))) + @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 + vnt = @inferred(setindex!!(vnt, [1], @varname(b[1].c[1]))) + @test @inferred(getindex(vnt, @varname(b[1].c[1]))) == [1] + vnt = @inferred(setindex!!(vnt, [1], @varname(e[3, 2].f[2, 2][10, 10]))) + @test @inferred(getindex(vnt, @varname(e[3, 2].f[2, 2][10, 10]))) == [1] + vnt = @inferred(setindex!!(vnt, [1], @varname(g[3, 2][10, 10].h[2, 2]))) + @test @inferred(getindex(vnt, @varname(g[3, 2][10, 10].h[2, 2]))) == [1] end @testset "equality and hash" begin @@ -352,15 +363,33 @@ Base.size(st::SizedThing) = st.size expected_merge = setindex!!(expected_merge, fill(2, 4), @varname(e.a[8:11])) @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, 1, @varname(e.b[1][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[2][13])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.b[1][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[2][13])) + vnt1 = setindex!!(vnt1, 1, @varname(e.b[3][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[3][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[3][13])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, 1, @varname(e.b[4][13])) + vnt2 = setindex!!(vnt2, 2, @varname(e.b[4][14])) + expected_merge = setindex!!(expected_merge, 1, @varname(e.b[4][13])) + expected_merge = setindex!!(expected_merge, 2, @varname(e.b[4][14])) + @test @inferred(merge(vnt1, vnt2)) == expected_merge + vnt1 = setindex!!(vnt1, ["1", "1"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) vnt2 = setindex!!(vnt2, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4])) expected_merge = setindex!!( expected_merge, ["2", "2"], @varname(f.a[1].b.c[2, 2].d[1, 3:4]) ) - vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1])) - expected_merge = setindex!!(expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1])) + vnt1 = setindex!!(vnt1, :1, @varname(f.a[1].b.c[3, 2].d[1, 1][14, 13])) + vnt2 = setindex!!(vnt2, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13])) + expected_merge = setindex!!( + expected_merge, :1, @varname(f.a[1].b.c[3, 2].d[1, 1][14, 13]) + ) + expected_merge = setindex!!( + expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13]) + ) @test merge(vnt1, vnt2) == expected_merge # PartialArrays with different sizes. @@ -501,6 +530,35 @@ Base.size(st::SizedThing) = st.size 1.0, SizedThing((3, 1, 4)), ] + + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) + @test keys(vnt) == [ + @varname(a), + @varname(b), + @varname(c.x.y), + @varname(d[4]), + @varname(e.f[3, 3].g.h[2, 4, 1].i), + @varname(j[1]), + @varname(j[2]), + @varname(j[3]), + @varname(j[4]), + @varname(j[6]), + @varname(n[2].a), + @varname(o[2:4, 5:5, 11:14]), + @varname(p[2, 1][2:4, 5:5, 11:14]), + ] + @test values(vnt) == [ + 1.0, + [1, 15, 3], + [10], + -1.0, + 2.0, + fill(1.0, 4)..., + "a", + 1.0, + SizedThing((3, 1, 4)), + SizedThing((3, 1, 4)), + ] end @testset "length" begin @@ -534,6 +592,9 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 4:6, 2, 1:2, 3])) @test @inferred(length(vnt)) == 14 + + vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) + @test @inferred(length(vnt)) == 16 end @testset "empty" begin @@ -622,7 +683,7 @@ Base.size(st::SizedThing) = st.size VarNamedTuple(a = "s", b = [1, 2, 3], \ c = PartialArray{Symbol,1}((2,) => :dada))""" - vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3].f.g[1:2])) + vnt = setindex!!(vnt, [16.0, 17.0], @varname(d.e[3][2, 2].f.g[1:2])) io = IOBuffer() show(io, vnt) output = String(take!(io)) @@ -634,11 +695,13 @@ Base.size(st::SizedThing) = st.size VarNamedTuple(a = "s", b = [1, 2, 3], \ c = PartialArray{Symbol,1}((2,) => :dada), \ d = VarNamedTuple(\ - e = PartialArray{VarNamedTuple{(:f,), \ + e = PartialArray{PartialArray{VarNamedTuple{(:f,), \ + Tuple{VarNamedTuple{(:g,), \ + Tuple{PartialArray{Float64, 1}}}}}, 2},1}((3,) => \ + PartialArray{VarNamedTuple{(:f,), \ Tuple{VarNamedTuple{(:g,), \ - Tuple{PartialArray{Float64, 1}}}}},1}((3,) => \ - VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ - (2,) => 17.0),),)),))""" + Tuple{PartialArray{Float64, 1}}}}},2}((2, 2) => VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ + (2,) => 17.0),),))),))""" end @testset "block variables" begin @@ -742,7 +805,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end - @testset "map!! and apply!!" begin + @testset "map!!, apply!!, and mapreduce" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1, @varname(a))) vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) @@ -755,6 +818,7 @@ Base.size(st::SizedThing) = st.size vnt = @inferred( setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) ) + vnt = @inferred(setindex!!(vnt, "", @varname(w[4][3][2, 1]))) test_invariants(vnt) struct AnotherSizedThing{T<:Tuple} @@ -780,6 +844,12 @@ Base.size(st::SizedThing) = st.size end end + reduction = mapreduce(identity, vcat, vnt; init=Any[]) + @test reduction == vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") + reduction = mapreduce(f, vcat, vnt; init=Any[]) + @test reduction == + vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") + vnt_mapped = @inferred(map!!(f, copy(vnt))) test_invariants(vnt_mapped) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @@ -789,6 +859,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" vnt_applied = @inferred(apply!!(f, vnt, @varname(a))) test_invariants(vnt_applied) @@ -821,6 +892,10 @@ Base.size(st::SizedThing) = st.size test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + + vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(w[4][3][2, 1]))) + test_invariants(vnt_applied) + @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" end end From 20ed5751f8d16f6499310a02a2b104168c4fb096 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 15:39:04 +0000 Subject: [PATCH 03/56] Test invariants more --- test/varnamedtuple.jl | 37 +++++++++++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 6 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index fe0d6e1b2..face2dc42 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -27,7 +27,9 @@ function test_invariants(vnt::VarNamedTuple) @test !(v isa ArrayLikeBlock) @test !(v isa PartialArray) vnt2 = setindex!!(copy(vnt), v, k) - @test vnt == vnt2 + equality = (vnt == vnt2) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end @@ -36,24 +38,26 @@ function test_invariants(vnt::VarNamedTuple) # reconstructability-from-repr property, this will fail. Likewise if any element uses # in its repr print out types that are not in scope in this module, it will fail. vnt3 = eval(Meta.parse(repr(vnt))) - @test vnt == vnt3 + equality = (vnt == vnt3) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing @test isequal(vnt, vnt3) @test hash(vnt) == hash(vnt3) # Check that merge with an empty VarNamedTuple is a no-op. - @test merge(vnt, VarNamedTuple()) == vnt - @test merge(VarNamedTuple(), vnt) == vnt + @test isequal(merge(vnt, VarNamedTuple()), vnt) + @test isequal(merge(VarNamedTuple(), vnt), vnt) # Check that the VNT can be constructed back from its keys and values. vnt4 = VarNamedTuple() for (k, v) in zip(vnt_keys, vnt_values) vnt4 = setindex!!(vnt4, v, k) end - @test vnt == vnt4 + @test isequal(vnt, vnt4) # Check that vnt isempty only if it has no keys was_empty = isempty(vnt) @test was_empty == isempty(vnt_keys) @test was_empty == isempty(vnt_values) # Check that vnt can be emptied - @test empty(vnt) == VarNamedTuple() + @test empty(vnt) === VarNamedTuple() emptied_vnt = empty!!(copy(vnt)) @test isempty(emptied_vnt) @test isempty(keys(emptied_vnt)) @@ -312,6 +316,8 @@ Base.size(st::SizedThing) = st.size expected_isequal = expected_isequal & isequal(v1, v2) expected_doubleequal = expected_doubleequal & (v1 == v2) end + test_invariants(vnt1) + test_invariants(vnt2) @test isequal(vnt1, vnt2) == expected_isequal @test (vnt1 == vnt2) === expected_doubleequal if expected_isequal @@ -335,6 +341,8 @@ Base.size(st::SizedThing) = st.size expected_merge = setindex!!(expected_merge, 2, @varname(c)) expected_merge = setindex!!(expected_merge, 2.0, @varname(b)) @test @inferred(merge(vnt1, vnt2)) == expected_merge + test_invariants(vnt1) + test_invariants(vnt2) vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -391,6 +399,8 @@ Base.size(st::SizedThing) = st.size expected_merge, :2, @varname(f.a[1].b.c[4, 2].d[1, 1][14, 13]) ) @test merge(vnt1, vnt2) == expected_merge + test_invariants(vnt1) + test_invariants(vnt2) # PartialArrays with different sizes. vnt1 = VarNamedTuple() @@ -406,6 +416,8 @@ Base.size(st::SizedThing) = st.size @test @inferred(merge(vnt1, vnt2)) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1])) @test @inferred(merge(vnt2, vnt1)) == expected_merge_21 + test_invariants(vnt1) + test_invariants(vnt2) vnt1 = VarNamedTuple() vnt2 = VarNamedTuple() @@ -420,6 +432,8 @@ Base.size(st::SizedThing) = st.size @test merge(vnt1, vnt2) == expected_merge_12 expected_merge_21 = setindex!!(expected_merge_12, 1, @varname(a[1, 1])) @test merge(vnt2, vnt1) == expected_merge_21 + test_invariants(vnt1) + test_invariants(vnt2) end @testset "keys and values" begin @@ -559,6 +573,7 @@ Base.size(st::SizedThing) = st.size SizedThing((3, 1, 4)), SizedThing((3, 1, 4)), ] + test_invariants(vnt) end @testset "length" begin @@ -595,6 +610,7 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) @test @inferred(length(vnt)) == 16 + test_invariants(vnt) end @testset "empty" begin @@ -605,10 +621,12 @@ Base.size(st::SizedThing) = st.size @test @inferred(isempty(vnt)) == true vnt = setindex!!(vnt, 1.0, @varname(a)) @test @inferred(isempty(vnt)) == false + test_invariants(vnt) vnt = VarNamedTuple() vnt = setindex!!(vnt, [], @varname(a[1])) @test @inferred(isempty(vnt)) == false + test_invariants(vnt) # 2) empty!! keeps PartialArrays in place: vnt = VarNamedTuple() @@ -624,22 +642,26 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(a[2:4]))) == [1, 2, 3] @test haskey(vnt, @varname(a[2:4])) @test !haskey(vnt, @varname(a[1])) + test_invariants(vnt) end @testset "densification" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 1)) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 2]))) @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (1, 2)) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[2, 1]))) @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 1)) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[1, 1]))) @@ -650,10 +672,12 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(a.b[1].c))) == fill(1.0, (2, 2)) vnt = @inferred(setindex!!(vnt, 1.0, @varname(a.b[1].c[3, 3]))) @test_throws ArgumentError @inferred(getindex(vnt, @varname(a.b[1].c))) + test_invariants(vnt) vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, SizedThing((2,)), @varname(x[1:2]))) @test_throws ArgumentError @inferred(getindex(vnt, @varname(x))) + test_invariants(vnt) end @testset "printing" begin @@ -702,6 +726,7 @@ Base.size(st::SizedThing) = st.size Tuple{VarNamedTuple{(:g,), \ Tuple{PartialArray{Float64, 1}}}}},2}((2, 2) => VarNamedTuple(f = VarNamedTuple(g = PartialArray{Float64,1}((1,) => 16.0, \ (2,) => 17.0),),))),))""" + test_invariants(vnt) end @testset "block variables" begin From 477b715a12776b30973e3d4e3ed7d53d3183500f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 16:52:54 +0000 Subject: [PATCH 04/56] Work-in-progress VNTVarInfo --- src/DynamicPPL.jl | 5 +- src/chains.jl | 4 +- src/contexts/init.jl | 10 +- src/logdensityfunction.jl | 91 ++++++++------ src/simple_varinfo.jl | 18 +-- src/test_utils/varinfo.jl | 42 ++++--- src/utils.jl | 1 + src/vntvarinfo.jl | 247 +++++++++++++++++++++++++++++++++++++ test/compiler.jl | 4 +- test/logdensityfunction.jl | 8 +- test/test_util.jl | 31 ++--- 11 files changed, 368 insertions(+), 93 deletions(-) create mode 100644 src/vntvarinfo.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 25ca59018..5b831e100 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("varnamedtuple.jl") -using .VarNamedTuples: VarNamedTuple +using .VarNamedTuples: VarNamedTuple, map!!, apply!! include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") @@ -201,7 +201,8 @@ include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") -include("varinfo.jl") +# include("varinfo.jl") +include("vntvarinfo.jl") include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") diff --git a/src/chains.jl b/src/chains.jl index 71ca29a8f..dc3a91044 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -67,8 +67,8 @@ end # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to # convert it to a typed varinfo first, hence this method. # https://github.com/TuringLang/Turing.jl/issues/2604 -maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) -maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) +# maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) +# maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi """ diff --git a/src/contexts/init.jl b/src/contexts/init.jl index dd9e99421..5422c7c85 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -320,7 +320,9 @@ function tilde_assume!!( insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) val_to_insert, logjac = if insert_transformed_value # Calculate the forward logjac and sum them up. - y, fwd_logjac = with_logabsdet_jacobian(link_transform(dist), x) + lt = link_transform(dist) + y, fwd_logjac = with_logabsdet_jacobian(lt, x) + transform = _compose_no_identity(transform, lt) # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian # calculation wastes a lot of time going from linked vectorised -> unlinked -> # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. @@ -360,7 +362,11 @@ function tilde_assume!!( if in_varinfo vi = setindex!!(vi, val_to_insert, vn) else - vi = push!!(vi, vn, val_to_insert, dist) + vi = if vi isa VNTVarInfo + push!!(vi, vn, val_to_insert, inverse(transform)) + else + push!!(vi, vn, val_to_insert, dist) + end end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 89e2b5989..adcb319c8 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -13,7 +13,7 @@ using DynamicPPL: OnlyAccsVarInfo, RangeAndLinked, VectorWithRanges, - Metadata, + # Metadata, VarNamedVector, default_accumulators, float_type_with_fallback, @@ -310,45 +310,56 @@ representation, along with whether each variable is linked or unlinked. This function returns a VarNamedTuple mapping all VarNames to their corresponding `RangeAndLinked`. """ -function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} - all_ranges = VarNamedTuple() +function get_ranges_and_linked(vi::VNTVarInfo) offset = 1 - for sym in syms - md = varinfo.metadata[sym] - this_md_others, offset = get_ranges_and_linked_metadata(md, offset) - all_ranges = merge(all_ranges, this_md_others) + vnt = map!!(vi.values) do tv + val = tv.val + range = offset:(offset + length(val) - 1) + offset += length(val) + RangeAndLinked(range, tv.linked, size(val)) end - return all_ranges -end -function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) - all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) - return all_ranges -end -function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) - all_ranges = VarNamedTuple() - offset = start_offset - for (vn, idx) in md.idcs - is_linked = md.is_transformed[idx] - range = md.ranges[idx] .+ (start_offset - 1) - orig_size = varnamesize(vn) - all_ranges = BangBang.setindex!!( - all_ranges, RangeAndLinked(range, is_linked, orig_size), vn - ) - offset += length(range) - end - return all_ranges, offset -end -function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) - all_ranges = VarNamedTuple() - offset = start_offset - for (vn, idx) in vnv.varname_to_index - is_linked = vnv.is_unconstrained[idx] - range = vnv.ranges[idx] .+ (start_offset - 1) - orig_size = varnamesize(vn) - all_ranges = BangBang.setindex!!( - all_ranges, RangeAndLinked(range, is_linked, orig_size), vn - ) - offset += length(range) - end - return all_ranges, offset + return vnt end + +# function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} +# all_ranges = VarNamedTuple() +# offset = 1 +# for sym in syms +# md = varinfo.metadata[sym] +# this_md_others, offset = get_ranges_and_linked_metadata(md, offset) +# all_ranges = merge(all_ranges, this_md_others) +# end +# return all_ranges +# end +# function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) +# all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) +# return all_ranges +# end +# function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) +# all_ranges = VarNamedTuple() +# offset = start_offset +# for (vn, idx) in md.idcs +# is_linked = md.is_transformed[idx] +# range = md.ranges[idx] .+ (start_offset - 1) +# orig_size = varnamesize(vn) +# all_ranges = BangBang.setindex!!( +# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn +# ) +# offset += length(range) +# end +# return all_ranges, offset +# end +# function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) +# all_ranges = VarNamedTuple() +# offset = start_offset +# for (vn, idx) in vnv.varname_to_index +# is_linked = vnv.is_unconstrained[idx] +# range = vnv.ranges[idx] .+ (start_offset - 1) +# orig_size = varnamesize(vn) +# all_ranges = BangBang.setindex!!( +# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn +# ) +# offset += length(range) +# end +# return all_ranges, offset +# end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 9d3fb1925..4add65d6d 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -256,15 +256,15 @@ function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFro end # Constructor from `VarInfo`. -function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} - values = values_as(vi, D) - return SimpleVarInfo(values, copy(getaccs(vi))) -end -function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} - values = values_as(vi, D) - accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) - return SimpleVarInfo(values, accs) -end +# function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} +# values = values_as(vi, D) +# return SimpleVarInfo(values, copy(getaccs(vi))) +# end +# function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} +# values = values_as(vi, D) +# accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) +# return SimpleVarInfo(values, accs) +# end function untyped_simple_varinfo(model::Model) varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 6483b29e8..79b92ce13 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -33,26 +33,32 @@ of the varinfo instances. function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) - # VarInfo - vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) - vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) - vi_typed_metadata = DynamicPPL.typed_varinfo(model) - vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) + # # VarInfo + # vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) + # vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) + # vi_typed_metadata = DynamicPPL.typed_varinfo(model) + # vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - # SimpleVarInfo - svi_typed = SimpleVarInfo(example_values) - svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) + # # SimpleVarInfo + # svi_typed = SimpleVarInfo(example_values) + # svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) + # svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - varinfos = map(( - vi_untyped_metadata, - vi_untyped_vnv, - vi_typed_metadata, - vi_typed_vnv, - svi_typed, - svi_untyped, - svi_vnv, - )) do vi + # varinfos = map(( + # vi_untyped_metadata, + # vi_untyped_vnv, + # vi_typed_metadata, + # vi_typed_vnv, + # svi_typed, + # svi_untyped, + # svi_vnv, + # )) do vi + # # Set them all to the same values and evaluate logp. + # vi = update_values!!(vi, example_values, varnames) + # last(DynamicPPL.evaluate!!(model, vi)) + # end + # + varinfos = map((DynamicPPL.typed_varinfo(model),)) do vi # Set them all to the same values and evaluate logp. vi = update_values!!(vi, example_values, varnames) last(DynamicPPL.evaluate!!(model, vi)) diff --git a/src/utils.jl b/src/utils.jl index fe2879182..ed9f3aa13 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -49,6 +49,7 @@ function typed_identity end @inline typed_identity(x) = x @inline Bijectors.with_logabsdet_jacobian(::typeof(typed_identity), x) = (x, zero(LogProbType)) +@inline Bijectors.inverse(::typeof(typed_identity)) = typed_identity """ @addlogprob!(ex) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl new file mode 100644 index 000000000..9b6ee2c7e --- /dev/null +++ b/src/vntvarinfo.jl @@ -0,0 +1,247 @@ +struct VNTVarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + values::T + accs::Accs +end + +# TODO(mhauru) Make this renaming permanent. +const VarInfo = VNTVarInfo + +struct TransformedValue{ValType,TransformType} + val::ValType + linked::Bool + transform::TransformType +end + +VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) + +function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VNTVarInfo(Random.default_rng(), model, init_strategy) +end + +function VNTVarInfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return last(init!!(rng, model, VNTVarInfo(), init_strategy)) +end + +getaccs(vi::VNTVarInfo) = vi.accs +setaccs!!(vi::VNTVarInfo, accs::AccumulatorTuple) = VNTVarInfo(vi.values, accs) + +transformation(::VNTVarInfo) = DynamicTransformation() + +Base.haskey(vi::VNTVarInfo, vn::VarName) = haskey(vi.values, vn) + +Base.length(vi::VNTVarInfo) = length(vi.values) + +function Base.getindex(vi::VNTVarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.transform(tv.val) +end + +Base.isempty(vi::VNTVarInfo) = isempty(vi.values) + +# TODO(mhauru) This should be called setindex_internal!!, but that's not the current +# convention. +function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) + old_tv = getindex(vi.values, vn) + new_tv = TransformedValue(val, old_tv.linked, old_tv.transform) + new_values = setindex!!(vi.values, new_tv, vn) + return VNTVarInfo(new_values, vi.accs) +end + +# TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. +function BangBang.push!!(vi::VNTVarInfo, vn::VarName, val, transform=typed_identity) + new_tv = TransformedValue(val, false, transform) + new_values = setindex!!(vi.values, new_tv, vn) + return VNTVarInfo(new_values, vi.accs) +end + +Base.keys(vi::VNTVarInfo) = keys(vi.values) + +function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) + old_tv = getindex(vi.values, vn) + new_tv = TransformedValue(old_tv.val, linked, old_tv.transform) + new_values = setindex!!(vi.values, new_tv, vn) + return VNTVarInfo(new_values, vi.accs) +end + +function set_transformed!!(vi::VNTVarInfo, linked::Bool) + new_values = map!!(vi.values) do tv + TransformedValue(tv.val, linked, tv.transform) + end + return VNTVarInfo(new_values, vi.accs) +end + +function getindex_internal(vi::VNTVarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.val +end + +getindex_internal(vi::VNTVarInfo, ::Colon) = values_as(vi, Vector) + +function is_transformed(vi::VNTVarInfo, vn::VarName) + tv = getindex(vi.values, vn) + return tv.linked +end + +# TODO(mhauru) Other VarInfos have something like this. Do we need it? +# function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) +# return from_vec_transform(dist) +# end + +function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) + return getindex(vi.values, vn).transform +end + +function from_linked_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) + return from_linked_vec_transform(dist) +end + +function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) + return getindex(vi.values, vn).transform +end + +function change_transform(tv::TransformedValue, new_transform, linked) + val_untransformed, logjac1 = with_logabsdet_jacobian(tv.transform, tv.val) + val_new, logjac2 = with_logabsdet_jacobian(inverse(new_transform), val_untransformed) + return TransformedValue(val_new, linked, new_transform), logjac1 + logjac2 +end + +function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) + dists = extract_priors(model, vi) + cumulative_logjac = zero(LogProbType) + new_values = vi.values + for vn in vns + new_values = apply!!(new_values, vn) do tv + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) + # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use + # map!!, but it doesn't have access to the VarName. + dists = extract_priors(model, vi) + cumulative_logjac = zero(LogProbType) + new_values = vi.values + vns = keys(vi) + for vn in vns + new_values = apply!!(vi.values, vn) do tv + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) + cumulative_logjac = zero(LogProbType) + new_values = vi.values + for vn in vns + new_values = apply!!(new_values, vn) do tv + transform = typed_identity + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) + # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use + # map!!, but it doesn't have access to the VarName. + cumulative_logjac = zero(LogProbType) + new_values = vi.values + vns = keys(vi) + for vn in vns + new_values = apply!!(vi.values, vn) do tv + transform = typed_identity + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv + end + end + vi = VNTVarInfo(new_values, vi.accs) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, cumulative_logjac) + end + return vi +end + +# TODO(mhauru) I don't think this should return the internal values, but that's the current +# convention. +function values_as(vi::VNTVarInfo, ::Type{Vector}) + return mapreduce(tv -> tovec(tv.val), vcat, vi.values; init=Union{}[]) +end + +# TODO(mhauru) These two are now redundant, just conforming to the old interface +# temporarily. +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return VNTVarInfo(rng, model, init_strategy) +end + +function typed_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return VNTVarInfo(rng, model, init_strategy) +end + +typed_varinfo(vi::VNTVarInfo) = vi + +function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return typed_varinfo(Random.default_rng(), model, init_strategy) +end + +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_varinfo(Random.default_rng(), model, init_strategy) +end + +function unflatten(vi::VNTVarInfo, vec::AbstractVector) + index = 1 + new_values = map!!(vi.values) do tv + # TODO(mhauru) This is quite crude, assuming that the value stored currently is + # an AbstractArray of some kind that has a size, and that reshape makes sense here. + # I may fix this later, but I'm also tempted to just get rid of unflatten entirely. + # This works for now for making most tests pass. + old_val = tv.val + len = length(old_val) + new_val = reshape(vec[index:(index + len - 1)], size(old_val)) + # If the old_val was a scalar then new_val is a 0-dimensional array. + # Convert it to a scalar. + if !(old_val isa AbstractArray) && length(old_val) == 1 + new_val = new_val[1] + end + index += len + return TransformedValue(new_val, tv.linked, tv.transform) + end + return VNTVarInfo(new_values, vi.accs) +end diff --git a/test/compiler.jl b/test/compiler.jl index 9056f666a..5101bd602 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -236,9 +236,9 @@ module Issue537 end # https://github.com/TuringLang/Turing.jl/issues/1464#issuecomment-731153615 vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) + @test haskey(vi, @varname(x)) vi = VarInfo(gdemo(x)) - @test haskey(vi.metadata, :x) + @test haskey(vi, @varname(x)) # Non-array variables @model function testmodel_nonarray(x, y) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index f96e7bf27..153962c9e 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -18,10 +18,10 @@ using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS @testset "$varinfo_func" for varinfo_func in [ - DynamicPPL.untyped_varinfo, + # DynamicPPL.untyped_varinfo, DynamicPPL.typed_varinfo, - DynamicPPL.untyped_vector_varinfo, - DynamicPPL.typed_vector_varinfo, + # DynamicPPL.untyped_vector_varinfo, + # DynamicPPL.typed_vector_varinfo, ] unlinked_vi = varinfo_func(m) @testset "$islinked" for islinked in (false, true) @@ -38,7 +38,7 @@ using Mooncake: Mooncake # directly range_with_linked = ranges[vn] @test params[range_with_linked.range] == - DynamicPPL.getindex_internal(vi, vn) + DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) # Check that the link status is correct @test range_with_linked.is_linked == islinked end diff --git a/test/test_util.jl b/test/test_util.jl index 94fdbd744..821b1e0db 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,28 +16,31 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -function short_varinfo_name(vi::DynamicPPL.NTVarInfo) - return if DynamicPPL.has_varnamedvector(vi) - "TypedVectorVarInfo" - else - "TypedVarInfo" - end -end -short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" -short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" +# function short_varinfo_name(vi::DynamicPPL.NTVarInfo) +# return if DynamicPPL.has_varnamedvector(vi) +# "TypedVectorVarInfo" +# else +# "TypedVarInfo" +# end +# end +# short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +# short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" end function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) return "SimpleVarInfo{<:OrderedDict,<:Ref}" end -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) - return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -end +# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) +# return "SimpleVarInfo{<:VarNamedVector,<:Ref}" +# end short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) - return "SimpleVarInfo{<:VarNamedVector}" +# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) +# return "SimpleVarInfo{<:VarNamedVector}" +# end +function short_varinfo_name(::DynamicPPL.VNTVarInfo) + return "VNTVarInfo" end # convenient functions for testing model.jl From 7aa601312b9e86e8139fcaa57c2e2c5782b0c42f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 18 Dec 2025 17:37:39 +0000 Subject: [PATCH 05/56] Fix a bug in link --- src/vntvarinfo.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 9b6ee2c7e..184fbd201 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -137,7 +137,7 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) new_values = vi.values vns = keys(vi) for vn in vns - new_values = apply!!(vi.values, vn) do tv + new_values = apply!!(new_values, vn) do tv dist = getindex(dists, vn) transform = from_linked_vec_transform(dist) new_tv, logjac = change_transform(tv, transform, true) @@ -177,7 +177,7 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) new_values = vi.values vns = keys(vi) for vn in vns - new_values = apply!!(vi.values, vn) do tv + new_values = apply!!(new_values, vn) do tv transform = typed_identity new_tv, logjac = change_transform(tv, transform, false) cumulative_logjac += logjac From bdeeb4ab1df4a3eeb678907bbe8e72a4818cd90a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 16:42:31 +0000 Subject: [PATCH 06/56] Update map!! to operate on pairs --- src/varnamedtuple.jl | 58 ++++++++++++++++++++++++++++++++++--------- test/varnamedtuple.jl | 23 +++++++++-------- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index b368ab8cd..faf2a298d 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1112,37 +1112,54 @@ function apply!!(func, vnt::VarNamedTuple, name::VarName) end """ - _map_recursive!!(func, x) + _map_recursive!!(func, x, vn) -Call `func` on `x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which case -call `_map_recursive!!` recursively on all their elements.. +Call `func` on `vn => x`, except if `x` is a `VarNamedTuple` or `PartialArray`, in which +case call `_map_recursive!!` recursively on all their elements, updating `vn` with the right +prefix. This is the internal implementation of `map!!`, but because it has a method defined for literally every type in existence, we hide it behind the interface of the more discriminating `map!!`. It makes the implementation a bit simpler, compared to checking element types within `map!!` itself. """ -_map_recursive!!(func, x) = func(x) - -function _map_recursive!!(func, pa::PartialArray) - # Ask the compiler to infer the return type of applying func to eltype(pa). - new_et = Core.Compiler.return_type(x -> _map_recursive!!(func, x), Tuple{eltype(pa)}) +_map_recursive!!(func, x, vn) = func(vn => x) + +# TODO(mhauru) The below is type unstable for some complex VarNames. My example case +# for which type stability fails is @varname(e.f[3].g.h[2].i). I don't understand this +# well, but I think it's just because constant propagation gives up at some point, and fails +# to go through the lines that figure out `new_et`. I could be wrong. I tried fixing this by +# lifting the first three lines of the function into a generated function, but that seems +# to run into trouble when trying to call Core.Compiler.return_type recursively on the same +# function. An earlier implementation of this function that only operated on the values, +# not on pairs of key => value, was type stable (presumably because it was a bit easier on +# constant propagation). +function _map_recursive!!(func, pa::PartialArray, vn) + # Ask the compiler to infer the return type of applying func recursively to eltype(pa). + index_type = IndexLens{NTuple{ndims(pa),Int}} + new_vn_type = Core.Compiler.return_type(∘, Tuple{index_type,typeof(vn)}) + new_et = Core.Compiler.return_type( + Tuple{typeof(_map_recursive!!),typeof(func),eltype(pa),new_vn_type} + ) new_data = if new_et <: eltype(pa) + # We can reuse the existing data array. pa.data else + # We need to allocate a new data array. similar(pa.data, new_et) end @inbounds for i in eachindex(pa.mask) if pa.mask[i] - new_data[i] = _map_recursive!!(func, pa.data[i]) + new_vn = IndexLens(Tuple(i)) ∘ vn + new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) end end # The above type inference may be overly conservative, so we concretise the eltype. return _concretise_eltype!!(PartialArray(new_data, pa.mask)) end -function _map_recursive!!(func, alb::ArrayLikeBlock) - new_block = _map_recursive!!(func, alb.block) +function _map_recursive!!(func, alb::ArrayLikeBlock, vn) + new_block = _map_recursive!!(func, alb.block, vn) if size(new_block) != size(alb.block) throw( DimensionMismatch( @@ -1157,7 +1174,22 @@ end @generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}) where {Names} exs = Expr[] for name in Names - push!(exs, :(_map_recursive!!(func, vnt.data.$name))) + push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end +end + +@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} + exs = Expr[] + for name in Names + push!( + exs, + :(_map_recursive!!( + func, vnt.data.$name, AbstractPPL.prefix(vn, VarName{$(QuoteNode(name))}()) + )), + ) end return quote return VarNamedTuple(NamedTuple{Names}(($(exs...),))) @@ -1168,6 +1200,8 @@ end map!!(func, vnt::VarNamedTuple) Apply `func` to all set elements of the `vnt`, in place if possible. + +`func` should accept a pair of `VarName` and value, and return the new value to be set. """ map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index face2dc42..0b3076468 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -851,7 +851,7 @@ Base.size(st::SizedThing) = st.size end Base.size(st::AnotherSizedThing) = st.size - function f(val) + function f_val(val) if val isa Int return val + 10 elseif val isa AbstractVector{Int} @@ -869,13 +869,16 @@ Base.size(st::SizedThing) = st.size end end + f_pair(pair) = f_val(pair.second) + reduction = mapreduce(identity, vcat, vnt; init=Any[]) @test reduction == vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") - reduction = mapreduce(f, vcat, vnt; init=Any[]) + reduction = mapreduce(f_val, vcat, vnt; init=Any[]) @test reduction == vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") - vnt_mapped = @inferred(map!!(f, copy(vnt))) + # vnt_mapped = @inferred(map!!(f, copy(vnt))) + vnt_mapped = map!!(f_pair, copy(vnt)) test_invariants(vnt_mapped) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] @@ -886,26 +889,26 @@ Base.size(st::SizedThing) = st.size AnotherSizedThing((2, 2)) @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" - vnt_applied = @inferred(apply!!(f, vnt, @varname(a))) + vnt_applied = @inferred(apply!!(f_val, vnt, @varname(a))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(b[1:2]))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(b[1:2]))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(c.d))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(c.d))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].i))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(e.f[3].g.h[2].j))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 @@ -913,12 +916,12 @@ Base.size(st::SizedThing) = st.size # This can't be type stable because y.z might have many elements set, and we can't # know at compile time that this sets the only one, thus allowing the element type # to be AnotherSizedThing. - vnt_applied = apply!!(f, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) + vnt_applied = apply!!(f_val, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) - vnt_applied = @inferred(apply!!(f, vnt_applied, @varname(w[4][3][2, 1]))) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" end From 5498d8279a9667aadba5173171b3edd3220ffe59 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 17:09:48 +0000 Subject: [PATCH 07/56] Split map!! into map_pairs!! and map_values!!, fix some bugs --- src/varnamedtuple.jl | 33 ++++++++++++++++++++------------- test/varnamedtuple.jl | 33 +++++++++++++++++++++++++++++---- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index faf2a298d..2217d35ee 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,7 @@ using BangBang using Accessors using ..DynamicPPL: _compose_no_identity -export VarNamedTuple, map!!, apply!! +export VarNamedTuple, map_pairs!!, map_values!!, apply!! # We define our own getindex, setindex!!, and haskey functions, which we use to # get/set/check values in VarNamedTuple and PartialArray. We do this because we want to be @@ -1083,7 +1083,7 @@ end Apply `func` to the subdata at `name` in `vnt`, and set the result back at `name`. -Like `map!!`, but only for a single `VarName`. +Like `map_values!!`, but only for a single `VarName`. ```jldoctest julia> using DynamicPPL: VarNamedTuple, setindex!! @@ -1118,10 +1118,10 @@ Call `func` on `vn => x`, except if `x` is a `VarNamedTuple` or `PartialArray`, case call `_map_recursive!!` recursively on all their elements, updating `vn` with the right prefix. -This is the internal implementation of `map!!`, but because it has a method defined for -literally every type in existence, we hide it behind the interface of the more -discriminating `map!!`. It makes the implementation a bit simpler, compared to checking -element types within `map!!` itself. +This is the internal implementation of `map_pairs!!`, but because it has a method defined +for literally every type in existence, we hide it behind the interface of the more +discriminating `map_pairs!!`. It makes the implementation a bit simpler, compared to +checking element types within `map_pairs!!` itself. """ _map_recursive!!(func, x, vn) = func(vn => x) @@ -1148,7 +1148,7 @@ function _map_recursive!!(func, pa::PartialArray, vn) # We need to allocate a new data array. similar(pa.data, new_et) end - @inbounds for i in eachindex(pa.mask) + @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] new_vn = IndexLens(Tuple(i)) ∘ vn new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) @@ -1163,8 +1163,8 @@ function _map_recursive!!(func, alb::ArrayLikeBlock, vn) if size(new_block) != size(alb.block) throw( DimensionMismatch( - "map!! can't change the size of an ArrayLikeBlock. Tried to change from" * - "$(size(alb.block)) to $(size(new_block)).", + "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * + "from $(size(alb.block)) to $(size(new_block)).", ), ) end @@ -1187,7 +1187,7 @@ end push!( exs, :(_map_recursive!!( - func, vnt.data.$name, AbstractPPL.prefix(vn, VarName{$(QuoteNode(name))}()) + func, vnt.data.$name, AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn) )), ) end @@ -1197,13 +1197,20 @@ end end """ - map!!(func, vnt::VarNamedTuple) + map_pairs!!(func, vnt::VarNamedTuple) -Apply `func` to all set elements of the `vnt`, in place if possible. +Apply `func` to all key => value pairs of `vnt`, in place if possible. `func` should accept a pair of `VarName` and value, and return the new value to be set. """ -map!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) +map_pairs!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) + +""" + map_values!!(func, vnt::VarNamedTuple) + +Apply `func` to elements of `vnt`, in place if possible. +""" +map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), vnt) function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) if init === nothing diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 0b3076468..13436d39a 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,8 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple -using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map!!, apply!! +using DynamicPPL.VarNamedTuples: + PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! using AbstractPPL: VarName, concretize, prefix using BangBang: setindex!!, empty!! @@ -64,6 +65,9 @@ function test_invariants(vnt::VarNamedTuple) @test isempty(values(emptied_vnt)) # Check that the copy protected the original vnt from being modified. @test isempty(vnt) == was_empty + # Check that map is a no-op when using identity functions. + @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) + @test isequal(map_values!!(identity, copy(vnt)), vnt) end """ A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" @@ -830,7 +834,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end - @testset "map!!, apply!!, and mapreduce" begin + @testset "map and friends" begin vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1, @varname(a))) vnt = @inferred(setindex!!(vnt, [2, 2], @varname(b[1:2]))) @@ -877,8 +881,11 @@ Base.size(st::SizedThing) = st.size @test reduction == vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") - # vnt_mapped = @inferred(map!!(f, copy(vnt))) - vnt_mapped = map!!(f_pair, copy(vnt)) + # TODO(mhauru) This should hopefully be type stable, but fails to be so because of + # some complex VarNames being too much for constant propagation. See comment in + # src/varnamedtuple.jl for more. + vnt_mapped = map_pairs!!(f_pair, copy(vnt)) + @test vnt_mapped == map_values!!(f_val, copy(vnt)) test_invariants(vnt_mapped) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] @@ -924,6 +931,24 @@ Base.size(st::SizedThing) = st.size vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) test_invariants(vnt_applied) @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" + + # map a function that maps every key => value pair to key => key. + # For this, use a simpler VarNamedTuple, because block variables don't work with + # this mapping function. It also allows us to check type stability. + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, 1, @varname(a))) + vnt = @inferred(setindex!!(vnt, 2, @varname(b[2]))) + vnt = @inferred(setindex!!(vnt, [3.0], @varname(c.d))) + vnt = @inferred(setindex!!(vnt, :oi, @varname(y.z[3, 2, 3, 2, 4]))) + vnt = @inferred(setindex!!(vnt, "", @varname(w[4][2, 1]))) + + get_key(pair) = pair.first + vnt_key_mapped = @inferred(map_pairs!!(get_key, copy(vnt))) + vnt_key_mapped_expected = VarNamedTuple() + for k in keys(vnt) + vnt_key_mapped_expected = setindex!!(vnt_key_mapped_expected, k, k) + end + @test vnt_key_mapped == vnt_key_mapped_expected end end From 81be716f37455a4121fbeadb174a20063164936b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 17:31:01 +0000 Subject: [PATCH 08/56] Make mapreduce operate on pairs --- src/varnamedtuple.jl | 80 +++++++++++++++++++++++++++++++++++++------ test/varnamedtuple.jl | 18 ++++++++-- 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 2217d35ee..4b91fbb12 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1136,12 +1136,13 @@ _map_recursive!!(func, x, vn) = func(vn => x) # constant propagation). function _map_recursive!!(func, pa::PartialArray, vn) # Ask the compiler to infer the return type of applying func recursively to eltype(pa). + et = eltype(pa) index_type = IndexLens{NTuple{ndims(pa),Int}} new_vn_type = Core.Compiler.return_type(∘, Tuple{index_type,typeof(vn)}) new_et = Core.Compiler.return_type( - Tuple{typeof(_map_recursive!!),typeof(func),eltype(pa),new_vn_type} + Tuple{typeof(_map_recursive!!),typeof(func),et,new_vn_type} ) - new_data = if new_et <: eltype(pa) + new_data = if new_et <: et # We can reuse the existing data array. pa.data else @@ -1150,7 +1151,13 @@ function _map_recursive!!(func, pa::PartialArray, vn) end @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] - new_vn = IndexLens(Tuple(i)) ∘ vn + val = pa.data[i] + # The first two checks on the below line are just a performance optimisation: + # They may short circuit at compile time. + is_alb = + (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock + ind = is_alb ? val.inds : Tuple(i) + new_vn = IndexLens(ind) ∘ vn new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) end end @@ -1212,6 +1219,16 @@ Apply `func` to elements of `vnt`, in place if possible. """ map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), vnt) +""" + mapreduce(f, op, vnt::VarNamedTuple; init) + +Apply `f` to all elements of `vnt`, and reduce the results using `op`, starting from `init`. + +`init` is a keyword argument to conform to the usual `mapreduce` interface in Base, but it +is not optional. + +`f` op` should accept pairs of `VarName` and value. +""" function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) if init === nothing throw( @@ -1223,8 +1240,8 @@ function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) return _mapreduce_recursive(f, op, vnt, init) end -_mapreduce_recursive(f, op, x, init) = op(init, f(x)) -_mapreduce_recursive(f, op, pa::ArrayLikeBlock, init) = op(init, f(pa.block)) +_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, init @@ -1237,25 +1254,68 @@ _mapreduce_recursive(f, op, pa::ArrayLikeBlock, init) = op(init, f(pa.block)) end, ) for name in Names - push!(exs, :(result = _mapreduce_recursive(f, op, vnt.data.$name, result))) + push!( + exs, + :( + result = _mapreduce_recursive( + f, op, vnt.data.$name, VarName{$(QuoteNode(name))}(), result + ) + ), + ) + end + push!(exs, :(return result)) + return Expr(:block, exs...) +end + +@generated function _mapreduce_recursive( + f, op, vnt::VarNamedTuple{Names}, vn, init +) where {Names} + exs = Expr[] + push!( + exs, + quote + result = init + end, + ) + for name in Names + push!( + exs, + :( + result = _mapreduce_recursive( + f, + op, + vnt.data.$name, + AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn), + result, + ) + ), + ) end push!(exs, :(return result)) return Expr(:block, exs...) end -function _mapreduce_recursive(f, op, pa::PartialArray, init) +function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) result = init + et = eltype(pa) + albs_seen = Set{ArrayLikeBlock}() - @inbounds for i in eachindex(pa.mask) + @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = @inbounds pa.data[i] - if val isa ArrayLikeBlock + # The first two checks on the below line are just a performance optimisation: + # They may short circuit at compile time. + is_alb = + (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock + if is_alb if val in albs_seen continue end push!(albs_seen, val) end - result = _mapreduce_recursive(f, op, pa.data[i], result) + ind = is_alb ? val.inds : Tuple(i) + new_vn = IndexLens(ind) ∘ vn + result = _mapreduce_recursive(f, op, pa.data[i], new_vn, result) end end return result diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 13436d39a..fe0417f2b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -875,9 +875,21 @@ Base.size(st::SizedThing) = st.size f_pair(pair) = f_val(pair.second) - reduction = mapreduce(identity, vcat, vnt; init=Any[]) - @test reduction == vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") - reduction = mapreduce(f_val, vcat, vnt; init=Any[]) + val_reduction = mapreduce(pair -> pair.second, vcat, vnt; init=Any[]) + @test val_reduction == + vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") + key_reduction = mapreduce(pair -> pair.first, vcat, vnt; init=Any[]) + @test key_reduction == vcat( + @varname(a), + @varname(b[1]), + @varname(b[2]), + @varname(c.d), + @varname(e.f[3].g.h[2].i), + @varname(e.f[3].g.h[2].j), + @varname(y.z[3, 2:3, 3, 2:3, 4]), + @varname(w[4][3][2, 1]), + ) + reduction = mapreduce(f_pair, vcat, vnt; init=Any[]) @test reduction == vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") From 37f4adfb66f70c2d4b3226c1714d4aa15043c89f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 8 Jan 2026 19:08:54 +0000 Subject: [PATCH 09/56] Implement keys and values using mapreduce --- src/varnamedtuple.jl | 145 +++++++------------------------------------ 1 file changed, 21 insertions(+), 124 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 4b91fbb12..d165ca3a5 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -378,6 +378,21 @@ function BangBang.empty!!(pa::PartialArray) return pa end +# Length could be defined as a special case of mapreduce, but it's harder to keep it type +# stable that way: If the element type is abstract, we end up calling _mapreduce_recursive +# on an abstract type, which makes the type of the cumulant Any. +function Base.length(pa::PartialArray) + len = 0 + @inbounds for i in eachindex(pa.mask) + if !pa.mask[i] + continue + end + val = pa.data[i] + len += val isa VarNamedTuple || val isa PartialArray ? length(val) : 1 + end + return len +end + """ _concretise_eltype!!(pa::PartialArray) @@ -745,83 +760,6 @@ function _merge_recursive(pa1::PartialArray, pa2::PartialArray) end end -function Base.keys(pa::PartialArray) - # TODO(mhauru) Should this rather be Union{}[]? It would make this very type unstable - # and cause more allocations, but would result in more concrete element types. Same - # question for Base.keys on VNT and Base.values. - ks = Any[] - alb_inds_seen = Set{Tuple}() - for ind in CartesianIndices(pa.mask) - @inbounds if !pa.mask[ind] - continue - end - lens = IndexLens(Tuple(ind)) - val = getindex(pa.data, lens.indices...) - if val isa VarNamedTuple - subkeys = keys(val) - for vn in subkeys - sublens = _varname_to_lens(vn) - push!(ks, _compose_no_identity(sublens, lens)) - end - elseif val isa PartialArray - subkeys = keys(val) - for sublens in subkeys - push!(ks, _compose_no_identity(sublens, lens)) - end - elseif val isa ArrayLikeBlock - if !(val.inds in alb_inds_seen) - push!(ks, IndexLens(Tuple(val.inds))) - push!(alb_inds_seen, val.inds) - end - else - push!(ks, lens) - end - end - return ks -end - -function Base.values(pa::PartialArray) - vs = Any[] - albs_seen = Set{ArrayLikeBlock}() - for ind in CartesianIndices(pa.mask) - @inbounds if !pa.mask[ind] - continue - end - val = getindex(pa.data, ind) - if val isa VarNamedTuple || val isa PartialArray - subvalues = values(val) - vs = push!!(vs, subvalues...) - elseif val isa ArrayLikeBlock - if !(val in albs_seen) - vs = push!!(vs, val.block) - push!(albs_seen, val) - end - else - vs = push!!(vs, val) - end - end - return vs -end - -function Base.length(pa::PartialArray) - len = 0 - for ind in CartesianIndices(pa.mask) - @inbounds if !pa.mask[ind] - continue - end - val = getindex(pa.data, ind) - if val isa VarNamedTuple || val isa PartialArray - len += length(val) - else - # Note we don't need to special case here for ArrayLikeBlocks. That's because - # we treat every index pointing to the same ArrayLikeBlock as contributing to - # the length. - len += 1 - end - end - return len -end - """ _dense_array(pa::PartialArray) @@ -1152,11 +1090,7 @@ function _map_recursive!!(func, pa::PartialArray, vn) @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = pa.data[i] - # The first two checks on the below line are just a performance optimisation: - # They may short circuit at compile time. - is_alb = - (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock - ind = is_alb ? val.inds : Tuple(i) + ind = val isa ArrayLikeBlock ? val.inds : Tuple(i) new_vn = IndexLens(ind) ∘ vn new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) end @@ -1303,10 +1237,7 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = @inbounds pa.data[i] - # The first two checks on the below line are just a performance optimisation: - # They may short circuit at compile time. - is_alb = - (et <: ArrayLikeBlock || ArrayLikeBlock <: et) && val isa ArrayLikeBlock + is_alb = val isa ArrayLikeBlock if is_alb if val in albs_seen continue @@ -1321,47 +1252,13 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) return result end -function Base.keys(vnt::VarNamedTuple) - result = VarName[] - for sym in keys(vnt.data) - subdata = vnt.data[sym] - if subdata isa VarNamedTuple - subkeys = keys(subdata) - append!(result, [AbstractPPL.prefix(sk, VarName{sym}()) for sk in subkeys]) - elseif subdata isa PartialArray - subkeys = keys(subdata) - append!(result, [VarName{sym}(lens) for lens in subkeys]) - else - push!(result, VarName{sym}()) - end - end - return result -end - -function Base.values(vnt::VarNamedTuple) - # TODO(mhauru) Same comments as for keys for type stability and Any vs Union{} - result = Any[] - for sym in keys(vnt.data) - subdata = vnt.data[sym] - if subdata isa VarNamedTuple || subdata isa PartialArray - subvalues = values(subdata) - append!(result, subvalues) - else - push!(result, subdata) - end - end - return result -end +Base.keys(vnt::VarNamedTuple) = mapreduce(first, push!, vnt; init=VarName[]) +Base.values(vnt::VarNamedTuple) = mapreduce(pair -> pair.second, push!, vnt; init=Any[]) function Base.length(vnt::VarNamedTuple) len = 0 - for sym in keys(vnt.data) - subdata = vnt.data[sym] - if subdata isa VarNamedTuple || subdata isa PartialArray - len += length(subdata) - else - len += 1 - end + for subdata in vnt.data + len += subdata isa VarNamedTuple || subdata isa PartialArray ? length(subdata) : 1 end return len end From fc29cc66cc013701f1ed472f8b9e0cbcb845455d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 11:23:58 +0000 Subject: [PATCH 10/56] Add more VNT constructors --- src/varnamedtuple.jl | 48 +++++++++++++++++++++++++++++++++++++++++-- test/varnamedtuple.jl | 34 ++++++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index d165ca3a5..bd33397d7 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -820,8 +820,25 @@ A `NamedTuple`-like structure with `VarName` keys. `VarNamedTuple` is a data structure for storing arbitrary data, keyed by `VarName`s, in an efficient and type stable manner. It is mainly used through `getindex`, `setindex!!`, and -`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Anther notable methods -is `merge`, which recursively merges two `VarNamedTuple`s. +`haskey`, all of which accept `VarName`s and only `VarName`s as keys. Other notable methods +are `merge` and `subset`. + +`VarNamedTuple` has an ordering to its elements, and two `VarNamedTuple`s with the same keys +and values but in different orders are considered different for equality and hashing. +Iterations such as `keys` and `values` respect this ordering. The ordering is dependent on +the order in which elements were inserted into the `VarNamedTuple`, though isn't always +equal to it. More specifically + +* Any new keys that have a joint parent `VarName` with an existing key are inserted after + that key. For instance, if one first inserts, in order, `@varname(a.x)`, `@varname(b)`, + and `@varname(a.y)`, the resulting order will be + `(@varname(a.x), @varname(a.y), @varname(b))`. +* `IndexLens` keys`, like `@varname(a[3])` or `@varname(b[2,3,4:5])`, are always iterated + in the same order an `Array` with the same indices would be iterated. For instance, + if one first inserts, in order, `@varname(a[2])`, `@varname(b)`, and `@varname(a[1])`, + the resulting order will be `(@varname(a[1]), @varname(a[2]), @varname(b))`. + +Otherwise insertion order is respected. The there are two major limitations to indexing by VarNamedTuples: @@ -844,10 +861,37 @@ related to `VarName`s with `IndexLens` components. """ struct VarNamedTuple{Names,Values} data::NamedTuple{Names,Values} + + function VarNamedTuple(data::NamedTuple{Names,Values}) where {Names,Values} + return new{Names,Values}(data) + end end VarNamedTuple(; kwargs...) = VarNamedTuple((; kwargs...)) +""" + VarNamedTuple(d) + VarNamedTuple(nt::NamedTuple) + +Create a `VarNamedTuple` from a collection or a `NamedTuple`. + +Any collection `d` is assumed to be an iterable of key-value pairs, where the keys are +`VarName`s. This could be a an `AbstractDict`, a vector of `Pair`s or `Tuple`s, etc. The +only exception is `NamedTuple`s, for which the `Symbol` keys are converted to `VarName`s. + +Note that `VarNamedTuple` has an ordering to its elements, and two `VarNamedTuple`s with the +same keys and values but in different orders are considered different. If `d` does not +guarantee an iteration order, then the order of the elements in the resulting +`VarNamedTuple` is undefined. +""" +function VarNamedTuple(d) + vnt = VarNamedTuple() + for (k, v) in d + vnt = setindex!!(vnt, v, k) + end + return vnt +end + Base.:(==)(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = vnt1.data == vnt2.data Base.isequal(vnt1::VarNamedTuple, vnt2::VarNamedTuple) = isequal(vnt1.data, vnt2.data) Base.hash(vnt::VarNamedTuple, h::UInt) = hash(vnt.data, h) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index fe0417f2b..5efd9fe49 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1,6 +1,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics +using OrderedCollections: OrderedDict using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple using DynamicPPL.VarNamedTuples: @@ -19,6 +20,7 @@ function test_invariants(vnt::VarNamedTuple) # These will be needed repeatedly. vnt_keys = keys(vnt) vnt_values = values(vnt) + # Check that for all keys in vnt, haskey is true, and resetting the value is a no-op. for k in vnt_keys @test haskey(vnt, k) @@ -34,6 +36,7 @@ function test_invariants(vnt::VarNamedTuple) @test isequal(vnt, vnt2) @test hash(vnt) == hash(vnt2) end + # Check that the printed representation can be parsed back to an equal VarNamedTuple. # The below eval test is a bit fragile: If any elements in vnt don't respect the same # reconstructability-from-repr property, this will fail. Likewise if any element uses @@ -44,27 +47,33 @@ function test_invariants(vnt::VarNamedTuple) @test equality === true || equality === missing @test isequal(vnt, vnt3) @test hash(vnt) == hash(vnt3) + # Check that merge with an empty VarNamedTuple is a no-op. @test isequal(merge(vnt, VarNamedTuple()), vnt) @test isequal(merge(VarNamedTuple(), vnt), vnt) + # Check that the VNT can be constructed back from its keys and values. vnt4 = VarNamedTuple() for (k, v) in zip(vnt_keys, vnt_values) vnt4 = setindex!!(vnt4, v, k) end @test isequal(vnt, vnt4) + # Check that vnt isempty only if it has no keys was_empty = isempty(vnt) - @test was_empty == isempty(vnt_keys) - @test was_empty == isempty(vnt_values) + @test isequal(was_empty, isempty(vnt_keys)) + @test isequal(was_empty, isempty(vnt_values)) + # Check that vnt can be emptied @test empty(vnt) === VarNamedTuple() emptied_vnt = empty!!(copy(vnt)) @test isempty(emptied_vnt) @test isempty(keys(emptied_vnt)) @test isempty(values(emptied_vnt)) + # Check that the copy protected the original vnt from being modified. @test isempty(vnt) == was_empty + # Check that map is a no-op when using identity functions. @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) @test isequal(map_values!!(identity, copy(vnt)), vnt) @@ -84,12 +93,33 @@ Base.size(st::SizedThing) = st.size vnt1 = setindex!!(vnt1, [1, 2, 3], @varname(b)) vnt1 = setindex!!(vnt1, "a", @varname(c.d.e)) test_invariants(vnt1) + vnt2 = VarNamedTuple(; a=1.0, b=[1, 2, 3], c=VarNamedTuple(; d=VarNamedTuple(; e="a")) ) test_invariants(vnt2) @test vnt1 == vnt2 + vnt3 = VarNamedTuple((; + a=1.0, b=[1, 2, 3], c=VarNamedTuple((; d=VarNamedTuple((; e="a")))) + )) + test_invariants(vnt3) + @test vnt1 == vnt3 + + vnt4 = VarNamedTuple( + OrderedDict( + @varname(a) => 1.0, @varname(b) => [1, 2, 3], @varname(c.d.e) => "a" + ), + ) + test_invariants(vnt4) + @test vnt1 == vnt4 + + vnt5 = VarNamedTuple(( + (@varname(a), 1.0), (@varname(b), [1, 2, 3]), (@varname(c.d.e), "a") + )) + test_invariants(vnt5) + @test vnt1 == vnt5 + pa1 = PartialArray{Float64,1}() pa1 = setindex!!(pa1, 1.0, 16) pa2 = PartialArray{Float64,1}(; min_size=(16,)) From c6d067720823792297d92242423c8f7e17c527e9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 11:24:39 +0000 Subject: [PATCH 11/56] Add VNT subset --- src/varnamedtuple.jl | 30 +++++++++++++++++++++++++++++- src/vntvarinfo.jl | 11 +++++++++++ test/varnamedtuple.jl | 42 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 2 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index bd33397d7..cc7648447 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -6,7 +6,7 @@ using AbstractPPL: AbstractPPL using Distributions: Distributions, Distribution using BangBang using Accessors -using ..DynamicPPL: _compose_no_identity +using ..DynamicPPL: DynamicPPL, _compose_no_identity export VarNamedTuple, map_pairs!!, map_values!!, apply!! @@ -1060,6 +1060,31 @@ Base.merge(x1::VarNamedTuple, x2::VarNamedTuple) = _merge_recursive(x1, x2) return Expr(:block, exs...) end +""" + subset(vnt::VarNamedTuple, vns) + +Create a new `VarNamedTuple` containing only the variables subsumed by ones in `vns`. +""" +function DynamicPPL.subset(vnt::VarNamedTuple, vns) + # TODO(mhauru) This could be done more efficiently by generating the code directly, + # because we could short-circuit: For instance, if `vns` contains `a`, we could + # directly include the whole subtree under `a`, without checking each individual + # variable under it. + return mapfoldl( + identity, + function (init, pair) + name, value = pair + return if any(vn -> subsumes(vn, name), vns) + setindex!!(init, value, name) + else + init + end + end, + vnt; + init=VarNamedTuple(), + ) +end + """ apply!!(func, vnt::VarNamedTuple, name::VarName) @@ -1218,6 +1243,9 @@ function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) return _mapreduce_recursive(f, op, vnt, init) end +# Our mapreduce is always left-associative. +Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) + _mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) _mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 184fbd201..ad698d169 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -245,3 +245,14 @@ function unflatten(vi::VNTVarInfo, vec::AbstractVector) end return VNTVarInfo(new_values, vi.accs) end + +function subset(varinfo::VNTVarInfo, vns) + new_values = subset(varinfo.values, vns) + return VNTVarInfo(new_values, map(copy, getaccs(varinfo))) +end + +function Base.merge(varinfo_left::VNTVarInfo, varinfo_right::VNTVarInfo) + new_values = merge(varinfo_left.values, varinfo_right.values) + new_accs = map(copy, getaccs(varinfo_right)) + return VNTVarInfo(new_values, new_accs) +end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 5efd9fe49..655b8e9e5 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -3,7 +3,7 @@ module VarNamedTupleTests using Combinatorics: Combinatorics using OrderedCollections: OrderedDict using Test: @inferred, @test, @test_throws, @testset -using DynamicPPL: DynamicPPL, @varname, VarNamedTuple +using DynamicPPL: DynamicPPL, @varname, VarNamedTuple, subset using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! using AbstractPPL: VarName, concretize, prefix @@ -77,6 +77,10 @@ function test_invariants(vnt::VarNamedTuple) # Check that map is a no-op when using identity functions. @test isequal(map_pairs!!(pair -> pair.second, copy(vnt)), vnt) @test isequal(map_values!!(identity, copy(vnt)), vnt) + + # Check that subsetting works as expected. + @test isequal(subset(vnt, vnt_keys), vnt) + @test isequal(subset(vnt, VarName[]), VarNamedTuple()) end """ A type that has a size but is not an Array. Used in ArrayLikeBlock tests.""" @@ -470,6 +474,42 @@ Base.size(st::SizedThing) = st.size test_invariants(vnt2) end + @testset "subset" begin + vnt = VarNamedTuple() + vnt = setindex!!(vnt, 1.0, @varname(a)) + vnt = setindex!!(vnt, [1, 2, 3], @varname(b)) + vnt = setindex!!(vnt, [10], @varname(c.x.y)) + vnt = setindex!!(vnt, :1, @varname(d[1])) + vnt = setindex!!(vnt, :2, @varname(d[2])) + vnt = setindex!!(vnt, :3, @varname(d[3])) + vnt = setindex!!(vnt, 2.0, @varname(e.f[3, 3].g.h[2, 4, 1].i)) + vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) + test_invariants(vnt) + + @test subset(vnt, VarName[]) == VarNamedTuple() + @test subset(vnt, (@varname(z),)) == VarNamedTuple() + @test subset(vnt, (@varname(d[4]),)) == VarNamedTuple() + # TODO(mhauru) Not sure what to do about the below. AbstractPPL considers d[1,1] to + # subsume d[1], but that breaks my idea of how VNT subset should work. + @test subset(vnt, (@varname(d[1, 1]),)) == VarNamedTuple() broken = true + @test subset(vnt, [@varname(a)]) == VarNamedTuple(; a=1.0) + @test subset(vnt, [@varname(b), @varname(d[1])]) == + VarNamedTuple((@varname(b) => [1, 2, 3], @varname(d[1]) => :1)) + @test subset(vnt, [@varname(d[2:3])]) == + VarNamedTuple((@varname(d[2]) => :2, @varname(d[3]) => :3)) + @test subset(vnt, [@varname(d)]) == VarNamedTuple(( + @varname(d[1]) => :1, @varname(d[2]) => :2, @varname(d[3]) => :3 + )) + @test subset(vnt, [@varname(c.x.y)]) == VarNamedTuple((@varname(c.x.y) => [10],)) + @test subset(vnt, [@varname(c)]) == VarNamedTuple((@varname(c.x.y) => [10],)) + @test subset(vnt, [@varname(e.f[3, 3].g.h[2, 4, 1].i)]) == + VarNamedTuple((@varname(e.f[3, 3].g.h[2, 4, 1].i) => 2.0,)) + @test subset(vnt, [@varname(p[2, 1][2:4, 5:5, 11:14])]) == + VarNamedTuple((@varname(p[2, 1][2:4, 5:5, 11:14]) => SizedThing((3, 1, 4)),)) + # Cutting the last range a bit short should mean that nothing is returned. + @test subset(vnt, [@varname(p[2, 1][2:4, 5:5, 11:13])]) == VarNamedTuple() + end + @testset "keys and values" begin vnt = VarNamedTuple() @test @inferred(keys(vnt)) == VarName[] From c18258cfbc717452ffdc550943c9f0c26c85a5be Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 15:45:18 +0000 Subject: [PATCH 12/56] Make _compose_no_identity handle typed_identity too --- src/utils.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index ed9f3aa13..11261334c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -951,6 +951,8 @@ Return `typeof(x)` stripped of its type parameters. """ basetypeof(x::T) where {T} = Base.typename(T).wrapper +const MaybeTypedIdentity = Union{typeof(typed_identity),typeof(identity)} + # TODO(mhauru) Might add another specialisation to _compose_no_identity, where if # ReshapeTransforms are composed with each other or with a an UnwrapSingeltonTransform, only # the latter one would be kept. @@ -963,6 +965,6 @@ This helps avoid trivial cases of `ComposedFunction` that would cause unnecessar conflicts. """ _compose_no_identity(f, g) = f ∘ g -_compose_no_identity(::typeof(identity), g) = g -_compose_no_identity(f, ::typeof(identity)) = f -_compose_no_identity(::typeof(identity), ::typeof(identity)) = identity +_compose_no_identity(::MaybeTypedIdentity, g) = g +_compose_no_identity(f, ::MaybeTypedIdentity) = f +_compose_no_identity(::MaybeTypedIdentity, ::MaybeTypedIdentity) = typed_identity From b91e6ff2f31c0cef2f22e9ae7e283512ceabb20a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 16:00:40 +0000 Subject: [PATCH 13/56] Myriad improvements to VNTVarInfo, overhaul varinfo.jl tests to use VNTVarInfo only --- src/DynamicPPL.jl | 2 +- src/test_utils/varinfo.jl | 42 +---- src/threadsafe.jl | 7 +- src/vntvarinfo.jl | 195 +++++++++++++++------ test/varinfo.jl | 351 +++++++++++--------------------------- 5 files changed, 249 insertions(+), 348 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5b831e100..7c1f53081 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -185,7 +185,7 @@ abstract type AbstractVarInfo <: AbstractModelTrace end # Necessary forward declarations include("utils.jl") include("varnamedtuple.jl") -using .VarNamedTuples: VarNamedTuple, map!!, apply!! +using .VarNamedTuples: VarNamedTuple, map_pairs!!, map_values!!, apply!! include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 79b92ce13..25f4fd04f 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -33,40 +33,14 @@ of the varinfo instances. function setup_varinfos( model::Model, example_values::NamedTuple, varnames; include_threadsafe::Bool=false ) - # # VarInfo - # vi_untyped_metadata = DynamicPPL.untyped_varinfo(model) - # vi_untyped_vnv = DynamicPPL.untyped_vector_varinfo(model) - # vi_typed_metadata = DynamicPPL.typed_varinfo(model) - # vi_typed_vnv = DynamicPPL.typed_vector_varinfo(model) - - # # SimpleVarInfo - # svi_typed = SimpleVarInfo(example_values) - # svi_untyped = SimpleVarInfo(OrderedDict{VarName,Any}()) - # svi_vnv = SimpleVarInfo(DynamicPPL.VarNamedVector()) - - # varinfos = map(( - # vi_untyped_metadata, - # vi_untyped_vnv, - # vi_typed_metadata, - # vi_typed_vnv, - # svi_typed, - # svi_untyped, - # svi_vnv, - # )) do vi - # # Set them all to the same values and evaluate logp. - # vi = update_values!!(vi, example_values, varnames) - # last(DynamicPPL.evaluate!!(model, vi)) - # end - # - varinfos = map((DynamicPPL.typed_varinfo(model),)) do vi - # Set them all to the same values and evaluate logp. - vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) + vi = DynamicPPL.VarInfo(model) + vi = update_values!!(vi, example_values, varnames) + last(DynamicPPL.evaluate!!(model, vi)) + + varinfos = if include_threadsafe + (vi, DynamicPPL.ThreadSafeVarInfo(deepcopy(vi))) + else + (vi,) end - - if include_threadsafe - varinfos = (varinfos..., map(DynamicPPL.ThreadSafeVarInfo ∘ deepcopy, varinfos)...) - end - return varinfos end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c7ab106a2..44b4da316 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -36,6 +36,9 @@ function getacc(vi::ThreadSafeVarInfo, accname::Val) return foldl(combine, other_accs; init=main_acc) end +function Base.copy(vi::ThreadSafeVarInfo) + return ThreadSafeVarInfo(copy(vi.varinfo), deepcopy(vi.accs_by_thread)) +end hasacc(vi::ThreadSafeVarInfo, accname::Val) = hasacc(vi.varinfo, accname) acckeys(vi::ThreadSafeVarInfo) = acckeys(vi.varinfo) @@ -195,8 +198,8 @@ end getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) -function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector) - return Accessors.@set vi.varinfo = unflatten(vi.varinfo, x) +function unflatten!!(vi::ThreadSafeVarInfo, x::AbstractVector) + return Accessors.@set vi.varinfo = unflatten!!(vi.varinfo, x) end function subset(varinfo::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index ad698d169..1ae9bc9d8 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -31,6 +31,8 @@ setaccs!!(vi::VNTVarInfo, accs::AccumulatorTuple) = VNTVarInfo(vi.values, accs) transformation(::VNTVarInfo) = DynamicTransformation() +Base.copy(vi::VNTVarInfo) = VNTVarInfo(copy(vi.values), copy(getaccs(vi))) + Base.haskey(vi::VNTVarInfo, vn::VarName) = haskey(vi.values, vn) Base.length(vi::VNTVarInfo) = length(vi.values) @@ -40,25 +42,37 @@ function Base.getindex(vi::VNTVarInfo, vn::VarName) return tv.transform(tv.val) end +function Base.getindex(vi::VNTVarInfo, vn::VarName, dist::Distribution) + val = getindex_internal(vi, vn) + return from_maybe_linked_internal(vi, vn, dist, val) +end + Base.isempty(vi::VNTVarInfo) = isempty(vi.values) +Base.empty(vi::VNTVarInfo) = VNTVarInfo(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VNTVarInfo) = VNTVarInfo(empty!!(vi.values), map(reset, vi.accs)) -# TODO(mhauru) This should be called setindex_internal!!, but that's not the current -# convention. -function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) +function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(val, old_tv.linked, old_tv.transform) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end +BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) = push!!(vi, vn, val) + # TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. function BangBang.push!!(vi::VNTVarInfo, vn::VarName, val, transform=typed_identity) + # TODO(mhauru) We should move away from having all values vectorised by default. + # That messes with our use of unflatten though, so will require some thought. + transform = _compose_no_identity(transform, from_vec_transform(val)) + val = to_vec_transform(val)(val) new_tv = TransformedValue(val, false, transform) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end Base.keys(vi::VNTVarInfo) = keys(vi.values) +Base.values(vi::VNTVarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) @@ -68,7 +82,7 @@ function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) end function set_transformed!!(vi::VNTVarInfo, linked::Bool) - new_values = map!!(vi.values) do tv + new_values = map_values!!(vi.values) do tv TransformedValue(tv.val, linked, tv.transform) end return VNTVarInfo(new_values, vi.accs) @@ -79,6 +93,8 @@ function getindex_internal(vi::VNTVarInfo, vn::VarName) return tv.val end +# TODO(mhauru) This is mimicing old behaviour, but is now wrong: The internal +# representation does not have to be a Vector. getindex_internal(vi::VNTVarInfo, ::Colon) = values_as(vi, Vector) function is_transformed(vi::VNTVarInfo, vn::VarName) @@ -86,15 +102,16 @@ function is_transformed(vi::VNTVarInfo, vn::VarName) return tv.linked end -# TODO(mhauru) Other VarInfos have something like this. Do we need it? -# function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) -# return from_vec_transform(dist) -# end - -function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) - return getindex(vi.values, vn).transform +# TODO(mhauru) Other VarInfos have something like this. Do we need it? Or should we use the +# below version? +function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) + return from_vec_transform(dist) end +# function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) +# return getindex(vi.values, vn).transform +# end + function from_linked_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) return from_linked_vec_transform(dist) end @@ -113,14 +130,17 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = vi.values - for vn in vns - new_values = apply!!(new_values, vn) do tv - dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac - return new_tv + new_values = map_pairs!!(new_values) do pair + vn, tv = pair + if !any(x -> subsumes(x, vn), vns) + # Not one of the target variables. + return tv end + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -135,15 +155,13 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = vi.values - vns = keys(vi) - for vn in vns - new_values = apply!!(new_values, vn) do tv - dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac - return new_tv - end + new_values = map_pairs!!(new_values) do pair + vn, tv = pair + dist = getindex(dists, vn) + transform = from_linked_vec_transform(dist) + new_tv, logjac = change_transform(tv, transform, true) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -155,13 +173,17 @@ end function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) cumulative_logjac = zero(LogProbType) new_values = vi.values - for vn in vns - new_values = apply!!(new_values, vn) do tv - transform = typed_identity - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac - return new_tv + new_values = map_pairs!!(new_values) do pair + vn, tv = pair + if !any(x -> subsumes(x, vn), vns) + # Not one of the target variables. + return tv end + current_val = tv.transform(tv.val) + transform = from_vec_transform(current_val) + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -175,14 +197,12 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) # map!!, but it doesn't have access to the VarName. cumulative_logjac = zero(LogProbType) new_values = vi.values - vns = keys(vi) - for vn in vns - new_values = apply!!(new_values, vn) do tv - transform = typed_identity - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac - return new_tv - end + new_values = map_values!!(new_values) do tv + current_val = tv.transform(tv.val) + transform = from_vec_transform(current_val) + new_tv, logjac = change_transform(tv, transform, false) + cumulative_logjac += logjac + return new_tv end vi = VNTVarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) @@ -191,10 +211,54 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) return vi end +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +end + +function link!!( + t::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vns::VarNameTuple, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) +end + +function invlink!!( + t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, + # and so we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) +end + +function invlink!!( + ::DynamicTransformation, + vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vns::VarNameTuple, + model::Model, +) + # By default this will simply evaluate the model with `DynamicTransformationContext`, and so + # we need to specialize to avoid this. + return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) +end + # TODO(mhauru) I don't think this should return the internal values, but that's the current # convention. function values_as(vi::VNTVarInfo, ::Type{Vector}) - return mapreduce(tv -> tovec(tv.val), vcat, vi.values; init=Union{}[]) + return mapfoldl(pair -> tovec(pair.second.val), vcat, vi.values; init=Union{}[]) +end + +function values_as(vi::VNTVarInfo, ::Type{T}) where {T<:AbstractDict} + return mapfoldl(identity, function (cumulant, pair) + vn, tv = pair + val = tv.transform(tv.val) + return setindex!!(cumulant, val, vn) + end, vi.values; init=T()) end # TODO(mhauru) These two are now redundant, just conforming to the old interface @@ -225,22 +289,41 @@ function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitF return untyped_varinfo(Random.default_rng(), model, init_strategy) end -function unflatten(vi::VNTVarInfo, vec::AbstractVector) - index = 1 - new_values = map!!(vi.values) do tv - # TODO(mhauru) This is quite crude, assuming that the value stored currently is - # an AbstractArray of some kind that has a size, and that reshape makes sense here. - # I may fix this later, but I'm also tempted to just get rid of unflatten entirely. - # This works for now for making most tests pass. +""" + VectorChunkIterator{T<:AbstractVector} + +A tiny struct for getting chunks of a vector sequentially. + +The only function provided is `get_next_chunk!`, which takes a length and returns +a view into the next chunk of that length, updating the internal index. +""" +mutable struct VectorChunkIterator{T<:AbstractVector} + vec::T + index::Int +end + +function get_next_chunk!(vci::VectorChunkIterator, len::Int) + i = vci.index + chunk = @view vci.vec[i:(i + len - 1)] + vci.index += len + return chunk +end + +function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) + # You may wonder, why have a whole struct for this, rather than just an index variable + # that the mapping function would close over. I wonder too. But for some reason type + # inference fails on such an index variable, turning it into a Core.Box. + vci = VectorChunkIterator(vec, 1) + new_values = map_values!!(vi.values) do tv old_val = tv.val - len = length(old_val) - new_val = reshape(vec[index:(index + len - 1)], size(old_val)) - # If the old_val was a scalar then new_val is a 0-dimensional array. - # Convert it to a scalar. - if !(old_val isa AbstractArray) && length(old_val) == 1 - new_val = new_val[1] + if !(old_val isa AbstractVector) + error( + "Can not unflatten a VarInfo for which existing values are not vectors:" * + " Got value of type $(typeof(old_val)).", + ) end - index += len + len = length(old_val) + new_val = get_next_chunk!(vci, len) return TransformedValue(new_val, tv.linked, tv.transform) end return VNTVarInfo(new_values, vi.accs) diff --git a/test/varinfo.jl b/test/varinfo.jl index a7948cc32..0a8e58eef 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -14,124 +14,59 @@ function check_varinfo_keys(varinfo, vns) end end -""" -Return the value of `vn` in `vi`. If one doesn't exist, sample and set it. -""" -function randr(vi::DynamicPPL.VarInfo, vn::VarName, dist::Distribution) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist) - r - else - vi[vn] - end -end - @testset "varinfo.jl" begin - @testset "VarInfo with NT of Metadata" begin - @model gdemo(x, y) = begin - s ~ InverseGamma(2, 3) - m ~ truncated(Normal(0.0, sqrt(s)), 0.0, 2.0) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - end - model = gdemo(1.0, 2.0) - - _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) - tvi = DynamicPPL.typed_varinfo(vi) - - meta = vi.metadata - for f in fieldnames(typeof(tvi.metadata)) - fmeta = getfield(tvi.metadata, f) - for vn in fmeta.vns - @test tvi[vn] == vi[vn] - ind = meta.idcs[vn] - tind = fmeta.idcs[vn] - @test meta.dists[ind] == fmeta.dists[tind] - @test meta.is_transformed[ind] == fmeta.is_transformed[tind] - range = meta.ranges[ind] - trange = fmeta.ranges[tind] - @test all(meta.vals[range] .== fmeta.vals[trange]) - end - end - end - @testset "Base" begin # Test Base functions: # in, keys, haskey, isempty, push!!, empty!!, # getindex, setindex!, getproperty, setproperty! - function test_base(vi_original) - vi = deepcopy(vi_original) - @test getlogjoint(vi) == 0 - @test isempty(vi[:]) - - vn = @varname x - dist = Normal(0, 1) - r = rand(dist) - - @test isempty(vi) - @test !haskey(vi, vn) - @test !(vn in keys(vi)) - vi = push!!(vi, vn, r, dist) - @test !isempty(vi) - @test haskey(vi, vn) - @test vn in keys(vi) - - @test length(vi[vn]) == 1 - @test vi[vn] == r - @test vi[:] == [r] - vi = DynamicPPL.setindex!!(vi, 2 * r, vn) - @test vi[vn] == 2 * r - @test vi[:] == [2 * r] - - # TODO(mhauru) Implement these functions for other VarInfo types too. - if vi isa DynamicPPL.UntypedVectorVarInfo - delete!(vi, vn) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - end - - vi = empty!!(vi) - @test isempty(vi) - vi = push!!(vi, vn, r, dist) - @test !isempty(vi) - end - - test_base(VarInfo()) - test_base(DynamicPPL.typed_varinfo(VarInfo())) - test_base(SimpleVarInfo()) - test_base(SimpleVarInfo(OrderedDict{VarName,Any}())) - test_base(SimpleVarInfo(DynamicPPL.VarNamedVector())) + vi = VarInfo() + @test getlogjoint(vi) == 0 + @test isempty(vi[:]) + + vn = @varname x + r = rand() + + @test isempty(vi) + @test !haskey(vi, vn) + @test !(vn in keys(vi)) + vi = push!!(vi, vn, r) + @test !isempty(vi) + @test haskey(vi, vn) + @test vn in keys(vi) + + @test length(vi[vn]) == 1 + @test vi[vn] == r + @test vi[:] == [r] + vi = DynamicPPL.setindex!!(vi, 2 * r, vn) + @test vi[vn] == 2 * r + @test vi[:] == [2 * r] + + vi = empty!!(vi) + @test isempty(vi) + vi = push!!(vi, vn, r) + @test !isempty(vi) end @testset "get/set/acclogp" begin - function test_varinfo_logp!(vi) - @test DynamicPPL.getlogjoint(vi) === 0.0 - vi = DynamicPPL.setlogprior!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 1.0 - @test DynamicPPL.getloglikelihood(vi) === 0.0 - @test DynamicPPL.getlogjoint(vi) === 1.0 - vi = DynamicPPL.acclogprior!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 0.0 - @test DynamicPPL.getlogjoint(vi) === 2.0 - vi = DynamicPPL.setloglikelihood!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 1.0 - @test DynamicPPL.getlogjoint(vi) === 3.0 - vi = DynamicPPL.accloglikelihood!!(vi, 1.0) - @test DynamicPPL.getlogprior(vi) === 2.0 - @test DynamicPPL.getloglikelihood(vi) === 2.0 - @test DynamicPPL.getlogjoint(vi) === 4.0 - end - vi = VarInfo() - test_varinfo_logp!(vi) - test_varinfo_logp!(DynamicPPL.typed_varinfo(vi)) - test_varinfo_logp!(SimpleVarInfo()) - test_varinfo_logp!(SimpleVarInfo(OrderedDict())) - test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + @test DynamicPPL.getlogjoint(vi) === 0.0 + vi = DynamicPPL.setlogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 1.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 1.0 + vi = DynamicPPL.acclogprior!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 0.0 + @test DynamicPPL.getlogjoint(vi) === 2.0 + vi = DynamicPPL.setloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 1.0 + @test DynamicPPL.getlogjoint(vi) === 3.0 + vi = DynamicPPL.accloglikelihood!!(vi, 1.0) + @test DynamicPPL.getlogprior(vi) === 2.0 + @test DynamicPPL.getloglikelihood(vi) === 2.0 + @test DynamicPPL.getlogjoint(vi) === 4.0 end @testset "logp accumulators" begin @@ -150,7 +85,7 @@ end lp_d = logpdf(Normal(), values.d) m = demo() | (; c=values.c, d=values.d) - vi = DynamicPPL.unflatten(VarInfo(m), collect(values)) + vi = DynamicPPL.unflatten!!(VarInfo(m), collect(values)) vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b @@ -284,39 +219,23 @@ end end @testset "is_transformed flag" begin - # Test is_transformed and set_transformed!! - function test_varinfo!(vi) - vn_x = @varname x - dist = Normal(0, 1) - r = rand(dist) - - push!!(vi, vn_x, r, dist) + vi = VarInfo() + vn_x = @varname x + r = rand() - # is_transformed is set by default - @test !is_transformed(vi, vn_x) + vi = push!!(vi, vn_x, r) - vi = set_transformed!!(vi, true, vn_x) - @test is_transformed(vi, vn_x) + # is_transformed is unset by default + @test !is_transformed(vi, vn_x) - vi = set_transformed!!(vi, false, vn_x) - @test !is_transformed(vi, vn_x) - end - vi = VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(DynamicPPL.typed_varinfo(vi))) - end + vi = set_transformed!!(vi, true, vn_x) + @test is_transformed(vi, vn_x) - @testset "push!! to VarInfo with NT of Metadata" begin - vn_x = @varname x - vn_y = @varname y - untyped_vi = VarInfo() - untyped_vi = push!!(untyped_vi, vn_x, 1.0, Normal(0, 1)) - typed_vi = DynamicPPL.typed_varinfo(untyped_vi) - typed_vi = push!!(typed_vi, vn_y, 2.0, Normal(0, 1)) - @test typed_vi[vn_x] == 1.0 - @test typed_vi[vn_y] == 2.0 + vi = set_transformed!!(vi, false, vn_x) + @test !is_transformed(vi, vn_x) end + # TODO(mhauru) Move this to a different file. @testset "returned on MCMCChains.Chains" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS chain = make_chain_from_prior(model, 10) @@ -354,39 +273,23 @@ end # change the VarInfo object. # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. vi = VarInfo() - meta = vi.metadata _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) - @test all(x -> !is_transformed(vi, x), meta.vns) + vals = values(vi) + + all_transformed(vi) = mapreduce(p -> p.second.linked, &, vi.values; init=true) + any_transformed(vi) = mapreduce(p -> p.second.linked, |, vi.values; init=false) + + @test !any_transformed(vi) # Check that linking and invlinking set the `is_transformed` flag accordingly - v = copy(meta.vals) vi = link!!(vi, model) - @test all(x -> is_transformed(vi, x), meta.vns) + @test all_transformed(vi) vi = invlink!!(vi, model) - @test all(x -> !is_transformed(vi, x), meta.vns) - @test meta.vals ≈ v atol = 1e-10 - - # Check that linking and invlinking preserves the values - vi = DynamicPPL.typed_varinfo(vi) - meta = vi.metadata - v_s = copy(meta.s.vals) - v_m = copy(meta.m.vals) - v_x = copy(meta.x.vals) - v_y = copy(meta.y.vals) - - @test all(x -> !is_transformed(vi, x), meta.s.vns) - @test all(x -> !is_transformed(vi, x), meta.m.vns) - vi = link!!(vi, model) - @test all(x -> is_transformed(vi, x), meta.s.vns) - @test all(x -> is_transformed(vi, x), meta.m.vns) - vi = invlink!!(vi, model) - @test all(x -> !is_transformed(vi, x), meta.s.vns) - @test all(x -> !is_transformed(vi, x), meta.m.vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 + @test !any_transformed(vi) + @test values(vi) ≈ vals atol = 1e-10 # Transform only one variable - all_vns = vcat(meta.s.vns, meta.m.vns, meta.x.vns, meta.y.vns) + all_vns = keys(vi) for vn in [ @varname(s), @varname(m), @@ -400,14 +303,11 @@ end @test !isempty(target_vns) @test !isempty(other_vns) vi = link!!(vi, (vn,), model) - @test all(x -> is_transformed(vi, x), target_vns) - @test all(x -> !is_transformed(vi, x), other_vns) + @test all_transformed(subset(vi, target_vns)) + @test !any_transformed(subset(vi, other_vns)) vi = invlink!!(vi, (vn,), model) - @test all(x -> !is_transformed(vi, x), all_vns) - @test meta.s.vals ≈ v_s atol = 1e-10 - @test meta.m.vals ≈ v_m atol = 1e-10 - @test meta.x.vals ≈ v_x atol = 1e-10 - @test meta.y.vals ≈ v_y atol = 1e-10 + @test !any_transformed(vi) + @test values(vi) ≈ vals atol = 1e-10 end end @@ -417,46 +317,17 @@ end vn = @varname(x) dist = truncated(Normal(); lower=0) - function test_linked_varinfo(model, vi) - # vn and dist are taken from the containing scope - vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test is_transformed(vi, vn) - @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) - @test getloglikelihood(vi) == 0.0 - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) - @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) - end - - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) - vi = DynamicPPL.set_transformed!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ## `typed_varinfo` - vi = DynamicPPL.typed_varinfo(model) + vi = DynamicPPL.VarInfo(model) vi = DynamicPPL.set_transformed!!(vi, true, vn) - test_linked_varinfo(model, vi) - - ### `SimpleVarInfo` - ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true) - test_linked_varinfo(model, vi) - - ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - test_linked_varinfo(model, vi) + vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test is_transformed(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) end @testset "values_as" begin @@ -471,32 +342,16 @@ end @testset "$(short_varinfo_name(vi))" for vi in varinfos # Just making sure. DynamicPPL.TestUtils.test_values(vi, example_values, vns) - - @testset "NamedTuple" begin - vals = values_as(vi, NamedTuple) - for vn in vns - if haskey(vals, Symbol(vn)) - # Assumed to be of form `(var"m[1]" = 1.0, ...)`. - @test getindex(vals, Symbol(vn)) == getindex(vi, vn) - else - # Assumed to be of form `(m = [1.0, ...], ...)`. - @test get(vals, vn) == getindex(vi, vn) - end - end + vals = values_as(vi, OrderedDict) + # All varnames in `vns` should be subsumed by one of `keys(vals)`. + @test all(vns) do vn + any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) end - - @testset "OrderedDict" begin - vals = values_as(vi, OrderedDict) - # All varnames in `vns` should be subsumed by one of `keys(vals)`. - @test all(vns) do vn - any(DynamicPPL.subsumes(vn_left, vn) for vn_left in keys(vals)) - end - # Iterate over `keys(vals)` because we might have scenarios such as - # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is - # the varname present in `vns`, not `@varname(m)`. - for vn in keys(vals) - @test getindex(vals, vn) == getindex(vi, vn) - end + # Iterate over `keys(vals)` because we might have scenarios such as + # `vals = OrderedDict(@varname(m) => [1.0])` but `@varname(m[1])` is + # the varname present in `vns`, not `@varname(m)`. + for vn in keys(vals) + @test getindex(vals, vn) == getindex(vi, vn) end end end @@ -546,8 +401,8 @@ end @test DynamicPPL.is_transformed(varinfo_linked, vn) end @test length(varinfo[:]) > length(varinfo_linked[:]) - varinfo_linked_unflattened = DynamicPPL.unflatten( - varinfo_linked, varinfo_linked[:] + varinfo_linked_unflattened = DynamicPPL.unflatten!!( + copy(varinfo_linked), varinfo_linked[:] ) @test length(varinfo_linked_unflattened[:]) == length(varinfo_linked[:]) @@ -591,13 +446,7 @@ end model, (; x=1.0), (@varname(x),); include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Skip the inconcrete `SimpleVarInfo` types, since checking for type - # stability for them doesn't make much sense anyway. - if varinfo isa SimpleVarInfo{<:AbstractDict} || - varinfo isa DynamicPPL.ThreadSafeVarInfo{<:SimpleVarInfo{<:AbstractDict}} - continue - end - @inferred DynamicPPL.unflatten(varinfo, varinfo[:]) + @inferred DynamicPPL.unflatten!!(varinfo, varinfo[:]) end end @@ -718,15 +567,6 @@ end @test varinfo_subset[:] == ground_truth end end - - # For certain varinfos we should have errors. - # `SimpleVarInfo{<:NamedTuple}` can only handle varnames with `identity`. - varinfo = varinfos[findfirst(Base.Fix2(isa, SimpleVarInfo{<:NamedTuple}), varinfos)] - @testset "$(short_varinfo_name(varinfo)): failure cases" begin - @test_throws ArgumentError subset( - varinfo, [@varname(s), @varname(m), @varname(x[1])] - ) - end end @testset "merge" begin @@ -817,9 +657,9 @@ end @testset "merge different dimensions" begin vn = @varname(x) vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0, Normal()) + vi_single = push!!(vi_single, vn, 1.0) vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6], Dirichlet(2, 1.0)) + vi_double = push!!(vi_double, vn, [0.5, 0.6]) @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] @test merge(vi_double, vi_single)[vn] == 1.0 end @@ -830,8 +670,9 @@ end n = length(varinfo[:]) # `Bool`. - @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(true, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten!!(varinfo, fill(true, n))) isa + typeof(float(1)) # `Int`. - @test getlogjoint(DynamicPPL.unflatten(varinfo, fill(1, n))) isa typeof(float(1)) + @test getlogjoint(DynamicPPL.unflatten!!(varinfo, fill(1, n))) isa typeof(float(1)) end end From 8018f451a6207a4908f1886ab928c25751dd712b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 17:09:22 +0000 Subject: [PATCH 14/56] Fix a couple of ArrayLikeBlock bugs --- src/varnamedtuple.jl | 11 +++++++---- test/varnamedtuple.jl | 12 +++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 5d83afc5e..a0640fb3b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -525,6 +525,8 @@ function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) wh end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) + # The original, non-bare inds is needed later for ArrayLikeBlock checks. + orig_inds = inds inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) if !(checkbounds(Bool, pa.mask, inds...) && all(@inbounds(getindex(pa.mask, inds...)))) @@ -561,7 +563,7 @@ function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) if !(first_elem isa ArrayLikeBlock) throw(err) end - if inds != first_elem.inds + if orig_inds != first_elem.inds # The requested indices do not match the ones used to set the value. throw(err) end @@ -655,6 +657,7 @@ function _needs_arraylikeblock(value, inds::Vararg{INDEX_TYPES}) end function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) + orig_inds = inds inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) pa = if checkbounds(Bool, pa.mask, inds...) @@ -679,7 +682,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) # some notion of size, and that size matches the indices that are being set. In this # case we wrap the value in an ArrayLikeBlock, and set all the individual indices # to point to that. - alb = ArrayLikeBlock(value, inds) + alb = ArrayLikeBlock(value, orig_inds) new_data = setindex!!(new_data, fill(alb, inds_size...), inds...) else new_data = setindex!!(new_data, value, inds...) @@ -1180,11 +1183,11 @@ end function _map_recursive!!(func, alb::ArrayLikeBlock, vn) new_block = _map_recursive!!(func, alb.block, vn) - if size(new_block) != size(alb.block) + if vnt_size(new_block) != vnt_size(alb.block) throw( DimensionMismatch( "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * - "from $(size(alb.block)) to $(size(new_block)).", + "from $(vnt_size(alb.block)) to $(vnt_size(new_block)).", ), ) end diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 655b8e9e5..18737f3d7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -6,7 +6,7 @@ using Test: @inferred, @test, @test_throws, @testset using DynamicPPL: DynamicPPL, @varname, VarNamedTuple, subset using DynamicPPL.VarNamedTuples: PartialArray, ArrayLikeBlock, map_pairs!!, map_values!!, apply!! -using AbstractPPL: VarName, concretize, prefix +using AbstractPPL: AbstractPPL, VarName, concretize, prefix using BangBang: setindex!!, empty!! """ @@ -305,6 +305,16 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, vn)) == x test_invariants(vnt) + vnt = VarNamedTuple() + vnt = @inferred(setindex!!(vnt, SizedThing((3,)), vn)) + @test haskey(vnt, vn) + @test vn in keys(vnt) + @test @inferred(getindex(vnt, vn)) == SizedThing((3,)) + # TODO(mhauru) The below test_invariants fails because AbstractPPL's ConretizedSlice + # objects don't respect the eval(Meta.parse(repr(...))) == ... property. + # test_invariants(vnt) + + vnt = VarNamedTuple() y = fill("a", (3, 2, 4)) x = y[:, 2, :] a = (; b=[nothing, nothing, (; c=(; d=reshape(y, (1, 3, 2, 4, 1))))]) From 1cbcda7b551528feaf24e5041614732edc2038d0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:02:15 +0000 Subject: [PATCH 15/56] Fix PartialArray map bug --- src/varnamedtuple.jl | 19 ++++++++- test/varnamedtuple.jl | 90 ++++++++++++++++++++++++++++++++----------- 2 files changed, 84 insertions(+), 25 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index a0640fb3b..4c739e681 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1169,12 +1169,27 @@ function _map_recursive!!(func, pa::PartialArray, vn) # We need to allocate a new data array. similar(pa.data, new_et) end + # Keep a dictionary of already-seen ArrayLikeBlocks to avoid redundant computations. + # This matters not only for performance, but also for correctness, because + # _map_recursive!! may mutate the value, and we don't want to mutate it multiple times. + albs_seen = Dict{ArrayLikeBlock,ArrayLikeBlock}() @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] val = pa.data[i] - ind = val isa ArrayLikeBlock ? val.inds : Tuple(i) + is_alb = val isa ArrayLikeBlock + if is_alb + if val in keys(albs_seen) + new_data[i] = albs_seen[val] + continue + end + end + ind = is_alb ? val.inds : Tuple(i) new_vn = IndexLens(ind) ∘ vn - new_data[i] = _map_recursive!!(func, pa.data[i], new_vn) + new_val = _map_recursive!!(func, pa.data[i], new_vn) + new_data[i] = new_val + if is_alb + albs_seen[val] = new_val + end end end # The above type inference may be overly conservative, so we concretise the eltype. diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 18737f3d7..a18885be7 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -10,13 +10,15 @@ using AbstractPPL: AbstractPPL, VarName, concretize, prefix using BangBang: setindex!!, empty!! """ - test_invariants(vnt::VarNamedTuple) + test_invariants(vnt::VarNamedTuple; skip=()) Test properties that should hold for all VarNamedTuples. Uses @test for all the tests. Intended to be called inside a @testset. + +`skip` is a tuple of symbols indicating which tests are to be skipped. """ -function test_invariants(vnt::VarNamedTuple) +function test_invariants(vnt::VarNamedTuple; skip=()) # These will be needed repeatedly. vnt_keys = keys(vnt) vnt_values = values(vnt) @@ -41,12 +43,14 @@ function test_invariants(vnt::VarNamedTuple) # The below eval test is a bit fragile: If any elements in vnt don't respect the same # reconstructability-from-repr property, this will fail. Likewise if any element uses # in its repr print out types that are not in scope in this module, it will fail. - vnt3 = eval(Meta.parse(repr(vnt))) - equality = (vnt == vnt3) - # The value may be `missing` if vnt itself has values that are missing. - @test equality === true || equality === missing - @test isequal(vnt, vnt3) - @test hash(vnt) == hash(vnt3) + if !(:parseeval in skip) + vnt3 = eval(Meta.parse(repr(vnt))) + equality = (vnt == vnt3) + # The value may be `missing` if vnt itself has values that are missing. + @test equality === true || equality === missing + @test isequal(vnt, vnt3) + @test hash(vnt) == hash(vnt3) + end # Check that merge with an empty VarNamedTuple is a no-op. @test isequal(merge(vnt, VarNamedTuple()), vnt) @@ -310,9 +314,9 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, vn) @test vn in keys(vnt) @test @inferred(getindex(vnt, vn)) == SizedThing((3,)) - # TODO(mhauru) The below test_invariants fails because AbstractPPL's ConretizedSlice + # TODO(mhauru) The below skip is needed because AbstractPPL's ConretizedSlice # objects don't respect the eval(Meta.parse(repr(...))) == ... property. - # test_invariants(vnt) + test_invariants(vnt; skip=(:parseeval,)) vnt = VarNamedTuple() y = fill("a", (3, 2, 4)) @@ -927,15 +931,21 @@ Base.size(st::SizedThing) = st.size vnt = @inferred( setindex!!(vnt, SizedThing((2, 2)), @varname(y.z[3, 2:3, 3, 2:3, 4])) ) + concretized_vn = concretize(@varname(v[:]), [0, 0]) + vnt = @inferred(setindex!!(vnt, SizedThing((2,)), concretized_vn)) vnt = @inferred(setindex!!(vnt, "", @varname(w[4][3][2, 1]))) - test_invariants(vnt) + # TODO(mhauru) The below skip is needed because AbstractPPL's ConretizedSlice + # objects don't respect the eval(Meta.parse(repr(...))) == ... property. + test_invariants(vnt; skip=(:parseeval,)) struct AnotherSizedThing{T<:Tuple} size::T end Base.size(st::AnotherSizedThing) = st.size + call_counter = 0 function f_val(val) + call_counter += 1 if val isa Int return val + 10 elseif val isa AbstractVector{Int} @@ -956,8 +966,9 @@ Base.size(st::SizedThing) = st.size f_pair(pair) = f_val(pair.second) val_reduction = mapreduce(pair -> pair.second, vcat, vnt; init=Any[]) - @test val_reduction == - vcat(Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), "") + @test val_reduction == vcat( + Any[], 1, [2, 2], [3.0], "a", 5.0, SizedThing((2, 2)), SizedThing((2,)), "" + ) key_reduction = mapreduce(pair -> pair.first, vcat, vnt; init=Any[]) @test key_reduction == vcat( @varname(a), @@ -967,18 +978,35 @@ Base.size(st::SizedThing) = st.size @varname(e.f[3].g.h[2].i), @varname(e.f[3].g.h[2].j), @varname(y.z[3, 2:3, 3, 2:3, 4]), + concretized_vn, @varname(w[4][3][2, 1]), ) + + call_counter = 0 reduction = mapreduce(f_pair, vcat, vnt; init=Any[]) - @test reduction == - vcat(Any[], 11, [12, 12], [2.0], "ab", 6.0, AnotherSizedThing((2, 2)), "b") + @test reduction == vcat( + Any[], + 11, + [12, 12], + [2.0], + "ab", + 6.0, + AnotherSizedThing((2, 2)), + AnotherSizedThing((2,)), + "b", + ) + # Check that f_pair gets called exactly once per element. + @test call_counter == length(keys(vnt)) # TODO(mhauru) This should hopefully be type stable, but fails to be so because of # some complex VarNames being too much for constant propagation. See comment in # src/varnamedtuple.jl for more. + call_counter = 0 vnt_mapped = map_pairs!!(f_pair, copy(vnt)) + # Check that f_pair gets called exactly once per element. + @test call_counter == length(keys(vnt)) @test vnt_mapped == map_values!!(f_val, copy(vnt)) - test_invariants(vnt_mapped) + test_invariants(vnt_mapped; skip=(:parseeval,)) @test @inferred(getindex(vnt_mapped, @varname(a))) == 11 @test @inferred(getindex(vnt_mapped, @varname(b[1:2]))) == [12, 12] @test @inferred(getindex(vnt_mapped, @varname(c.d))) == [2.0] @@ -986,29 +1014,38 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_mapped, @varname(e.f[3].g.h[2].j))) == 6.0 @test @inferred(getindex(vnt_mapped, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + @test @inferred(getindex(vnt_mapped, concretized_vn)) == AnotherSizedThing((2,)) @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" + call_counter = 0 vnt_applied = @inferred(apply!!(f_val, vnt, @varname(a))) - test_invariants(vnt_applied) + @test call_counter == 1 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [2, 2] vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(b[1:2]))) - test_invariants(vnt_applied) + # Unlike map_pairs!!, apply!! operates on the whole value at once, rather than + # element-wise, so this is only one more call. + @test call_counter == 2 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 @test @inferred(getindex(vnt_applied, @varname(b[1:2]))) == [12, 12] vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(c.d))) - test_invariants(vnt_applied) + @test call_counter == 3 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) - test_invariants(vnt_applied) + @test call_counter == 4 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) - test_invariants(vnt_applied) + @test call_counter == 5 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 6.0 @@ -1016,12 +1053,19 @@ Base.size(st::SizedThing) = st.size # know at compile time that this sets the only one, thus allowing the element type # to be AnotherSizedThing. vnt_applied = apply!!(f_val, vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4])) - test_invariants(vnt_applied) + @test call_counter == 6 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == AnotherSizedThing((2, 2)) + vnt_applied = apply!!(f_val, vnt_applied, concretized_vn) + @test call_counter == 7 + test_invariants(vnt_applied; skip=(:parseeval,)) + @test @inferred(getindex(vnt_applied, concretized_vn)) == AnotherSizedThing((2,)) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(w[4][3][2, 1]))) - test_invariants(vnt_applied) + @test call_counter == 8 + test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(w[4][3][2, 1]))) == "b" # map a function that maps every key => value pair to key => key. From 573cd5afea464c2148ee5767011d676e4ab15093 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:04:36 +0000 Subject: [PATCH 16/56] In VNTVarInfo, handle variables with varying dimensions correctly --- src/contexts/init.jl | 2 +- src/threadsafe.jl | 6 ++++++ src/vntvarinfo.jl | 26 ++++++++++++++++---------- test/varinfo.jl | 20 -------------------- 4 files changed, 23 insertions(+), 31 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index f137e07d6..b118280d0 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -363,7 +363,7 @@ function tilde_assume!!( vi = setindex!!(vi, val_to_insert, vn) else vi = if vi isa VNTVarInfo - push!!(vi, vn, val_to_insert, inverse(transform)) + push!!(vi, vn, val_to_insert, inverse(transform), size(x)) else push!!(vi, vn, val_to_insert, dist) end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 44b4da316..f168eb7c1 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -71,6 +71,12 @@ function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distributi return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) end +function BangBang.push!!( + vi::ThreadSafeVarInfo, vn::VarName, r, transform=typed_identity, orig_size=size(r) +) + return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, transform, orig_size) +end + syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 1ae9bc9d8..756bf8e34 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -6,12 +6,15 @@ end # TODO(mhauru) Make this renaming permanent. const VarInfo = VNTVarInfo -struct TransformedValue{ValType,TransformType} +struct TransformedValue{ValType,TransformType,SizeType} val::ValType linked::Bool transform::TransformType + size::SizeType end +VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size + VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) @@ -53,7 +56,7 @@ BangBang.empty!!(vi::VNTVarInfo) = VNTVarInfo(empty!!(vi.values), map(reset, vi. function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(val, old_tv.linked, old_tv.transform) + new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end @@ -61,12 +64,14 @@ end BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) = push!!(vi, vn, val) # TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. -function BangBang.push!!(vi::VNTVarInfo, vn::VarName, val, transform=typed_identity) +function BangBang.push!!( + vi::VNTVarInfo, vn::VarName, val, transform=typed_identity, orig_size=size(val) +) # TODO(mhauru) We should move away from having all values vectorised by default. # That messes with our use of unflatten though, so will require some thought. transform = _compose_no_identity(transform, from_vec_transform(val)) val = to_vec_transform(val)(val) - new_tv = TransformedValue(val, false, transform) + new_tv = TransformedValue(val, false, transform, orig_size) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end @@ -76,14 +81,14 @@ Base.values(vi::VNTVarInfo) = mapreduce(p -> p.second.val, push!, vi.values; ini function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(old_tv.val, linked, old_tv.transform) + new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) return VNTVarInfo(new_values, vi.accs) end function set_transformed!!(vi::VNTVarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv - TransformedValue(tv.val, linked, tv.transform) + TransformedValue(tv.val, linked, tv.transform, tv.size) end return VNTVarInfo(new_values, vi.accs) end @@ -121,9 +126,11 @@ function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) end function change_transform(tv::TransformedValue, new_transform, linked) + # Note that the transform may change the size of `val`, but it doesn't change the + # tv.size, since that one tracks the original size of the value before any transforms. val_untransformed, logjac1 = with_logabsdet_jacobian(tv.transform, tv.val) val_new, logjac2 = with_logabsdet_jacobian(inverse(new_transform), val_untransformed) - return TransformedValue(val_new, linked, new_transform), logjac1 + logjac2 + return TransformedValue(val_new, linked, new_transform, tv.size), logjac1 + logjac2 end function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) @@ -154,8 +161,7 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) # map!!, but it doesn't have access to the VarName. dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_pairs!!(new_values) do pair + new_values = map_pairs!!(vi.values) do pair vn, tv = pair dist = getindex(dists, vn) transform = from_linked_vec_transform(dist) @@ -324,7 +330,7 @@ function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) end len = length(old_val) new_val = get_next_chunk!(vci, len) - return TransformedValue(new_val, tv.linked, tv.transform) + return TransformedValue(new_val, tv.linked, tv.transform, tv.size) end return VNTVarInfo(new_values, vi.accs) end diff --git a/test/varinfo.jl b/test/varinfo.jl index 0a8e58eef..0bea67402 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -369,26 +369,6 @@ end model, value_true, varnames; include_threadsafe=true ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: this is broken since we'll end up trying to set - # - # varinfo[@varname(x[4:5])] = [x[4],] - # - # upon linking (since `x[4:5]` will be projected onto a 1-dimensional - # space). In the case of `SimpleVarInfo{<:NamedTuple}`, this results in - # calling `setindex!!(varinfo.values, [x[4],], @varname(x[4:5]))`, which - # in turn attempts to call `setindex!(varinfo.values.x, [x[4],], 4:5)`, - # i.e. a vector of length 1 (`[x[4],]`) being assigned to 2 indices (`4:5`). - @test_broken false - continue - end - - if DynamicPPL.has_varnamedvector(varinfo) && mutating - # NOTE: Can't handle mutating `link!` and `invlink!` `VarNamedVector`. - @test_broken false - continue - end - # Evaluate the model once to update the logp of the varinfo. varinfo = last(DynamicPPL.evaluate!!(model, varinfo)) From c353cbc7ba9fa117bfd330f3d806831ce92be176 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:25:31 +0000 Subject: [PATCH 17/56] Fix two small bugs --- src/utils.jl | 4 ++++ src/varnamedtuple.jl | 11 +++++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 11261334c..0e03c5cdc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,7 @@ +# subset is defined here to avoid circular dependencies between files. Methods for it are +# defined in other files. +function subset end + # singleton for indicating if no default arguments are present struct NoDefault end const NO_DEFAULT = NoDefault() diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 4c739e681..2f7e38ca6 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -388,6 +388,13 @@ function BangBang.empty!!(pa::PartialArray) return pa end +# This is a tad hacky: We use _mapreduce_recursive which requires a prefix VarName. We give +# it the non-sense @varname(_), and then strip it away with the mapping function, returning +# only the optic. +function Base.keys(pa::PartialArray) + return _mapreduce_recursive(pair -> first(pair).optic, push!, pa, @varname(_), Any[]) +end + # Length could be defined as a special case of mapreduce, but it's harder to keep it type # stable that way: If the element type is abstract, we end up calling _mapreduce_recursive # on an abstract type, which makes the type of the cumulant Any. @@ -1500,8 +1507,8 @@ function AbstractPPL.hasvalue(vnt::VarNamedTuple, vn::VarName, dist::MV_DIST_TYP # Note that _getindex, rather than getindex, skips the need to denseify PartialArrays. val = _getindex(vnt, vn) if !(val isa VarNamedTuple || val isa PartialArray) - # There is _a_ value. Where it's the right kind, we do not know, but returning true - # is no worse than `hasvalue` returning true for e.g. UnivariateDistributions + # There is _a_ value. Whether it's the right kind, we do not know, but returning + # true is no worse than `hasvalue` returning true for e.g. UnivariateDistributions # whenever there is at least some value. return true end From a36bb150d2c44ab365514650be00d4ba005b9280 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:59:29 +0000 Subject: [PATCH 18/56] Allow nested PartialArrays with ArrayLikeBlocks --- src/varnamedtuple.jl | 15 +++++++++++++-- test/varnamedtuple.jl | 2 ++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 2f7e38ca6..37158442b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -110,6 +110,8 @@ Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`. By default, this falls back onto `Base.size`, but can be overloaded for custom types. This notion of type is used to determine whether a value can be set into a `PartialArray` as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. + +A special return value of `Val(:pass)` indicates that the size check should be skipped. """ vnt_size(x) = size(x) @@ -294,6 +296,13 @@ end # The size of the .data field is an implementation detail. _internal_size(pa::PartialArray, args...) = size(pa.data, args...) +# Even though a PartialArray has no well-defined size, we still allow it to be used as an +# ArrayLikeBlock. This enables setting values for keys like @varname(x[1:3][1]), which will +# be stored as a PartialArray wrapped in an ArrayLikeBlock, stored in another PartialArray. +# Note that this bypasses _any_ size checks, so that e.g. @varname(x[1:3][1,15]) is also a +# valid key. +vnt_size(pa::PartialArray) = Val(:pass) + function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively # copy. @@ -677,7 +686,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) - if vnt_size(value) != inds_size + if vnt_size(value) !== Val(:pass) && vnt_size(value) != inds_size throw( DimensionMismatch( "Assigned value has size $(vnt_size(value)), which does not match " * @@ -1205,7 +1214,9 @@ end function _map_recursive!!(func, alb::ArrayLikeBlock, vn) new_block = _map_recursive!!(func, alb.block, vn) - if vnt_size(new_block) != vnt_size(alb.block) + sz_new = vnt_size(new_block) + sz_old = vnt_size(alb.block) + if sz_new !== Val(:pass) && sz_old !== Val(:pass) && sz_new != sz_old throw( DimensionMismatch( "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index a18885be7..1937ea189 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -333,6 +333,8 @@ Base.size(st::SizedThing) = st.size vnt = VarNamedTuple() vnt = @inferred(setindex!!(vnt, 1, @varname(a[1][1]))) @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 + vnt = @inferred(setindex!!(vnt, 1, @varname(ab[1:2][1]))) + @test @inferred(getindex(vnt, @varname(a[1][1]))) == 1 vnt = @inferred(setindex!!(vnt, [1], @varname(b[1].c[1]))) @test @inferred(getindex(vnt, @varname(b[1].c[1]))) == [1] vnt = @inferred(setindex!!(vnt, [1], @varname(e[3, 2].f[2, 2][10, 10]))) From bf05554ff96d2e7e72b8a64c497c296af9f24851 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 18:59:59 +0000 Subject: [PATCH 19/56] Stop testing for NamedDist with unconcrete VarName --- test/compiler.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 5101bd602..e4a9a2474 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -341,21 +341,21 @@ module Issue537 end end @testset "user-defined variable name" begin @model f1() = x ~ NamedDist(Normal(), :y) - @model f2() = x ~ NamedDist(Normal(), @varname(y[2][:, 1])) + @model f2() = x ~ NamedDist(Normal(), @varname(y[2][5, 1])) @model f3() = x ~ NamedDist(Normal(), @varname(y[1])) vi1 = VarInfo(f1()) vi2 = VarInfo(f2()) vi3 = VarInfo(f3()) - @test haskey(vi1.metadata, :y) - @test first(Base.keys(vi1.metadata.y)) == @varname(y) - @test haskey(vi2.metadata, :y) - @test first(Base.keys(vi2.metadata.y)) == @varname(y[2][:, 1]) - @test haskey(vi3.metadata, :y) - @test first(Base.keys(vi3.metadata.y)) == @varname(y[1]) + @test haskey(vi1, @varname(y)) + @test first(Base.keys(vi1)) == @varname(y) + @test haskey(vi2, @varname(y[2][5, 1])) + @test first(Base.keys(vi2)) == @varname(y[2][5, 1]) + @test haskey(vi3, @varname(y[1])) + @test first(Base.keys(vi3)) == @varname(y[1]) # Conditioning f1_c = f1() | (y=1,) - f2_c = f2() | NamedTuple((Symbol(@varname(y[2][:, 1])) => 1,)) + f2_c = f2() | NamedTuple((Symbol(@varname(y[2][5, 1])) => 1,)) f3_c = f3() | NamedTuple((Symbol(@varname(y[1])) => 1,)) @test f1_c() == 1 # TODO(torfjelde): We need conditioning for `Dict`. From 7857eaee73c3854bde5ae30f1626f70c91691432 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 19:39:42 +0000 Subject: [PATCH 20/56] Misc bugfixes --- docs/src/api.md | 2 +- ext/DynamicPPLMarginalLogDensitiesExt.jl | 2 +- src/DynamicPPL.jl | 4 +- src/abstract_varinfo.jl | 10 ++--- src/contexts/init.jl | 3 +- src/logdensityfunction.jl | 23 ++++++++---- src/test_utils/ad.jl | 2 +- src/values_as_in_model.jl | 2 +- test/logdensityfunction.jl | 45 ++++++++++------------- test/model.jl | 47 +++++++----------------- 10 files changed, 61 insertions(+), 79 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index f687fd90a..084639c07 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -468,7 +468,7 @@ DynamicPPL.maybe_invlink_before_eval!! ```@docs Base.merge(::AbstractVarInfo) DynamicPPL.subset -DynamicPPL.unflatten +DynamicPPL.unflatten!! ``` ### Evaluation Contexts diff --git a/ext/DynamicPPLMarginalLogDensitiesExt.jl b/ext/DynamicPPLMarginalLogDensitiesExt.jl index e28560872..8e53d8709 100644 --- a/ext/DynamicPPLMarginalLogDensitiesExt.jl +++ b/ext/DynamicPPLMarginalLogDensitiesExt.jl @@ -212,7 +212,7 @@ function DynamicPPL.VarInfo( if unmarginalized_params !== nothing full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params end - return DynamicPPL.unflatten(original_vi, full_params) + return DynamicPPL.unflatten!!(original_vi, full_params) end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d6248a4d0..84b8b2e68 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -196,14 +196,14 @@ include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -include("varnamedvector.jl") +# include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") # include("varinfo.jl") include("vntvarinfo.jl") -include("simple_varinfo.jl") +# include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 898b6caf9..ef1d92042 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -838,14 +838,14 @@ function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model ) # TODO(mhauru) This assumes that the user has defined the bijector using the same - # variable ordering as what `vi[:]` and `unflatten(vi, x)` use. This is a bad user + # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user # interface, and it's also dangerous for any AbstractVarInfo types that may not respect # a particular ordering, such as SimpleVarInfo{Dict}. b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) # Set parameters and add the logjac term. - vi = unflatten(vi, y) + vi = unflatten!!(vi, y) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) end @@ -910,7 +910,7 @@ function invlink!!( # Mildly confusing: we need to _add_ the logjac of the inverse transform, # because we are trying to remove the logjac of the forward transform # that was previously accumulated when linking. - vi = unflatten(vi, x) + vi = unflatten!!(vi, x) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, inv_logjac) end @@ -1013,11 +1013,11 @@ end # Utilities """ - unflatten(vi::AbstractVarInfo, x::AbstractVector) + unflatten!!(vi::AbstractVarInfo, x::AbstractVector) Return a new instance of `vi` with the values of `x` assigned to the variables. """ -function unflatten end +function unflatten!! end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index b118280d0..8899ba4d0 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -363,7 +363,8 @@ function tilde_assume!!( vi = setindex!!(vi, val_to_insert, vn) else vi = if vi isa VNTVarInfo - push!!(vi, vn, val_to_insert, inverse(transform), size(x)) + x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () + vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) else push!!(vi, vn, val_to_insert, dist) end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index adcb319c8..e7c83ecb4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -311,13 +311,22 @@ This function returns a VarNamedTuple mapping all VarNames to their correspondin `RangeAndLinked`. """ function get_ranges_and_linked(vi::VNTVarInfo) - offset = 1 - vnt = map!!(vi.values) do tv - val = tv.val - range = offset:(offset + length(val) - 1) - offset += length(val) - RangeAndLinked(range, tv.linked, size(val)) - end + # TODO(mhauru) Check that the closure doesn't cause type instability here. + vnt = VarNamedTuple() + vnt, _ = mapreduce( + identity, + function ((vnt, offset), pair) + vn, tv = pair + val = tv.val + range = offset:(offset + length(val) - 1) + offset += length(val) + ral = RangeAndLinked(range, tv.linked, size(val)) + vnt = setindex!!(vnt, ral, vn) + return vnt, offset + end, + vi.values; + init=(VarNamedTuple(), 1), + ) return vnt end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index a030b479e..6bcd9547e 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -242,7 +242,7 @@ Everything else is optional, and can be categorised into several groups: Finally, note that these only reflect the parameters used for _evaluating_ the gradient. If you also want to control the parameters used for _preparing_ the gradient, then you need to manually set these parameters in - the VarInfo object, for example using `vi = DynamicPPL.unflatten(vi, + the VarInfo object, for example using `vi = DynamicPPL.unflatten!!(vi, prep_params)`. You could then evaluate the gradient at a different set of parameters using the `params` keyword argument. diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 992cbdc8d..f7440d6ff 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -118,7 +118,7 @@ julia> # Perform computations in unconstrained space, e.g. changing the values o θ = [!varinfo[@varname(x)], rand(rng)]; julia> # Update the `VarInfo` with the new values. - varinfo_linked = DynamicPPL.unflatten(varinfo_linked, θ); + varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, θ); julia> # Determine the expected support of `y`. lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 77ae0ccab..777b91ee4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -17,31 +17,24 @@ using Mooncake: Mooncake @testset "LogDensityFunction: Correctness" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - @testset "$varinfo_func" for varinfo_func in [ - # DynamicPPL.untyped_varinfo, - DynamicPPL.typed_varinfo, - # DynamicPPL.untyped_vector_varinfo, - # DynamicPPL.typed_vector_varinfo, - ] - unlinked_vi = varinfo_func(m) - @testset "$islinked" for islinked in (false, true) - vi = if islinked - DynamicPPL.link!!(unlinked_vi, m) - else - unlinked_vi - end - ranges = DynamicPPL.get_ranges_and_linked(vi) - params = [x for x in vi[:]] - # Iterate over all variables - for vn in keys(vi) - # Check that `getindex_internal` returns the same thing as using the ranges - # directly - range_with_linked = ranges[vn] - @test params[range_with_linked.range] == - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) - # Check that the link status is correct - @test range_with_linked.is_linked == islinked - end + @testset "$islinked" for islinked in (false, true) + unlinked_vi = DynamicPPL.VarInfo(m) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ranges = DynamicPPL.get_ranges_and_linked(vi) + params = [x for x in vi[:]] + # Iterate over all variables + for vn in keys(vi) + # Check that `getindex_internal` returns the same thing as using the ranges + # directly + range_with_linked = ranges[vn] + @test params[range_with_linked.range] == + DynamicPPL.tovec(DynamicPPL.getindex_internal(vi, vn)) + # Check that the link status is correct + @test range_with_linked.is_linked == islinked end end end @@ -104,8 +97,8 @@ end @testset "LogDensityFunction: Type stability" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - unlinked_vi = DynamicPPL.VarInfo(m) @testset "$islinked" for islinked in (false, true) + unlinked_vi = DynamicPPL.VarInfo(m) vi = if islinked DynamicPPL.link!!(unlinked_vi, m) else diff --git a/test/model.jl b/test/model.jl index 29b9650a5..05688c224 100644 --- a/test/model.jl +++ b/test/model.jl @@ -26,8 +26,7 @@ function innermost_distribution_type(d::Distributions.Product) end is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_type_stable_varinfo(varinfo::DynamicPPL.NTVarInfo) = true -is_type_stable_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +is_type_stable_varinfo(varinfo::DynamicPPL.VNTVarInfo) = true const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @@ -230,24 +229,13 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() for i in 1:10 # Sample with large variations. r_raw = randn(length(vi[:])) * 10 - vi = DynamicPPL.unflatten(vi, r_raw) + vi = DynamicPPL.unflatten!!(vi, r_raw) @test vi[@varname(m)] == r_raw[1] @test vi[@varname(x)] != r_raw[2] model(vi) end end - @testset "Dynamic constraints, VectorVarInfo" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - for i in 1:10 - for vi_constructor in - [DynamicPPL.typed_vector_varinfo, DynamicPPL.untyped_vector_varinfo] - vi = vi_constructor(model) - @test vi[@varname(x)] >= vi[@varname(m)] - end - end - end - @testset "rand" begin model = GDEMO_DEFAULT @@ -510,26 +498,17 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end model = product_dirichlet() - varinfos = [ - DynamicPPL.untyped_varinfo(model), - DynamicPPL.typed_varinfo(model), - DynamicPPL.typed_simple_varinfo(model), - DynamicPPL.untyped_simple_varinfo(model), - ] - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - logjoint = getlogjoint(varinfo) # unlinked space - varinfo_linked = DynamicPPL.link(varinfo, model) - varinfo_linked_result = last( - DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) - ) - # getlogjoint should return the same result as before it was linked - @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) - @test getlogjoint(varinfo_linked) ≈ logjoint - # getlogjoint_internal shouldn't - @test getlogjoint_internal(varinfo_linked) ≈ - getlogjoint_internal(varinfo_linked_result) - @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) - end + varinfo = DynamicPPL.VarInfo(model) + logjoint = getlogjoint(varinfo) # unlinked space + varinfo_linked = DynamicPPL.link(varinfo, model) + varinfo_linked_result = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked))) + # getlogjoint should return the same result as before it was linked + @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end @testset "predict" begin From 16fe15056559a3cc1a2e783727f1a76b18c9d350 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Jan 2026 19:41:21 +0000 Subject: [PATCH 21/56] Stop running SVI and VNT tests --- test/runtests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index e0b42904c..e04b664fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -54,9 +54,9 @@ include("test_util.jl") include("accumulators.jl") include("compiler.jl") include("varnamedtuple.jl") - include("varnamedvector.jl") + # include("varnamedvector.jl") include("varinfo.jl") - include("simple_varinfo.jl") + # include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") include("linking.jl") From 51a518f19a84dc60eb6976f0a13372ecbf835ab6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 09:43:41 +0000 Subject: [PATCH 22/56] Fix LDF bug --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index e7c83ecb4..44fdad5a8 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -320,7 +320,7 @@ function get_ranges_and_linked(vi::VNTVarInfo) val = tv.val range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, tv.linked, size(val)) + ral = RangeAndLinked(range, tv.linked, tv.size) vnt = setindex!!(vnt, ral, vn) return vnt, offset end, From 1950a9345e3c1740ecc17f27c94c5786f9c959e6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 12:41:17 +0000 Subject: [PATCH 23/56] Fix some bugs, simplify (inv)linking --- src/contexts/init.jl | 15 +++---- src/test_utils/varinfo.jl | 2 +- src/vntvarinfo.jl | 83 +++++++++++++-------------------------- 3 files changed, 33 insertions(+), 67 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 8899ba4d0..45d6356f1 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -308,7 +308,6 @@ end function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) - in_varinfo = haskey(vi, vn) val, transform = init(ctx.rng, vn, dist, ctx.strategy) x, inv_logjac = with_logabsdet_jacobian(transform, val) # Determine whether to insert a transformed value into the VarInfo. @@ -317,7 +316,7 @@ function tilde_assume!!( # check the rest of the VarInfo to see if other variables are linked. # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. - insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) + insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) val_to_insert, logjac = if insert_transformed_value # Calculate the forward logjac and sum them up. lt = link_transform(dist) @@ -359,15 +358,11 @@ function tilde_assume!!( end # Add the new value to the VarInfo. `push!!` errors if the value already # exists, hence the need for setindex!!. - if in_varinfo - vi = setindex!!(vi, val_to_insert, vn) + vi = if vi isa VNTVarInfo + x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () + vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) else - vi = if vi isa VNTVarInfo - x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () - vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) - else - push!!(vi, vn, val_to_insert, dist) - end + push!!(vi, vn, val_to_insert, dist) end # Neither of these set the `trans` flag so we have to do it manually if # necessary. diff --git a/src/test_utils/varinfo.jl b/src/test_utils/varinfo.jl index 25f4fd04f..bbfb0b662 100644 --- a/src/test_utils/varinfo.jl +++ b/src/test_utils/varinfo.jl @@ -35,7 +35,7 @@ function setup_varinfos( ) vi = DynamicPPL.VarInfo(model) vi = update_values!!(vi, example_values, varnames) - last(DynamicPPL.evaluate!!(model, vi)) + vi = last(DynamicPPL.evaluate!!(model, vi)) varinfos = if include_threadsafe (vi, DynamicPPL.ThreadSafeVarInfo(deepcopy(vi))) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 756bf8e34..a0392334c 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -125,28 +125,24 @@ function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) return getindex(vi.values, vn).transform end -function change_transform(tv::TransformedValue, new_transform, linked) - # Note that the transform may change the size of `val`, but it doesn't change the - # tv.size, since that one tracks the original size of the value before any transforms. - val_untransformed, logjac1 = with_logabsdet_jacobian(tv.transform, tv.val) - val_new, logjac2 = with_logabsdet_jacobian(inverse(new_transform), val_untransformed) - return TransformedValue(val_new, linked, new_transform, tv.size), logjac1 + logjac2 -end - function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_pairs!!(new_values) do pair + new_values = map_pairs!!(vi.values) do pair vn, tv = pair - if !any(x -> subsumes(x, vn), vns) + if vns !== nothing && !any(x -> subsumes(x, vn), vns) # Not one of the target variables. return tv end dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac + vec_transform = from_vec_transform(dist) + link_transform = from_linked_vec_transform(dist) + val_untransformed, logjac1 = with_logabsdet_jacobian(vec_transform, tv.val) + val_new, logjac2 = with_logabsdet_jacobian( + inverse(link_transform), val_untransformed + ) + new_tv = TransformedValue(val_new, true, link_transform, tv.size) + cumulative_logjac += logjac1 + logjac2 return new_tv end vi = VNTVarInfo(new_values, vi.accs) @@ -156,39 +152,29 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) return vi end -function link!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) - # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use - # map!!, but it doesn't have access to the VarName. - dists = extract_priors(model, vi) - cumulative_logjac = zero(LogProbType) - new_values = map_pairs!!(vi.values) do pair - vn, tv = pair - dist = getindex(dists, vn) - transform = from_linked_vec_transform(dist) - new_tv, logjac = change_transform(tv, transform, true) - cumulative_logjac += logjac - return new_tv - end - vi = VNTVarInfo(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi +function link!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) + return link!!(t, vi, nothing, model) end function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) + dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_pairs!!(new_values) do pair + new_values = map_pairs!!(vi.values) do pair vn, tv = pair - if !any(x -> subsumes(x, vn), vns) + if vns !== nothing && !any(x -> subsumes(x, vn), vns) # Not one of the target variables. return tv end - current_val = tv.transform(tv.val) - transform = from_vec_transform(current_val) - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac + current_val = tv.val + dist = getindex(dists, vn) + vec_transform = from_vec_transform(dist) + link_transform = from_linked_vec_transform(dist) + val_untransformed, logjac1 = with_logabsdet_jacobian(link_transform, current_val) + val_new, logjac2 = with_logabsdet_jacobian( + inverse(vec_transform), val_untransformed + ) + new_tv = TransformedValue(val_new, false, vec_transform, tv.size) + cumulative_logjac += logjac1 + logjac2 return new_tv end vi = VNTVarInfo(new_values, vi.accs) @@ -198,23 +184,8 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) return vi end -function invlink!!(::DynamicTransformation, vi::VNTVarInfo, model::Model) - # TODO(mhauru) This is probably pretty inefficient. Do this better. Would like to use - # map!!, but it doesn't have access to the VarName. - cumulative_logjac = zero(LogProbType) - new_values = vi.values - new_values = map_values!!(new_values) do tv - current_val = tv.transform(tv.val) - transform = from_vec_transform(current_val) - new_tv, logjac = change_transform(tv, transform, false) - cumulative_logjac += logjac - return new_tv - end - vi = VNTVarInfo(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi +function invlink!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) + return invlink!!(t, vi, nothing, model) end function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model) From 051521a83b96ed537b4272acea0b6dff9c8843af Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 12:42:23 +0000 Subject: [PATCH 24/56] Fix some tests --- test/contexts.jl | 214 ++++++++++++++++++++------------------------ test/debug_utils.jl | 54 ----------- test/lkj.jl | 2 +- 3 files changed, 98 insertions(+), 172 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 9621013ac..24f6445f5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -414,18 +414,6 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() end @testset "InitContext" begin - empty_varinfos = [ - ("untyped+metadata", VarInfo()), - ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), - ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), - ( - "typed+VNV", - DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), - ), - ("SVI+NamedTuple", SimpleVarInfo()), - ("Svi+Dict", SimpleVarInfo(OrderedDict{VarName,Any}())), - ] - @model function test_init_model() x ~ Normal() y ~ MvNormal(fill(x, 2), I) @@ -438,19 +426,17 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # Check that init!! can generate values that weren't there # previously. model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - this_vi = deepcopy(empty_vi) - _, vi = DynamicPPL.init!!(model, this_vi, strategy) - @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) - x, y = vi[@varname(x)], vi[@varname(y)] - @test x isa Real - @test y isa AbstractVector{<:Real} - @test length(y) == 2 - (; logprior, loglikelihood) = getlogp(vi) - @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == - logprior - @test logpdf(Normal(), 1.0) == loglikelihood - end + empty_vi = VarInfo() + this_vi = deepcopy(empty_vi) + _, vi = DynamicPPL.init!!(model, this_vi, strategy) + @test Set(keys(vi)) == Set([@varname(x), @varname(y)]) + x, y = vi[@varname(x)], vi[@varname(y)] + @test x isa Real + @test y isa AbstractVector{<:Real} + @test length(y) == 2 + (; logprior, loglikelihood) = getlogp(vi) + @test logpdf(Normal(), x) + logpdf(MvNormal(fill(x, 2), I), y) == logprior + @test logpdf(Normal(), 1.0) == loglikelihood end end @@ -458,40 +444,38 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "replacing old values: $(typeof(strategy))" begin # Check that init!! can overwrite values that were already there. model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - # start by generating some rubbish values - vi = deepcopy(empty_vi) - old_x, old_y = 100000.00, [300000.00, 500000.00] - push!!(vi, @varname(x), old_x, Normal()) - push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) - # then overwrite it - _, new_vi = DynamicPPL.init!!(model, vi, strategy) - new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] - # check that the values are (presumably) different - @test old_x != new_x - @test old_y != new_y - end + empty_vi = VarInfo() + # start by generating some rubbish values + vi = deepcopy(empty_vi) + old_x, old_y = 100000.00, [300000.00, 500000.00] + push!!(vi, @varname(x), old_x, Normal()) + push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + # then overwrite it + _, new_vi = DynamicPPL.init!!(model, vi, strategy) + new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] + # check that the values are (presumably) different + @test old_x != new_x + @test old_y != new_y end end function test_rng_respected(strategy::AbstractInitStrategy) @testset "check that RNG is respected: $(typeof(strategy))" begin model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - _, vi1 = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), strategy - ) - _, vi2 = DynamicPPL.init!!( - Xoshiro(468), model, deepcopy(empty_vi), strategy - ) - _, vi3 = DynamicPPL.init!!( - Xoshiro(469), model, deepcopy(empty_vi), strategy - ) - @test vi1[@varname(x)] == vi2[@varname(x)] - @test vi1[@varname(y)] == vi2[@varname(y)] - @test vi1[@varname(x)] != vi3[@varname(x)] - @test vi1[@varname(y)] != vi3[@varname(y)] - end + empty_vi = VarInfo() + _, vi1 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi2 = DynamicPPL.init!!( + Xoshiro(468), model, deepcopy(empty_vi), strategy + ) + _, vi3 = DynamicPPL.init!!( + Xoshiro(469), model, deepcopy(empty_vi), strategy + ) + @test vi1[@varname(x)] == vi2[@varname(x)] + @test vi1[@varname(y)] == vi2[@varname(y)] + @test vi1[@varname(x)] != vi3[@varname(x)] + @test vi1[@varname(y)] != vi3[@varname(y)] end end @@ -578,21 +562,20 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x, y=my_y) params_dict = Dict(@varname(x) => my_x, @varname(y) => my_y) model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_nt) - ) - @test vi[@varname(x)] == my_x - @test vi[@varname(y)] == my_y - logp_nt = getlogp(vi) - _, vi = DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_dict) - ) - @test vi[@varname(x)] == my_x - @test vi[@varname(y)] == my_y - logp_dict = getlogp(vi) - @test logp_nt == logp_dict - end + empty_vi = VarInfo() + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_nt = getlogp(vi) + _, vi = DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict) + ) + @test vi[@varname(x)] == my_x + @test vi[@varname(y)] == my_y + logp_dict = getlogp(vi) + @test logp_nt == logp_dict end @testset "given only partial parameters" begin @@ -600,56 +583,53 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() params_nt = (; x=my_x) params_dict = Dict(@varname(x) => my_x) model = test_init_model() - @testset "$vi_name" for (vi_name, empty_vi) in empty_varinfos - @testset "with InitFromPrior fallback" begin - _, vi = DynamicPPL.init!!( - Xoshiro(468), - model, - deepcopy(empty_vi), - InitFromParams(params_nt, InitFromPrior()), - ) - @test vi[@varname(x)] == my_x - nt_y = vi[@varname(y)] - @test nt_y isa AbstractVector{<:Real} - @test length(nt_y) == 2 - _, vi = DynamicPPL.init!!( - Xoshiro(469), - model, - deepcopy(empty_vi), - InitFromParams(params_dict, InitFromPrior()), - ) - @test vi[@varname(x)] == my_x - dict_y = vi[@varname(y)] - @test dict_y isa AbstractVector{<:Real} - @test length(dict_y) == 2 - # the values should be different since we used different seeds - @test dict_y != nt_y - end + empty_vi = VarInfo() + @testset "with InitFromPrior fallback" begin + _, vi = DynamicPPL.init!!( + Xoshiro(468), + model, + deepcopy(empty_vi), + InitFromParams(params_nt, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + nt_y = vi[@varname(y)] + @test nt_y isa AbstractVector{<:Real} + @test length(nt_y) == 2 + _, vi = DynamicPPL.init!!( + Xoshiro(469), + model, + deepcopy(empty_vi), + InitFromParams(params_dict, InitFromPrior()), + ) + @test vi[@varname(x)] == my_x + dict_y = vi[@varname(y)] + @test dict_y isa AbstractVector{<:Real} + @test length(dict_y) == 2 + # the values should be different since we used different seeds + @test dict_y != nt_y + end - @testset "with no fallback" begin - # These just don't have an entry for `y`. - @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) - ) - @test_throws ErrorException DynamicPPL.init!!( - model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) - ) - # We also explicitly test the case where `y = missing`. - params_nt_missing = (; x=my_x, y=missing) - params_dict_missing = Dict( - @varname(x) => my_x, @varname(y) => missing - ) - @test_throws ErrorException DynamicPPL.init!!( - model, - deepcopy(empty_vi), - InitFromParams(params_nt_missing, nothing), - ) - @test_throws ErrorException DynamicPPL.init!!( - model, - deepcopy(empty_vi), - InitFromParams(params_dict_missing, nothing), - ) - end + @testset "with no fallback" begin + # These just don't have an entry for `y`. + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_nt, nothing) + ) + @test_throws ErrorException DynamicPPL.init!!( + model, deepcopy(empty_vi), InitFromParams(params_dict, nothing) + ) + # We also explicitly test the case where `y = missing`. + params_nt_missing = (; x=my_x, y=missing) + params_dict_missing = Dict(@varname(x) => my_x, @varname(y) => missing) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_nt_missing, nothing), + ) + @test_throws ErrorException DynamicPPL.init!!( + model, + deepcopy(empty_vi), + InitFromParams(params_dict_missing, nothing), + ) end end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index f950f6b45..343282480 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -67,60 +67,6 @@ model = ModelOuterWorking2() @test check_model(model, VarInfo(model); error_on_failure=true) end - - @testset "subsumes (x then x[1])" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x ~ MvNormal(zeros(2), I) - x[1] ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end - - @testset "subsumes (x[1] then x)" begin - @model function buggy_subsumes_demo_model() - x = Vector{Float64}(undef, 2) - x[1] ~ Normal() - x ~ MvNormal(zeros(2), I) - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end - - @testset "subsumes (x.a then x)" begin - @model function buggy_subsumes_demo_model() - x = (a=nothing,) - x.a ~ Normal() - x ~ Normal() - return nothing - end - buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) - - @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) - issuccess = check_model(buggy_model, varinfo) - @test !issuccess - @test_throws ErrorException check_model( - buggy_model, varinfo; error_on_failure=true - ) - end end @testset "NaN in data" begin diff --git a/test/lkj.jl b/test/lkj.jl index 5c5603aba..bab3ce185 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -37,7 +37,7 @@ end last(DynamicPPL.init!!(model, VarInfo(), init_strategy)) for _ in 1:n_samples ] corr_matrices = map(samples) do s - M = reshape(s.metadata.vals, (2, 2)) + M = reshape(DynamicPPL.getindex_internal(s, @varname(x)), (2, 2)) pd_from_triangular(M, uplo) end @test vec(mean(corr_matrices)) ≈ target_mean atol = _lkj_atol From 9812ad0b950f83c1911ce8f5857c047ae1335f2e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 12:44:53 +0000 Subject: [PATCH 25/56] Comment back in include of old VI files --- src/DynamicPPL.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 84b8b2e68..d6248a4d0 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -196,14 +196,14 @@ include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -# include("varnamedvector.jl") +include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") # include("varinfo.jl") include("vntvarinfo.jl") -# include("simple_varinfo.jl") +include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") From 6d44954d7420ccf80386e01985fe7b288e678921 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 14:19:56 +0000 Subject: [PATCH 26/56] Remote JET extension and experimental.jl --- Project.toml | 3 - docs/Project.toml | 2 - docs/src/api.md | 9 --- ext/DynamicPPLJETExt.jl | 56 ----------------- src/DynamicPPL.jl | 22 ------- src/experimental.jl | 98 ------------------------------ test/Project.toml | 2 - test/ext/DynamicPPLJETExt.jl | 113 ----------------------------------- test/runtests.jl | 3 - 9 files changed, 308 deletions(-) delete mode 100644 ext/DynamicPPLJETExt.jl delete mode 100644 src/experimental.jl delete mode 100644 test/ext/DynamicPPLJETExt.jl diff --git a/Project.toml b/Project.toml index 1b899c906..a1a95c822 100644 --- a/Project.toml +++ b/Project.toml @@ -30,7 +30,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" @@ -40,7 +39,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] -DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake"] @@ -62,7 +60,6 @@ DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" InteractiveUtils = "1" -JET = "0.9, 0.10, 0.11" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" diff --git a/docs/Project.toml b/docs/Project.toml index d5fa9a637..5bdb0a2db 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -8,7 +8,6 @@ DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392" @@ -24,7 +23,6 @@ DocumenterMermaid = "0.1, 0.2" DynamicPPL = "0.40" FillArrays = "0.13, 1" ForwardDiff = "0.10, 1" -JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" MCMCChains = "5, 6, 7" MarginalLogDensities = "0.4" diff --git a/docs/src/api.md b/docs/src/api.md index 084639c07..bfc5dcc8d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -553,15 +553,6 @@ init get_param_eltype ``` -### Choosing a suitable VarInfo - -There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: - -```@docs -DynamicPPL.Experimental.determine_suitable_varinfo -DynamicPPL.Experimental.is_suitable_varinfo -``` - ### Converting VarInfos to/from chains It is a fairly common operation to want to convert a collection of `VarInfo` objects into a chains object for downstream analysis. diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl deleted file mode 100644 index cb35c5ffb..000000000 --- a/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,56 +0,0 @@ -module DynamicPPLJETExt - -using DynamicPPL: DynamicPPL -using JET: JET - -function DynamicPPL.Experimental.is_suitable_varinfo( - model::DynamicPPL.Model, varinfo::DynamicPPL.AbstractVarInfo; only_dppl::Bool=true -) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(model, varinfo) - # If specified, we only check errors originating somewhere in the DynamicPPL.jl. - # This way we don't just fall back to untyped if the user's code is the issue. - result = if only_dppl - JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) - else - JET.report_call(f, argtypes) - end - return length(JET.get_reports(result)) == 0, result -end - -function DynamicPPL.Experimental._determine_varinfo_jet( - model::DynamicPPL.Model; only_dppl::Bool=true -) - # Generate a typed varinfo to test model type stability with - varinfo = DynamicPPL.typed_varinfo(model) - - # Check type stability of evaluation (i.e. DefaultContext) - model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) - eval_issuccess, eval_result = DynamicPPL.Experimental.is_suitable_varinfo( - model, varinfo; only_dppl - ) - if !eval_issuccess - @debug "Evaluation with typed varinfo failed with the following issues:" - @debug eval_result - end - - # Check type stability of initialisation (i.e. InitContext) - model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) - init_issuccess, init_result = DynamicPPL.Experimental.is_suitable_varinfo( - model, varinfo; only_dppl - ) - if !init_issuccess - @debug "Initialisation with typed varinfo failed with the following issues:" - @debug init_result - end - - # If neither of them failed, we can return the typed varinfo as it's type stable. - return if (eval_issuccess && init_issuccess) - varinfo - else - # Warn the user that we can't use the type stable one. - @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." - DynamicPPL.untyped_varinfo(model) - end -end - -end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d6248a4d0..b84c076be 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -211,7 +211,6 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") -include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -223,27 +222,6 @@ include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) function __init__() - # Better error message if users forget to load JET.jl - Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ - requires_jet = - exc.f === DynamicPPL.Experimental._determine_varinfo_jet && - length(argtypes) >= 2 && - argtypes[1] <: Model && - argtypes[2] <: AbstractContext - requires_jet |= - exc.f === DynamicPPL.Experimental.is_suitable_varinfo && - length(argtypes) >= 3 && - argtypes[1] <: Model && - argtypes[2] <: AbstractContext && - argtypes[3] <: AbstractVarInfo - if requires_jet - print( - io, - "\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).", - ) - end - end - # Same for MarginalLogDensities.jl Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ requires_mld = diff --git a/src/experimental.jl b/src/experimental.jl deleted file mode 100644 index 8c82dca68..000000000 --- a/src/experimental.jl +++ /dev/null @@ -1,98 +0,0 @@ -module Experimental - -using DynamicPPL: DynamicPPL - -# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. -""" - is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...) - -Check if the `model` supports evaluation using the provided `varinfo`. - -!!! warning - Loading JET.jl is required before calling this function. - -# Arguments -- `model`: The model to verify the support for. -- `varinfo`: The varinfo to verify the support for. - -# Keyword Arguments -- `only_dppl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. - -# Returns -- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. -- `report`: The result of `report_call` from JET.jl. -""" -function is_suitable_varinfo end - -# Internal hook for JET.jl to overload. -function _determine_varinfo_jet end - -""" - determine_suitable_varinfo(model; only_dppl::Bool=true) - -Return a suitable varinfo for the given `model`. - -See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). - -!!! warning - For full functionality, this requires JET.jl to be loaded. - If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. - -# Arguments -- `model`: The model for which to determine the varinfo. - -# Keyword Arguments -- `only_dppl`: If `true`, only consider error reports within DynamicPPL.jl. - -# Examples - -```jldoctest -julia> using DynamicPPL.Experimental: determine_suitable_varinfo - -julia> using JET: JET # needs to be loaded for full functionality - -julia> @model function model_with_random_support() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end -model_with_random_support (generic function with 2 methods) - -julia> model = model_with_random_support(); - -julia> # Typed varinfo cannot handle this random support model properly - # as using a single execution of the model will not see all random variables. - # Hence, this this model requires untyped varinfo. - vi = determine_suitable_varinfo(model); -┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. -└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 - -julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) -true - -julia> # In contrast, a simple model with no random support can be handled by typed varinfo. - @model model_with_static_support() = x ~ Normal() -model_with_static_support (generic function with 2 methods) - -julia> vi = determine_suitable_varinfo(model_with_static_support()); - -julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) -true -``` -""" -function determine_suitable_varinfo(model::DynamicPPL.Model; only_dppl::Bool=true) - # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. - return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing - _determine_varinfo_jet(model; only_dppl) - else - # Warn the user. - @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." - # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). - DynamicPPL.typed_varinfo(model, context) - end -end - -end diff --git a/test/Project.toml b/test/Project.toml index 927954ba4..9c146eb97 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,7 +14,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" @@ -46,7 +45,6 @@ Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" ForwardDiff = "0.10.12, 1" -JET = "0.9, 0.10, 0.11" LogDensityProblems = "2" MCMCChains = "7.2.1" MacroTools = "0.5.6" diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl deleted file mode 100644 index e46c25113..000000000 --- a/test/ext/DynamicPPLJETExt.jl +++ /dev/null @@ -1,113 +0,0 @@ -@testset "DynamicPPLJETExt.jl" begin - @testset "determine_suitable_varinfo" begin - @model function demo1() - x ~ Bernoulli() - if x - y ~ Normal() - else - z ~ Normal() - end - end - model = demo1() - @test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa - DynamicPPL.UntypedVarInfo - - @model demo2() = x ~ Normal() - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa - DynamicPPL.NTVarInfo - - @model function demo3() - # Just making sure that nothing strange happens when type inference fails. - x = Vector(undef, 1) - x[1] ~ Bernoulli() - if x[1] - y ~ Normal() - else - z ~ Normal() - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa - DynamicPPL.UntypedVarInfo - - # Evaluation works (and it would even do so in practice), but sampling - # will fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. - @model function demo4() - x ~ Bernoulli() - if x - y ~ Normal() - else - y ~ Cauchy() # different distibution, but same transformation - end - end - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa - DynamicPPL.UntypedVarInfo - - # In this model, the type error occurs in the user code rather than in DynamicPPL. - @model function demo5() - x ~ Normal() - xs = Any[] - push!(xs, x) - # `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the - # correct `zero` method. As a result, this code will run, but JET will raise this is an issue. - return sum(xs) - end - # Should pass if we're only checking the tilde statements. - @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa - DynamicPPL.NTVarInfo - # Should fail if we're including errors in the model body. - @test DynamicPPL.Experimental.determine_suitable_varinfo( - demo5(); only_dppl=false - ) isa DynamicPPL.UntypedVarInfo - end - - @testset "demo models" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.ALL_MODELS - if model.f === DynamicPPL.TestUtils.demo_lkjchol - # TODO(mhauru) - # The LKJCholesky model fails with JET. The problem is not with Turing but - # with Distributions, and ultimately this in LinearAlgebra: - # julia> v = @view rand(2,2)[:,1]; - # - # julia> JET.@report_call norm(v) - # ═════ 2 possible errors found ═════ - # blahblah - # The below trivial call to @test is just marking that there's something - # broken here. - @test false broken = true - continue - end - # Use debug logging below. - varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) - # Check that the inferred varinfo is indeed suitable for evaluation - f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f_eval, argtypes_eval) - - # For our demo models, they should all result in typed. - is_typed = varinfo isa DynamicPPL.NTVarInfo - @test is_typed - # If the test failed, check what the type stability problem was for - # the typed varinfo. This is mostly useful for debugging from test - # logs. - if !is_typed - @info "Model `$(model.f)` is not type stable with typed varinfo." - typed_vi = DynamicPPL.typed_varinfo(model) - - @info "Evaluating with DefaultContext:" - model = DynamicPPL.setleafcontext(model, DynamicPPL.DefaultContext()) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f, argtypes) - - @info "Initialising with InitContext:" - model = DynamicPPL.setleafcontext(model, DynamicPPL.InitContext()) - f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( - model, varinfo - ) - JET.test_call(f, argtypes) - end - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index e04b664fe..23dda437b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,8 +28,6 @@ using Test using Distributions using LinearAlgebra # Diagonal -using JET: JET - using Combinatorics: combinations using OrderedCollections: OrderedSet @@ -76,7 +74,6 @@ include("test_util.jl") include("logdensityfunction.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") - include("ext/DynamicPPLJETExt.jl") include("ext/DynamicPPLMarginalLogDensitiesExt.jl") end @testset "ad" begin From d5bfa2c1748e0b3d629da1d79de8000d035357a4 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 18:12:04 +0000 Subject: [PATCH 27/56] Reimplement bijector.jl --- src/bijector.jl | 111 +++++++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/src/bijector.jl b/src/bijector.jl index 31fe7cd88..576205641 100644 --- a/src/bijector.jl +++ b/src/bijector.jl @@ -1,60 +1,65 @@ +struct BijectorAccumulator <: AbstractAccumulator + bijectors::Vector{Any} + sizes::Vector{Int} +end -""" - bijector(model::Model[, sym2ranges = Val(false)]) +BijectorAccumulator() = BijectorAccumulator(Bijectors.Bijector[], UnitRange{Int}[]) + +function Base.:(==)(acc1::BijectorAccumulator, acc2::BijectorAccumulator) + return (acc1.bijectors == acc2.bijectors && acc1.sizes == acc2.sizes) +end + +function Base.copy(acc::BijectorAccumulator) + return BijectorAccumulator(copy(acc.bijectors), copy(acc.sizes)) +end + +accumulator_name(::Type{<:BijectorAccumulator}) = :Bijector + +function _zero(acc::BijectorAccumulator) + return BijectorAccumulator(empty(acc.bijectors), empty(acc.sizes)) +end +reset(acc::BijectorAccumulator) = _zero(acc) +split(acc::BijectorAccumulator) = _zero(acc) +function combine(acc1::BijectorAccumulator, acc2::BijectorAccumulator) + return BijectorAccumulator( + vcat(acc1.bijectors, acc2.bijectors), vcat(acc1.sizes, acc2.sizes) + ) +end + +function accumulate_assume!!(acc::BijectorAccumulator, val, logjac, vn, right) + bijector = _compose_no_identity( + to_linked_vec_transform(right), from_vec_transform(right) + ) + push!(acc.bijectors, bijector) + push!(acc.sizes, prod(output_size(to_vec_transform(right), right); init=1)) + return acc +end + +accumulate_observe!!(acc::BijectorAccumulator, right, left, vn) = acc -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. """ -function Bijectors.bijector( - model::DynamicPPL.Model, - (::Val{sym2ranges})=Val(false); - varinfo=DynamicPPL.VarInfo(model), -) where {sym2ranges} - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end + bijector(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - # Wrap a bijector `f` such that it operates on vectors of length `prod(in_size)` - # and produces a vector of length `prod(Bijectors.output(f, in_size))`. - in_size = size(d) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(b, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - reshape_outer ∘ b ∘ reshape_inner - end - end +Returns a `Stacked <: Bijector` which maps from constrained to unconstrained space. + +The input to the bijector is a vector of values for the whole model, like the input to +`unflatten!!`. These are in constrained space, i.e., respecting variable constraints. +The output is a vector of unconstrained values. - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) +`init_strategy` is passed to `DynamicPPL.init!!` to determine what values the model is +evaluated with. This may affect the results if the prior distributions or constraints of +variables are dependent on other variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + vi = OnlyAccsVarInfo((BijectorAccumulator(),)) + vi = last(DynamicPPL.init!!(model, vi, init_strategy)) + acc = getacc(vi, Val(:Bijector)) + ranges = foldl(acc.sizes; init=UnitRange{Int}[]) do cumulant, sz + last_index = length(cumulant) > 0 ? last(cumulant).stop : 0 + push!(cumulant, (last_index + 1):(last_index + sz)) + return cumulant end + return Bijectors.Stacked(acc.bijectors, ranges) end From eb903e1ac79bf611422cd5cc132081af4ee70acb Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 18:13:06 +0000 Subject: [PATCH 28/56] Move linking code to VarInfo, fix ProductNamedDistribution bijector, etc --- src/contexts/init.jl | 64 ++------------------------- src/model.jl | 4 +- src/test_utils/model_interface.jl | 4 +- src/utils.jl | 19 +++++++- src/vntvarinfo.jl | 73 ++++++++++++++++++++++++++----- test/model.jl | 2 +- 6 files changed, 89 insertions(+), 77 deletions(-) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 45d6356f1..65ea08ec5 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -309,68 +309,10 @@ function tilde_assume!!( ctx::InitContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo ) val, transform = init(ctx.rng, vn, dist, ctx.strategy) - x, inv_logjac = with_logabsdet_jacobian(transform, val) - # Determine whether to insert a transformed value into the VarInfo. - # If the VarInfo alrady had a value for this variable, we will - # keep the same linked status as in the original VarInfo. If not, we - # check the rest of the VarInfo to see if other variables are linked. - # is_transformed(vi) returns true if vi is nonempty and all variables in vi - # are linked. - insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) - val_to_insert, logjac = if insert_transformed_value - # Calculate the forward logjac and sum them up. - lt = link_transform(dist) - y, fwd_logjac = with_logabsdet_jacobian(lt, x) - transform = _compose_no_identity(transform, lt) - # Note that if we use VectorWithRanges with a full VarInfo, this double-Jacobian - # calculation wastes a lot of time going from linked vectorised -> unlinked -> - # linked, and `inv_logjac` will also just be the negative of `fwd_logjac`. - # - # However, `VectorWithRanges` is only really used with `OnlyAccsVarInfo`, in which - # case this branch is never hit (since `in_varinfo` will always be false). It does - # mean that the combination of InitFromParams{<:VectorWithRanges} with a full, - # linked, VarInfo will be very slow. That should never really be used, though. So - # (at least for now) we can leave this branch in for full generality with other - # combinations of init strategies / VarInfo. - # - # TODO(penelopeysm): Figure out one day how to refactor this. The crux of the issue - # is that the transform used by `VectorWithRanges` is `from_linked_VEC_transform`, - # which is NOT the same as `inverse(link_transform)` (because there is an additional - # vectorisation step). We need `init` and `tilde_assume!!` to share this information - # but it's not clear right now how to do this. In my opinion, there are a couple of - # potential ways forward: - # - # 1. Just remove metadata entirely so that there is never any need to construct - # a linked vectorised value again. This would require us to use VAIMAcc as the only - # way of getting values. I consider this the best option, but it might take a long - # time. - # - # 2. Clean up the behaviour of bijectors so that we can have a complete separation - # between the linking and vectorisation parts of it. That way, `x` can either be - # unlinked, unlinked vectorised, linked, or linked vectorised, and regardless of - # which it is, we should only need to apply at most one linking and one - # vectorisation transform. Doing so would allow us to remove the first call to - # `with_logabsdet_jacobian`, and instead compose and/or uncompose the - # transformations before calling `with_logabsdet_jacobian` once. - y, -inv_logjac + fwd_logjac - else - x, -inv_logjac - end - # Add the new value to the VarInfo. `push!!` errors if the value already - # exists, hence the need for setindex!!. - vi = if vi isa VNTVarInfo - x_size = hasmethod(size, Tuple{typeof(x)}) ? size(x) : () - vi = push!!(vi, vn, val_to_insert, inverse(transform), x_size) - else - push!!(vi, vn, val_to_insert, dist) - end - # Neither of these set the `trans` flag so we have to do it manually if - # necessary. - if insert_transformed_value - vi = set_transformed!!(vi, true, vn) - end + x, init_logjac = with_logabsdet_jacobian(transform, val) + vi, logjac = setindex_with_dist!!(vi, x, dist, vn) # `accumulate_assume!!` wants untransformed values as the second argument. - vi = accumulate_assume!!(vi, x, logjac, vn, dist) + vi = accumulate_assume!!(vi, x, init_logjac + logjac, vn, dist) # We always return the untransformed value here, as that will determine # what the lhs of the tilde-statement is set to. return x, vi diff --git a/src/model.jl b/src/model.jl index 8bfeaf6a1..91558ecdc 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1085,7 +1085,9 @@ Base.nameof(model::Model{<:Function}) = nameof(model.f) Generate a sample of type `T` from the prior distribution of the `model`. """ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} - x = last(init!!(rng, model, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}()))) + vi = VarInfo() + vi = setaccs!!(vi, DynamicPPL.AccumulatorTuple()) + x = last(init!!(rng, model, vi)) return values_as(x, T) end diff --git a/src/test_utils/model_interface.jl b/src/test_utils/model_interface.jl index e7fb16fbe..9914c05ca 100644 --- a/src/test_utils/model_interface.jl +++ b/src/test_utils/model_interface.jl @@ -89,10 +89,10 @@ function logprior_true_with_logabsdet_jacobian end Return a collection of `VarName` as they are expected to appear in the model. Even though it is recommended to implement this by hand for a particular `Model`, -a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +a default implementation using [`VarInfo`](@ref) is provided. """ function varnames(model::Model) - result = collect(keys(last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict()))))) + result = collect(keys(last(DynamicPPL.init!!(model, VarInfo())))) # Concretise the element type. return [x for x in result] end diff --git a/src/utils.jl b/src/utils.jl index 0e03c5cdc..ba79f94b4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -406,7 +406,7 @@ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) -struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} +struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} <: Bijectors.Bijector dists::T # The `i`-th input range corresponds to the segment of the input vector # that belongs to the `i`-th distribution. @@ -439,13 +439,30 @@ end return expr end +@generated function (inv_trf::Bijectors.Inverse{<:ProductNamedTupleUnvecTransform{names}})( + x::NamedTuple{names} +) where {names} + exprs = Expr[] + for name in names + push!(exprs, :(to_vec_transform(inv_trf.orig.dists.$name)(x.$name))) + end + return :(vcat($(exprs...))) +end + function from_vec_transform(dist::Distributions.ProductNamedTupleDistribution) return ProductNamedTupleUnvecTransform(dist) end + function Bijectors.with_logabsdet_jacobian(f::ProductNamedTupleUnvecTransform, x) return f(x), zero(LogProbType) end +function Bijectors.with_logabsdet_jacobian( + inv_f::Bijectors.Inverse{<:ProductNamedTupleUnvecTransform}, x +) + return inv_f(x), zero(LogProbType) +end + # This function returns the length of the vector that the function from_vec_transform # expects. This helps us determine which segment of a concatenated vector belongs to which # variable. diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index a0392334c..6ce1a861e 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -61,19 +61,39 @@ function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) return VNTVarInfo(new_values, vi.accs) end -BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) = push!!(vi, vn, val) - -# TODO(mhauru) The arguments are in the wrong order, but this is the current convetion. -function BangBang.push!!( - vi::VNTVarInfo, vn::VarName, val, transform=typed_identity, orig_size=size(val) -) +# TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, +# we need `dist` to determine the linking transformation (or even just the vectorisation +# transformation, in the case of ProductNamedTupleDistribions), and if we leave the work +# of doing the transformation to the caller, it'll be done even when e.g. using +# OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be removed +# once VAIMAcc is the only way to get values out of an evaluation. +function setindex_with_dist!!(vi::VNTVarInfo, val, dist::Distribution, vn::VarName) + # Determine whether to insert a transformed value into `vi`. + # If the VarInfo alrady had a value for this variable, we will + # keep the same linked status as in the original VarInfo. If not, we + # check the rest of the VarInfo to see if other variables are linked. + # is_transformed(vi) returns true if vi is nonempty and all variables in vi + # are linked. + insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) # TODO(mhauru) We should move away from having all values vectorised by default. # That messes with our use of unflatten though, so will require some thought. - transform = _compose_no_identity(transform, from_vec_transform(val)) - val = to_vec_transform(val)(val) - new_tv = TransformedValue(val, false, transform, orig_size) - new_values = setindex!!(vi.values, new_tv, vn) - return VNTVarInfo(new_values, vi.accs) + transform = if insert_transformed_value + from_linked_vec_transform(dist) + else + from_vec_transform(dist) + end + transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) + val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () + tv = TransformedValue(transformed_val, insert_transformed_value, transform, val_size) + vi = VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) + return vi, logjac +end + +function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) + transform = from_vec_transform(val) + transformed_val = inverse(transform)(val) + tv = TransformedValue(transformed_val, false, transform, size(val)) + return VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) end Base.keys(vi::VNTVarInfo) = keys(vi.values) @@ -86,6 +106,20 @@ function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) return VNTVarInfo(new_values, vi.accs) end +# VNTVarInfo does not care whether the transformation was Static or Dynamic, it just tracks +# whether one was applied at all. +function set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation, vn::VarName) + return set_transformed!!(vi, true, vn) +end + +set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) + +function set_transformed!!(vi::VNTVarInfo, ::NoTransformation, vn::VarName) + return set_transformed!!(vi, false, vn) +end + +set_transformed!!(vi::VNTVarInfo, ::NoTransformation) = set_transformed!!(vi, false) + function set_transformed!!(vi::VNTVarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv TransformedValue(tv.val, linked, tv.transform, tv.size) @@ -238,6 +272,23 @@ function values_as(vi::VNTVarInfo, ::Type{T}) where {T<:AbstractDict} end, vi.values; init=T()) end +# TODO(mhauru) I really dislike this sort of conversion to Symbols, but it's the current +# interface provided by rand(::Model). We should change that to return a VarNamedTuple +# instead, and then this method (and any other values_as methods for NamedTuple) could be +# removed. +function values_as(vi::VNTVarInfo, ::Type{NamedTuple}) + return mapfoldl( + identity, + function (cumulant, pair) + vn, tv = pair + val = tv.transform(tv.val) + return setindex!!(cumulant, val, Symbol(vn)) + end, + vi.values; + init=NamedTuple(), + ) +end + # TODO(mhauru) These two are now redundant, just conforming to the old interface # temporarily. function untyped_varinfo( diff --git a/test/model.jl b/test/model.jl index 05688c224..7c5dc2fcc 100644 --- a/test/model.jl +++ b/test/model.jl @@ -311,7 +311,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test logjoint(model, x) != DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian(model, x...) # Ensure `varnames` is implemented. - vi = last(DynamicPPL.init!!(model, SimpleVarInfo(OrderedDict{VarName,Any}()))) + vi = last(DynamicPPL.init!!(model, VarInfo())) @test all(collect(keys(vi)) .== DynamicPPL.TestUtils.varnames(model)) # Ensure `posterior_mean` is implemented. @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) From 469a71514626f916c4594c2a9592d80b8752b9c3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 18:54:33 +0000 Subject: [PATCH 29/56] Mark a test as broken --- src/abstract_varinfo.jl | 7 +++++-- test/chains.jl | 10 +++++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ef1d92042..0c15cb9c7 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -839,12 +839,15 @@ function link!!( ) # TODO(mhauru) This assumes that the user has defined the bijector using the same # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user - # interface, and it's also dangerous for any AbstractVarInfo types that may not respect - # a particular ordering, such as SimpleVarInfo{Dict}. + # interface. b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) # Set parameters and add the logjac term. + # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant + # that getindex(vi, vn) would apply the default link transform of the distribution. With + # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any + # transform. Neither is correct, rather the transform should be the inverse of b. vi = unflatten!!(vi, y) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) diff --git a/test/chains.jl b/test/chains.jl index 608a9a9cf..d69d2d4ca 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -68,8 +68,16 @@ end @testset "ParamsWithStats from LogDensityFunction" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.ALL_MODELS - unlinked_vi = VarInfo(m) + if m.f === DynamicPPL.TestUtils.demo_static_transformation + # TODO(mhauru) These tests are broken for demo_static_transformation because + # vi[vn] doesn't know which transform it should apply to the internally stored + # value. This requires a rethink, either of StaticTransformation or of what the + # comparison in this test should be. + @test false broken = true + continue + end @testset "$islinked" for islinked in (false, true) + unlinked_vi = VarInfo(m) vi = if islinked DynamicPPL.link!!(unlinked_vi, m) else From 89a8396a35ab7db7f5ad9d6ba8d28d88b0c5b147 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 19:32:33 +0000 Subject: [PATCH 30/56] Various bugfixes --- src/threadsafe.jl | 15 +++++---------- test/compiler.jl | 10 +++++----- test/contexts.jl | 6 ++++-- test/varinfo.jl | 12 ++++++------ 4 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f168eb7c1..88200680a 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -67,16 +67,6 @@ end has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) -function BangBang.push!!(vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, dist) -end - -function BangBang.push!!( - vi::ThreadSafeVarInfo, vn::VarName, r, transform=typed_identity, orig_size=size(r) -) - return Accessors.@set vi.varinfo = push!!(vi.varinfo, vn, r, transform, orig_size) -end - syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -168,6 +158,11 @@ function getindex(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}, dist::D return getindex(vi.varinfo, vns, dist) end +function setindex_with_dist!!(vi::ThreadSafeVarInfo, val, dist::Distribution, vn::VarName) + vi_inner, logjac = setindex_with_dist!!(vi.varinfo, val, dist, vn) + return Accessors.@set(vi.varinfo = vi_inner), logjac +end + function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vn::VarName) return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vn) end diff --git a/test/compiler.jl b/test/compiler.jl index e4a9a2474..8d0105947 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -604,9 +604,9 @@ module Issue537 end # Even if the return-value is `AbstractVarInfo`, we should return # a `Tuple` with `AbstractVarInfo` in the second component too. @model demo() = return __varinfo__ - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) - @test svi == SimpleVarInfo() - @test retval == svi + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) + @test vi == VarInfo() + @test retval == vi # We should not be altering return-values other than at top-level. @model function demo() @@ -615,11 +615,11 @@ module Issue537 end f(x) = return x^2 return f(1.0) end - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) @test retval isa Float64 @model demo() = x ~ Normal() - retval, svi = DynamicPPL.init!!(demo(), SimpleVarInfo()) + retval, vi = DynamicPPL.init!!(demo(), VarInfo()) # Return-value when using `to_submodel` @model inner() = x ~ Normal() diff --git a/test/contexts.jl b/test/contexts.jl index 24f6445f5..435561267 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -448,8 +448,10 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() # start by generating some rubbish values vi = deepcopy(empty_vi) old_x, old_y = 100000.00, [300000.00, 500000.00] - push!!(vi, @varname(x), old_x, Normal()) - push!!(vi, @varname(y), old_y, MvNormal(fill(old_x, 2), I)) + vi, _ = DynamicPPL.setindex_with_dist!!(vi, old_x, Normal(), @varname(x)) + vi, _ = DynamicPPL.setindex_with_dist!!( + vi, old_y, MvNormal(fill(old_x, 2), I), @varname(y) + ) # then overwrite it _, new_vi = DynamicPPL.init!!(model, vi, strategy) new_x, new_y = new_vi[@varname(x)], new_vi[@varname(y)] diff --git a/test/varinfo.jl b/test/varinfo.jl index 0bea67402..1d01a0cf8 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -17,7 +17,7 @@ end @testset "varinfo.jl" begin @testset "Base" begin # Test Base functions: - # in, keys, haskey, isempty, push!!, empty!!, + # in, keys, haskey, isempty, setindex!!, empty!!, # getindex, setindex!, getproperty, setproperty! vi = VarInfo() @@ -30,7 +30,7 @@ end @test isempty(vi) @test !haskey(vi, vn) @test !(vn in keys(vi)) - vi = push!!(vi, vn, r) + vi = setindex!!(vi, r, vn) @test !isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -44,7 +44,7 @@ end vi = empty!!(vi) @test isempty(vi) - vi = push!!(vi, vn, r) + vi = setindex!!(vi, r, vn) @test !isempty(vi) end @@ -223,7 +223,7 @@ end vn_x = @varname x r = rand() - vi = push!!(vi, vn_x, r) + vi = setindex!!(vi, r, vn_x) # is_transformed is unset by default @test !is_transformed(vi, vn_x) @@ -637,9 +637,9 @@ end @testset "merge different dimensions" begin vn = @varname(x) vi_single = VarInfo() - vi_single = push!!(vi_single, vn, 1.0) + vi_single = setindex!!(vi_single, 1.0, vn) vi_double = VarInfo() - vi_double = push!!(vi_double, vn, [0.5, 0.6]) + vi_double = setindex!!(vi_double, [0.5, 0.6], vn) @test merge(vi_single, vi_double)[vn] == [0.5, 0.6] @test merge(vi_double, vi_single)[vn] == 1.0 end From 8cf8ab0dfa88afdf1e0efa9ca60cbc138686121a Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 12 Jan 2026 20:05:10 +0000 Subject: [PATCH 31/56] Remove SimpleVarInfo, VarNamedVector, and the old VarInfo type --- benchmarks/benchmarks.jl | 61 +- benchmarks/src/DynamicPPLBenchmarks.jl | 43 +- benchmarks/src/Models.jl | 2 +- docs/src/api.md | 24 +- docs/src/internals/varinfo.md | 295 +--- ext/DynamicPPLChainRulesCoreExt.jl | 2 - src/DynamicPPL.jl | 6 +- src/abstract_varinfo.jl | 101 +- src/contexts/transformation.jl | 2 +- src/logdensityfunction.jl | 50 - src/model.jl | 111 ++ src/simple_varinfo.jl | 647 --------- src/threadsafe.jl | 2 - src/utils.jl | 5 - src/varinfo.jl | 1810 ------------------------ src/varnamedvector.jl | 1674 ---------------------- src/vntvarinfo.jl | 9 + test/model.jl | 10 +- test/runtests.jl | 2 - test/simple_varinfo.jl | 345 ----- test/test_util.jl | 23 - test/varinfo.jl | 36 +- test/varnamedvector.jl | 711 ---------- 23 files changed, 198 insertions(+), 5773 deletions(-) delete mode 100644 src/simple_varinfo.jl delete mode 100644 src/varinfo.jl delete mode 100644 src/varnamedvector.jl delete mode 100644 test/simple_varinfo.jl delete mode 100644 test/varnamedvector.jl diff --git a/benchmarks/benchmarks.jl b/benchmarks/benchmarks.jl index e8ffa7e0b..5be32fdef 100644 --- a/benchmarks/benchmarks.jl +++ b/benchmarks/benchmarks.jl @@ -9,9 +9,7 @@ using StableRNGs: StableRNG rng = StableRNG(23) -colnames = [ - "Model", "Dim", "AD Backend", "VarInfo", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)" -] +colnames = ["Model", "Dim", "AD Backend", "Linked", "t(eval)/t(ref)", "t(grad)/t(eval)"] function print_results(results_table; to_json=false) if to_json # Print to the given file as JSON @@ -58,31 +56,26 @@ function run(; to_json=false) end # Specify the combinations to test: - # (Model Name, model instance, VarInfo choice, AD backend, linked) + # (Model Name, model instance, AD backend, linked) chosen_combinations = [ ( "Simple assume observe", Models.simple_assume_observe(randn(rng)), - :typed, :forwarddiff, false, ), - ("Smorgasbord", smorgasbord_instance, :typed, :forwarddiff, false), - ("Smorgasbord", smorgasbord_instance, :simple_namedtuple, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :simple_dict, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :untyped_vector, :forwarddiff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :reversediff, true), - ("Smorgasbord", smorgasbord_instance, :typed, :mooncake, true), - ("Smorgasbord", smorgasbord_instance, :typed, :enzyme, true), - ("Loop univariate 1k", loop_univariate1k, :typed, :mooncake, true), - ("Multivariate 1k", multivariate1k, :typed, :mooncake, true), - ("Loop univariate 10k", loop_univariate10k, :typed, :mooncake, true), - ("Multivariate 10k", multivariate10k, :typed, :mooncake, true), - ("Dynamic", Models.dynamic(), :typed, :mooncake, true), - ("Submodel", Models.parent(randn(rng)), :typed, :mooncake, true), - ("LDA", lda_instance, :typed, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :forwarddiff, false), + ("Smorgasbord", smorgasbord_instance, :forwarddiff, true), + ("Smorgasbord", smorgasbord_instance, :reversediff, true), + ("Smorgasbord", smorgasbord_instance, :mooncake, true), + ("Smorgasbord", smorgasbord_instance, :enzyme, true), + ("Loop univariate 1k", loop_univariate1k, :mooncake, true), + ("Multivariate 1k", multivariate1k, :mooncake, true), + ("Loop univariate 10k", loop_univariate10k, :mooncake, true), + ("Multivariate 10k", multivariate10k, :mooncake, true), + ("Dynamic", Models.dynamic(), :mooncake, true), + ("Submodel", Models.parent(randn(rng)), :mooncake, true), + ("LDA", lda_instance, :reversediff, true), ] # Time running a model-like function that does not use DynamicPPL, as a reference point. @@ -94,13 +87,13 @@ function run(; to_json=false) @info "Reference evaluation time: $(reference_time) seconds" results_table = Tuple{ - String,Int,String,String,Bool,Union{Float64,Missing},Union{Float64,Missing} + String,Int,String,Bool,Union{Float64,Missing},Union{Float64,Missing} }[] - for (model_name, model, varinfo_choice, adbackend, islinked) in chosen_combinations - @info "Running benchmark for $model_name, $varinfo_choice, $adbackend, $islinked" + for (model_name, model, adbackend, islinked) in chosen_combinations + @info "Running benchmark for $model_name, $adbackend, $islinked" relative_eval_time, relative_ad_eval_time = try - results = benchmark(model, varinfo_choice, adbackend, islinked) + results = benchmark(model, adbackend, islinked) @info " t(eval) = $(results.primal_time)" @info " t(grad) = $(results.grad_time)" (results.primal_time / reference_time), @@ -115,7 +108,6 @@ function run(; to_json=false) model_name, model_dimension(model, islinked), string(adbackend), - string(varinfo_choice), islinked, relative_eval_time, relative_ad_eval_time, @@ -131,9 +123,8 @@ struct TestCase model_name::String dim::Integer ad_backend::String - varinfo::String linked::Bool - TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:5])...) + TestCase(d::Dict{String,Any}) = new((d[c] for c in colnames[1:4])...) end function combine(head_filename::String, base_filename::String) head_results = try @@ -148,23 +139,22 @@ function combine(head_filename::String, base_filename::String) Dict{String,Any}[] end @info "Loaded $(length(base_results)) results from $base_filename" - # Identify unique combinations of (Model, Dim, AD Backend, VarInfo, Linked) + # Identify unique combinations of (Model, Dim, AD Backend, Linked) head_testcases = Dict( - TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in head_results + TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in head_results ) base_testcases = Dict( - TestCase(d) => (d[colnames[6]], d[colnames[7]]) for d in base_results + TestCase(d) => (d[colnames[5]], d[colnames[6]]) for d in base_results ) all_testcases = union(Set(keys(head_testcases)), Set(keys(base_testcases))) @info "$(length(all_testcases)) unique test cases found" sorted_testcases = sort( - collect(all_testcases); by=(c -> (c.model_name, c.linked, c.varinfo, c.ad_backend)) + collect(all_testcases); by=(c -> (c.model_name, c.linked, c.ad_backend)) ) results_table = Tuple{ String, Int, String, - String, Bool, String, String, @@ -179,12 +169,12 @@ function combine(head_filename::String, base_filename::String) sublabels = ["base", "this PR", "speedup"] results_colnames = [ [ - EmptyCells(5), + EmptyCells(4), MultiColumn(3, "t(eval) / t(ref)"), MultiColumn(3, "t(grad) / t(eval)"), MultiColumn(3, "t(grad) / t(ref)"), ], - [colnames[1:5]..., sublabels..., sublabels..., sublabels...], + [colnames[1:4]..., sublabels..., sublabels..., sublabels...], ] sprint_float(x::Float64) = @sprintf("%.2f", x) sprint_float(m::Missing) = "err" @@ -211,7 +201,6 @@ function combine(head_filename::String, base_filename::String) c.model_name, c.dim, c.ad_backend, - c.varinfo, c.linked, sprint_float(base_eval), sprint_float(head_eval), diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 0dc7ece6e..6bb8672c9 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -1,6 +1,6 @@ module DynamicPPLBenchmarks -using DynamicPPL: VarInfo, SimpleVarInfo, VarName +using DynamicPPL: VarInfo, VarName using DynamicPPL: DynamicPPL using DynamicPPL.TestUtils.AD: run_ad, NoTest using ADTypes: ADTypes @@ -23,7 +23,7 @@ Return the dimension of `model`, accounting for linking, if any. """ function model_dimension(model, islinked) vi = VarInfo() - model(StableRNG(23), vi) + vi = last(DynamicPPL.init!!(StableRNG(23), model, vi)) if islinked vi = DynamicPPL.link(vi, model) end @@ -52,53 +52,24 @@ function to_backend(x::Union{AbstractString,Symbol}) end """ - benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) + benchmark(model, adbackend::Symbol, islinked::Bool) -Benchmark evaluation and gradient calculation for `model` using the selected varinfo type -and AD backend. - -Available varinfo choices: - • `:untyped` → uses `DynamicPPL.untyped_varinfo(model)` - • `:typed` → uses `DynamicPPL.typed_varinfo(model)` - • `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())` - • `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs) +Benchmark evaluation and gradient calculation for `model` using the selected AD backend. The AD backend should be specified as a Symbol (e.g. `:forwarddiff`, `:reversediff`, `:zygote`). `islinked` determines whether to link the VarInfo for evaluation. """ -function benchmark(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::Bool) +function benchmark(model, adbackend::Symbol, islinked::Bool) rng = StableRNG(23) - + vi = VarInfo(rng, model) adbackend = to_backend(adbackend) - - vi = if varinfo_choice == :untyped - DynamicPPL.untyped_varinfo(rng, model) - elseif varinfo_choice == :typed - DynamicPPL.typed_varinfo(rng, model) - elseif varinfo_choice == :simple_namedtuple - SimpleVarInfo{Float64}(model(rng)) - elseif varinfo_choice == :simple_dict - retvals = model(rng) - vns = [VarName{k}() for k in keys(retvals)] - SimpleVarInfo{Float64}(Dict(zip(vns, values(retvals)))) - elseif varinfo_choice == :typed_vector - DynamicPPL.typed_vector_varinfo(rng, model) - elseif varinfo_choice == :untyped_vector - DynamicPPL.untyped_vector_varinfo(rng, model) - else - error("Unknown varinfo choice: $varinfo_choice") - end - - adbackend = to_backend(adbackend) - if islinked vi = DynamicPPL.link(vi, model) end - return run_ad( model, adbackend; varinfo=vi, benchmark=true, test=NoTest(), verbose=false ) end -end # module +end diff --git a/benchmarks/src/Models.jl b/benchmarks/src/Models.jl index 2c881aa95..76d4b2e93 100644 --- a/benchmarks/src/Models.jl +++ b/benchmarks/src/Models.jl @@ -2,7 +2,7 @@ Models for benchmarking Turing.jl. Each model returns a NamedTuple of all the random variables in the model that are not -observed (this is used for constructing SimpleVarInfos). +observed. """ module Models diff --git a/docs/src/api.md b/docs/src/api.md index bfc5dcc8d..a506c793e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -333,27 +333,18 @@ Please see the documentation of [AbstractPPL.jl](https://github.com/TuringLang/A ### Data Structures of Variables -DynamicPPL provides different data structures used in for storing samples and accumulation of the log-probabilities, all of which are subtypes of [`AbstractVarInfo`](@ref). +DynamicPPL provides a data structure for storing samples and accumulation of the log-probabilities, called [`VarInfo`](@ref). +The interface that `VarInfo` respects is described by the abstract type [`AbstractVarInfo`](@ref). +Internally DynamicPPL also uses a couple of other subtypes of `AbstractVarInfo`. ```@docs AbstractVarInfo ``` -But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. - -#### `VarInfo` - ```@docs VarInfo ``` -```@docs -DynamicPPL.untyped_varinfo -DynamicPPL.typed_varinfo -DynamicPPL.untyped_vector_varinfo -DynamicPPL.typed_vector_varinfo -``` - One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). The [Transformations section below](#Transformations) describes the methods used for this. In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. @@ -367,14 +358,11 @@ set_transformed!! Base.empty! ``` -#### `SimpleVarInfo` - -```@docs -SimpleVarInfo -``` - #### `VarNamedTuple` +`VarInfo` is only a thin wrapper around [`VarNamedTuple`](@ref), which stores arbitrary data keyed by `VarName`s. +For more details on `VarNamedTuple`, see the Internals section of our documentation. + ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple DynamicPPL.VarNamedTuples.vnt_size diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index b04913aaf..c57ea1fcf 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -8,293 +8,50 @@ VarInfo It contains - - a `logp` field for accumulation of the log-density evaluation, and - - a `metadata` field for storing information about the realizations of the different variables. + - a `VarNamedTuple` field called `values`, + - an `AccumulatorTuple` called `accs`, to hold accumulators. -Representing `logp` is fairly straight-forward: we'll just use a `Real` or an array of `Real`, depending on the context. +`values` takes care of storing information related to values of individual random variables, while `accs` keeps track of information that we keep accumulating in the course of evaluating through a model. -**Representing `metadata` is a bit trickier**. This is supposed to contain all the necessary information for each `VarName` to enable the different executions of the model + extraction of different properties of interest after execution, e.g. the realization / value corresponding to a variable `@varname(x)`. +Variables are regonised by their `VarName`. +We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information. +For instance, a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. +`VarName`s also allow things such as setting values for `x[1]` and `x[2]` and getting a value for `x` as a whole. -!!! note - - We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information, e.g. a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. +To ensure that `VarInfo` is simple and intuitive to work with we want it to replicate the following functionality of `Dict`: -To ensure that `VarInfo` is simple and intuitive to work with, we want `VarInfo`, and hence the underlying `metadata`, to replicate the following functionality of `Dict`: + - `keys(::VarInfo)`: return all the `VarName`s present. + - `haskey(::VarInfo)`: check if a particular `VarName` is present. + - `getindex(::VarInfo, ::VarName)`: return the realization corresponding to a particular `VarName`. + - `setindex!!(::VarInfo, val, ::VarName)`: set the realization corresponding to a particular `VarName`. + - `delete!!(::VarInfo, ::VarName)`: delete the realization corresponding to a particular `VarName`. + - `empty!!(::VarInfo)`: delete all data. + - `merge(::VarInfo, ::VarInfo)`: merge two containers according to similar rules as `Dict`. - - `keys(::Dict)`: return all the `VarName`s present in `metadata`. - - `haskey(::Dict)`: check if a particular `VarName` is present in `metadata`. - - `getindex(::Dict, ::VarName)`: return the realization corresponding to a particular `VarName`. - - `setindex!(::Dict, val, ::VarName)`: set the realization corresponding to a particular `VarName`. - - `push!(::Dict, ::Pair)`: add a new key-value pair to the container. - - `delete!(::Dict, ::VarName)`: delete the realization corresponding to a particular `VarName`. - - `empty!(::Dict)`: delete all realizations in `metadata`. - - `merge(::Dict, ::Dict)`: merge two `metadata` structures according to similar rules as `Dict`. +Note that we only define the BangBang methods such as `setindex!!`, rather than the mutating ones likes `setindex!`. +This is due to the design of `VarNamedTuple`, which is explained on its own page in these docs. -*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. One can access a vectorised version of a variable's value with the following vector-like functions: +*But* for general-purpose samplers, we often want to work with a simple flattened structure, typically a `Vector{<:Real}`. +One can access a vectorised version of a variable's value with the following vector-like functions: - `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable. - `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables. - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - - `setindex_internal!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - - `setindex_internal!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values + - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. + - `setindex_internal!!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values - `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`. The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. -Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. `getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. One can also manually set a transformation by giving `setindex_internal!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. +Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. +`getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. +One can also manually set a transformation by giving `setindex_internal!!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. -Finally, we want want the underlying representation used in `metadata` to have a few performance-related properties: +Finally, we want want the underlying storage to have a few performance-related properties: 1. Type-stable when possible, but functional when not. 2. Efficient storage and iteration when possible, but functional when not. The "but functional when not" is important as we want to support arbitrary models, which means that we can't always have these performance properties. -In the following sections, we'll outline how we achieve this in [`VarInfo`](@ref). - -## Type-stability - -Ensuring type-stability is somewhat non-trivial to address since we want this to be the case even when models mix continuous (typically `Float64`) and discrete (typically `Int`) variables. - -Suppose we have an implementation of `metadata` which implements the functionality outlined in the previous section. The way we approach this in `VarInfo` is to use a `NamedTuple` with a separate `metadata` *for each distinct `Symbol` used*. For example, if we have a model of the form - -```@example varinfo-design -using DynamicPPL, Distributions, FillArrays - -@model function demo() - x ~ product_distribution(Fill(Bernoulli(0.5), 2)) - y ~ Normal(0, 1) - return nothing -end -``` - -then we construct a type-stable representation by using a `NamedTuple{(:x, :y), Tuple{Vx, Vy}}` where - - - `Vx` is a container with `eltype` `Bool`, and - - `Vy` is a container with `eltype` `Float64`. - -Since `VarName` contains the `Symbol` used in its type, something like `getindex(varinfo, @varname(x))` can be resolved to `getindex(varinfo.metadata.x, @varname(x))` at compile-time. - -For example, with the model above we have - -```@example varinfo-design -# Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) -typeof(varinfo_untyped.metadata) -``` - -```@example varinfo-design -# Type-stable `VarInfo` -varinfo_typed = DynamicPPL.typed_varinfo(demo()) -typeof(varinfo_typed.metadata) -``` - -They both work as expected but one results in concrete typing and the other does not: - -```@example varinfo-design -varinfo_untyped[@varname(x)], varinfo_untyped[@varname(y)] -``` - -```@example varinfo-design -varinfo_typed[@varname(x)], varinfo_typed[@varname(y)] -``` - -Notice that the untyped `VarInfo` uses `Vector{Real}` to store the boolean entries while the typed uses `Vector{Bool}`. This is because the untyped version needs the underlying container to be able to handle both the `Bool` for `x` and the `Float64` for `y`, while the typed version can use a `Vector{Bool}` for `x` and a `Vector{Float64}` for `y` due to its usage of `NamedTuple`. - -!!! warning - - Of course, this `NamedTuple` approach is *not* necessarily going to help us in scenarios where the `Symbol` does not correspond to a unique type, e.g. - - ```julia - x[1] ~ Bernoulli(0.5) - x[2] ~ Normal(0, 1) - ``` - - In this case we'll end up with a `NamedTuple((:x,), Tuple{Vx})` where `Vx` is a container with `eltype` `Union{Bool, Float64}` or something worse. This is *not* type-stable but will still be functional. - - In practice, we rarely observe such mixing of types, therefore in DynamicPPL, and more widely in Turing.jl, we use a `NamedTuple` approach for type-stability with great success. - -!!! warning - - Another downside with such a `NamedTuple` approach is that if we have a model with lots of tilde-statements, e.g. `a ~ Normal()`, `b ~ Normal()`, ..., `z ~ Normal()` will result in a `NamedTuple` with 27 entries, potentially leading to long compilation times. - - For these scenarios it can be useful to fall back to "untyped" representations. - -Hence we obtain a "type-stable when possible"-representation by wrapping it in a `NamedTuple` and partially resolving the `getindex`, `setindex!`, etc. methods at compile-time. When type-stability is *not* desired, we can simply use a single `metadata` for all `VarName`s instead of a `NamedTuple` wrapping a collection of `metadata`s. - -## Efficient storage and iteration - -Efficient storage and iteration we achieve through implementation of the `metadata`. In particular, we do so with [`DynamicPPL.VarNamedVector`](@ref): - -```@docs -DynamicPPL.VarNamedVector -``` - -In a [`DynamicPPL.VarNamedVector{<:VarName,T}`](@ref), we achieve the desiderata by storing the values for different `VarName`s contiguously in a `Vector{T}` and keeping track of which ranges correspond to which `VarName`s. - -This does require a bit of book-keeping, in particular when it comes to insertions and deletions. Internally, this is handled by assigning each `VarName` a unique `Int` index in the `varname_to_index` field, which is then used to index into the following fields: - - - `varnames::Vector{<:VarName}`: the `VarName`s in the order they appear in the `Vector{T}`. - - `ranges::Vector{UnitRange{Int}}`: the ranges of indices in the `Vector{T}` that correspond to each `VarName`. - - `transforms::Vector`: the transforms associated with each `VarName`. - -Mutating functions, e.g. `setindex_internal!(vnv::VarNamedVector, val, vn::VarName)`, are then treated according to the following rules: - - 1. If `vn` is not already present: add it to the end of `vnv.varnames`, add the `val` to the underlying `vnv.vals`, etc. - - 2. If `vn` is already present in `vnv`: - - 1. If `val` has the *same length* as the existing value for `vn`: replace existing value. - 2. If `val` has a *smaller length* than the existing value for `vn`: replace existing value and mark the remaining indices as "inactive" by increasing the entry in `vnv.num_inactive` field. - 3. If `val` has a *larger length* than the existing value for `vn`: expand the underlying `vnv.vals` to accommodate the new value, update all `VarName`s occuring after `vn`, and update the `vnv.ranges` to point to the new range for `vn`. - -This means that `VarNamedVector` is allowed to grow as needed, while "shrinking" (i.e. insertion of smaller elements) is handled by simply marking the redundant indices as "inactive". This turns out to be efficient for use-cases that we are generally interested in. - -For example, we want to optimize code-paths which effectively boil down to inner-loop in the following example: - -```julia -# Construct a `VarInfo` with types inferred from `model`. -varinfo = VarInfo(model) - -# Repeatedly sample from `model`. -for _ in 1:num_samples - rand!(rng, model, varinfo) - - # Do something with `varinfo`. - # ... -end -``` - -There are typically a few scenarios where we encounter changing representation sizes of a random variable `x`: - - 1. We're working with a transformed version `x` which is represented in a lower-dimensional space, e.g. transforming a `x ~ LKJ(2, 1)` to unconstrained `y = f(x)` takes us from 2-by-2 `Matrix{Float64}` to a 1-length `Vector{Float64}`. - 2. `x` has a random size, e.g. in a mixture model with a prior on the number of components. Here the size of `x` can vary widly between every realization of the `Model`. - -In scenario (1), we're usually *shrinking* the representation of `x`, and so we end up not making any allocations for the underlying `Vector{T}` but instead just marking the redundant part as "inactive". - -In scenario (2), we end up increasing the allocated memory for the randomly sized `x`, eventually leading to a vector that is large enough to hold realizations without needing to reallocate. But this can still lead to unnecessary memory usage, which might be undesirable. Hence one has to make a decision regarding the trade-off between memory usage and performance for the use-case at hand. - -To help with this, we have the following functions: - -```@docs -DynamicPPL.has_inactive -DynamicPPL.num_inactive -DynamicPPL.num_allocated -DynamicPPL.is_contiguous -DynamicPPL.contiguify! -``` - -For example, one might encounter the following scenario: - -```@example varinfo-design -vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) -println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") - -for i in 1:5 - x = fill(true, rand(1:100)) - DynamicPPL.update!(vnv, x, @varname(x)) - println( - "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", - ) -end -``` - -We can then insert a call to [`DynamicPPL.contiguify!`](@ref) after every insertion whenever the allocation grows too large to reduce overall memory usage: - -```@example varinfo-design -vnv = DynamicPPL.VarNamedVector(@varname(x) => [true]) -println("Before insertion: number of allocated entries $(DynamicPPL.num_allocated(vnv))") - -for i in 1:5 - x = fill(true, rand(1:100)) - DynamicPPL.update!(vnv, x, @varname(x)) - if DynamicPPL.num_allocated(vnv) > 10 - DynamicPPL.contiguify!(vnv) - end - println( - "After insertion #$(i) of length $(length(x)): number of allocated entries $(DynamicPPL.num_allocated(vnv))", - ) -end -``` - -This does incur a runtime cost as it requires re-allocation of the `ranges` in addition to a `resize!` of the underlying `Vector{T}`. However, this also ensures that the the underlying `Vector{T}` is contiguous, which is important for performance. Hence, if we're about to do a lot of work with the `VarNamedVector` without insertions, etc., it can be worth it to do a sweep to ensure that the underlying `Vector{T}` is contiguous. - -!!! note - - Higher-dimensional arrays, e.g. `Matrix`, are handled by simply vectorizing them before storing them in the `Vector{T}`, and composing the `VarName`'s transformation with a `DynamicPPL.ReshapeTransform`. - -Continuing from the example from the previous section, we can use a `VarInfo` with a `VarNamedVector` as the `metadata` field: - -```@example varinfo-design -# Type-unstable -varinfo_untyped_vnv = DynamicPPL.untyped_vector_varinfo(varinfo_untyped) -varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)] -``` - -```@example varinfo-design -# Type-stable -varinfo_typed_vnv = DynamicPPL.typed_vector_varinfo(varinfo_typed) -varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)] -``` - -If we now try to `delete!` `@varname(x)` - -```@example varinfo-design -haskey(varinfo_untyped_vnv, @varname(x)) -``` - -```@example varinfo-design -DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) -``` - -```@example varinfo-design -# `delete!` -DynamicPPL.delete!(varinfo_untyped_vnv.metadata, @varname(x)) -DynamicPPL.has_inactive(varinfo_untyped_vnv.metadata) -``` - -```@example varinfo-design -haskey(varinfo_untyped_vnv, @varname(x)) -``` - -Or insert a differently-sized value for `@varname(x)` - -```@example varinfo-design -DynamicPPL.insert!(varinfo_untyped_vnv.metadata, fill(true, 1), @varname(x)) -varinfo_untyped_vnv[@varname(x)] -``` - -```@example varinfo-design -DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) -``` - -```@example varinfo-design -DynamicPPL.update!(varinfo_untyped_vnv.metadata, fill(true, 4), @varname(x)) -varinfo_untyped_vnv[@varname(x)] -``` - -```@example varinfo-design -DynamicPPL.num_allocated(varinfo_untyped_vnv.metadata, @varname(x)) -``` - -### Performance summary - -In the end, we have the following "rough" performance characteristics for `VarNamedVector`: - -| Method | Is blazingly fast? | -|:----------------------------------------:|:--------------------------------------------------------------------------------------------:| -| `getindex` | ${\color{green} \checkmark}$ | -| `setindex!` on a new `VarName` | ${\color{green} \checkmark}$ | -| `delete!` | ${\color{red} \times}$ | -| `update!` on existing `VarName` | ${\color{green} \checkmark}$ if smaller or same size / ${\color{red} \times}$ if larger size | -| `values_as(::VarNamedVector, Vector{T})` | ${\color{green} \checkmark}$ if contiguous / ${\color{orange} \div}$ otherwise | - -## Other methods - -```@docs -DynamicPPL.replace_raw_storage(::DynamicPPL.VarNamedVector, vals::AbstractVector) -``` - -```@docs; canonical=false -DynamicPPL.values_as(::DynamicPPL.VarNamedVector) -``` +To understand how these are achieved, we refer the reader to the documentation on `VarNamedTuple`, which underpins `VarInfo`. diff --git a/ext/DynamicPPLChainRulesCoreExt.jl b/ext/DynamicPPLChainRulesCoreExt.jl index 12b816c60..37c9444b3 100644 --- a/ext/DynamicPPLChainRulesCoreExt.jl +++ b/ext/DynamicPPLChainRulesCoreExt.jl @@ -16,6 +16,4 @@ ChainRulesCore.@non_differentiable BangBang.push!!( # No need + causes issues for some AD backends, e.g. Zygote. ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x) -ChainRulesCore.@non_differentiable DynamicPPL.recontiguify_ranges!(ranges) - end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b84c076be..b5a77be03 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,7 +46,6 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, - SimpleVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -178,7 +177,7 @@ Abstract supertype for data structures that capture random variables when execut probabilistic model and accumulate log densities such as the log likelihood or the log joint probability of the model. -See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref). +See also: [`VarInfo`](@ref) """ abstract type AbstractVarInfo <: AbstractModelTrace end @@ -196,14 +195,11 @@ include("model.jl") include("varname.jl") include("distribution_wrappers.jl") include("submodel.jl") -include("varnamedvector.jl") include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") -# include("varinfo.jl") include("vntvarinfo.jl") -include("simple_varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 0c15cb9c7..1c5159626 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -502,64 +502,12 @@ If no `Type` is provided, return values as stored in `varinfo`. # Examples -`SimpleVarInfo` with `NamedTuple`: - -```jldoctest -julia> data = (x = 1.0, m = [2.0]); - -julia> values_as(SimpleVarInfo(data)) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`SimpleVarInfo` with `OrderedDict`: - -```jldoctest -julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]); - -julia> values_as(SimpleVarInfo(data)) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), NamedTuple) -(x = 1.0, m = [2.0]) - -julia> values_as(SimpleVarInfo(data), OrderedDict) -OrderedDict{Any, Any} with 2 entries: - x => 1.0 - m => [2.0] - -julia> values_as(SimpleVarInfo(data), Vector) -2-element Vector{Float64}: - 1.0 - 2.0 -``` - -`VarInfo` with `NamedTuple` of `Metadata`: - ```jldoctest julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); + vi = DynamicPPL.VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; -julia> # For the sake of brevity, let's just check the type. - md = values_as(vi); md.s isa Union{DynamicPPL.Metadata, DynamicPPL.VarNamedVector} -true - julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) @@ -573,32 +521,6 @@ julia> values_as(vi, Vector) 1.0 2.0 ``` - -`VarInfo` with `Metadata`: - -```jldoctest -julia> # Just use an example model to construct the `VarInfo` because we're lazy. - vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); - -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; - -julia> # For the sake of brevity, let's just check the type. - values_as(vi) isa Union{DynamicPPL.Metadata, Vector} -true - -julia> values_as(vi, NamedTuple) -(s = 1.0, m = 2.0) - -julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: - s => 1.0 - m => 2.0 - -julia> values_as(vi, Vector) -2-element Vector{Real}: - 1.0 - 2.0 -``` """ function values_as end @@ -625,13 +547,6 @@ function Base.eltype(vi::AbstractVarInfo) return eltype(T) end -""" - has_varnamedvector(varinfo::VarInfo) - -Returns `true` if `varinfo` uses `VarNamedVector` as metadata. -""" -has_varnamedvector(vi::AbstractVarInfo) = false - # TODO: Should relax constraints on `vns` to be `AbstractVector{<:Any}` and just try to convert # the `eltype` to `VarName`? This might be useful when someone does `[@varname(x[1]), @varname(m)]` which # might result in a `Vector{Any}`. @@ -828,8 +743,6 @@ function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation model = setleafcontext(model, DynamicTransformationContext{false}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, t) @@ -897,8 +810,6 @@ function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - # Note that in practice this method is only called for SimpleVarInfo, because VarInfo - # has a dedicated implementation model = setleafcontext(model, DynamicTransformationContext{true}()) vi = last(evaluate!!(model, vi)) return set_transformed!!(vi, NoTransformation()) @@ -983,12 +894,12 @@ julia> # Change the `default_transformation` for our model to be a julia> model = demo(); -julia> vi = SimpleVarInfo(x=1.0) -SimpleVarInfo((x = 1.0,), 0.0) +julia> vi = setindex!!(VarInfo(), 1.0, @varname(x)); + +julia> vi[@varname(x)] +1.0 -julia> # Uses the `inverse` of `MyBijector`, which we have defined as `identity` - vi_linked = link!!(vi, model) -Transformed SimpleVarInfo((x = 1.0,), 0.0) +julia> vi_linked = link!!(vi, model); julia> # Now performs a single `invlink!!` before model evaluation. logjoint(model, vi_linked) diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index c2eee2863..0914d7a79 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -7,7 +7,7 @@ constrained space if `isinverse` or unconstrained if `!isinverse`. Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the `DynamicTransformationContext` methods with more efficient implementations. `DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation, used by e.g. `SimpleVarInfo`. +how to do the transformation. """ struct DynamicTransformationContext{isinverse} <: AbstractContext end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 44fdad5a8..4f8ac4933 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -13,8 +13,6 @@ using DynamicPPL: OnlyAccsVarInfo, RangeAndLinked, VectorWithRanges, - # Metadata, - VarNamedVector, default_accumulators, float_type_with_fallback, getlogjoint, @@ -296,11 +294,6 @@ tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtyp # Helper functions to extract ranges and link status # ###################################################### -# This fails for SimpleVarInfo, but honestly there is no reason to support that here. The -# fact is that evaluation doesn't use a VarInfo, it only uses it once to generate the ranges -# and link status. So there is no motivation to use SimpleVarInfo inside a -# LogDensityFunction any more, we can just always use typed VarInfo. In fact one could argue -# that there is no purpose in supporting untyped VarInfo either. """ get_ranges_and_linked(varinfo::VarInfo) @@ -329,46 +322,3 @@ function get_ranges_and_linked(vi::VNTVarInfo) ) return vnt end - -# function get_ranges_and_linked(varinfo::VarInfo{<:NamedTuple{syms}}) where {syms} -# all_ranges = VarNamedTuple() -# offset = 1 -# for sym in syms -# md = varinfo.metadata[sym] -# this_md_others, offset = get_ranges_and_linked_metadata(md, offset) -# all_ranges = merge(all_ranges, this_md_others) -# end -# return all_ranges -# end -# function get_ranges_and_linked(varinfo::VarInfo{<:Union{Metadata,VarNamedVector}}) -# all_ranges, _ = get_ranges_and_linked_metadata(varinfo.metadata, 1) -# return all_ranges -# end -# function get_ranges_and_linked_metadata(md::Metadata, start_offset::Int) -# all_ranges = VarNamedTuple() -# offset = start_offset -# for (vn, idx) in md.idcs -# is_linked = md.is_transformed[idx] -# range = md.ranges[idx] .+ (start_offset - 1) -# orig_size = varnamesize(vn) -# all_ranges = BangBang.setindex!!( -# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn -# ) -# offset += length(range) -# end -# return all_ranges, offset -# end -# function get_ranges_and_linked_metadata(vnv::VarNamedVector, start_offset::Int) -# all_ranges = VarNamedTuple() -# offset = start_offset -# for (vn, idx) in vnv.varname_to_index -# is_linked = vnv.is_unconstrained[idx] -# range = vnv.ranges[idx] .+ (start_offset - 1) -# orig_size = varnamesize(vn) -# all_ranges = BangBang.setindex!!( -# all_ranges, RangeAndLinked(range, is_linked, orig_size), vn -# ) -# offset += length(range) -# end -# return all_ranges, offset -# end diff --git a/src/model.jl b/src/model.jl index 91558ecdc..cd36ee44b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1151,6 +1151,117 @@ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) return getloglikelihood(last(evaluate!!(model, varinfo))) end +""" + logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log joint probability of variables `values` for the probabilistic `model`. + +See [`logprior`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logjoint(demo([1.0]), (m = 100.0, )) +-9902.33787706641 + +julia> # Using a `OrderedDict`. + logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-9902.33787706641 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) +-9902.33787706641 +``` +""" +function logjoint(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator())) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getlogjoint(vi) +end + +""" + logprior(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log prior probability of variables `values` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`loglikelihood`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + logprior(demo([1.0]), (m = 100.0, )) +-5000.918938533205 + +julia> # Using a `OrderedDict`. + logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-5000.918938533205 + +julia> # Truth. + logpdf(Normal(), 100.0) +-5000.918938533205 +``` +""" +function logprior(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogPriorAccumulator(),)) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getlogprior(vi) +end + +""" + loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) + +Return the log likelihood of variables `values` for the probabilistic `model`. + +See also [`logjoint`](@ref) and [`logprior`](@ref). + +# Examples +```jldoctest; setup=:(using Distributions) +julia> @model function demo(x) + m ~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end + end +demo (generic function with 2 methods) + +julia> # Using a `NamedTuple`. + loglikelihood(demo([1.0]), (m = 100.0, )) +-4901.418938533205 + +julia> # Using a `OrderedDict`. + loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) +-4901.418938533205 + +julia> # Truth. + logpdf(Normal(100.0, 1.0), 1.0) +-4901.418938533205 +``` +""" +function Distributions.loglikelihood(model::Model, values::Union{NamedTuple,AbstractDict}) + accs = AccumulatorTuple((LogLikelihoodAccumulator(),)) + vi = OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.init!!(model, vi, InitFromParams(values, nothing)) + return getloglikelihood(vi) +end + # Implemented & documented in DynamicPPLMCMCChainsExt function predict end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl deleted file mode 100644 index 4add65d6d..000000000 --- a/src/simple_varinfo.jl +++ /dev/null @@ -1,647 +0,0 @@ -""" - $(TYPEDEF) - -A simple wrapper of the parameters with a `logp` field for -accumulation of the logdensity. - -Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`. - -# Fields -$(FIELDS) - -# Notes -The major differences between this and `NTVarInfo` are: -1. `SimpleVarInfo` does not require linearization. -2. `SimpleVarInfo` can use more efficient bijectors. -3. `SimpleVarInfo` is only type-stable if `NT<:NamedTuple` and either - a) no indexing is used in tilde-statements, or - b) the values have been specified with the correct shapes. - -# Examples -## General usage -```jldoctest simplevarinfo-general; setup=:(using Distributions) -julia> using StableRNGs - -julia> @model function demo() - m ~ Normal() - x = Vector{Float64}(undef, 2) - for i in eachindex(x) - x[i] ~ Normal() - end - return x - end -demo (generic function with 2 methods) - -julia> m = demo(); - -julia> rng = StableRNG(42); - -julia> # In the `NamedTuple` version we need to provide the place-holder values for - # the variables which are using "containers", e.g. `Array`. - # In this case, this means that we need to specify `x` but not `m`. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo((x = ones(2), ))); - -julia> # (✓) Vroom, vroom! FAST!!! - vi[@varname(x[1])] -0.4471218424633827 - -julia> # We can also access arbitrary varnames pointing to `x`, e.g. - vi[@varname(x)] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> vi[@varname(x[1:2])] -2-element Vector{Float64}: - 0.4471218424633827 - 1.3736306979834252 - -julia> # (×) If we don't provide the container... - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); -ERROR: FieldError: type NamedTuple has no field `x`, available fields: `m` -[...] - -julia> # If one does not know the varnames, we can use a `OrderedDict` instead. - _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo{Float64}(OrderedDict{VarName,Any}())); - -julia> # (✓) Sort of fast, but only possible at runtime. - vi[@varname(x[1])] --1.019202452456547 - -julia> # In addtion, we can only access varnames as they appear in the model! - vi[@varname(x)] -ERROR: x was not found in the dictionary provided -[...] - -julia> vi[@varname(x[1:2])] -ERROR: x[1:2] was not found in the dictionary provided -[...] -``` - -_Technically_, it's possible to use any implementation of `AbstractDict` in place of -`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening -of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is -the preferred implementation of `AbstractDict` to use here. - -You can also sample in _transformed_ space: - -```jldoctest simplevarinfo-general -julia> @model demo_constrained() = x ~ Exponential() -demo_constrained (generic function with 2 methods) - -julia> m = demo_constrained(); - -julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); - -julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ -1.8632965762164932 - -julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ --0.21080155351918753 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true - -julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); - -julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.6225185067787314 - -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; - -julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! -true -``` - -Evaluation in transformed space of course also works: - -```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) Positive probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --1.3678794411714423 - -julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) - -julia> # (✓) No probability mass on negative numbers! - getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) --Inf -``` - -## Indexing -Using `NamedTuple` as underlying storage. - -```jldoctest -julia> svi_nt = SimpleVarInfo((m = (a = [1.0], ), )); - -julia> svi_nt[@varname(m)] -(a = [1.0],) - -julia> svi_nt[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_nt[@varname(m.a[1])] -1.0 - -julia> svi_nt[@varname(m.a[2])] -ERROR: BoundsError: attempt to access 1-element Vector{Float64} at index [2] -[...] - -julia> svi_nt[@varname(m.b)] -ERROR: FieldError: type NamedTuple has no field `b`, available fields: `a` -[...] -``` - -Using `OrderedDict` as underlying storage. -```jldoctest -julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], ))); - -julia> svi_dict[@varname(m)] -(a = [1.0],) - -julia> svi_dict[@varname(m.a)] -1-element Vector{Float64}: - 1.0 - -julia> svi_dict[@varname(m.a[1])] -1.0 - -julia> svi_dict[@varname(m.a[2])] -ERROR: m.a[2] was not found in the dictionary provided -[...] - -julia> svi_dict[@varname(m.b)] -ERROR: m.b was not found in the dictionary provided -[...] -``` -""" -struct SimpleVarInfo{NT,Accs<:AccumulatorTuple where {N},C<:AbstractTransformation} <: - AbstractVarInfo - "underlying representation of the realization represented" - values::NT - "tuple of accumulators for things like log prior and log likelihood" - accs::Accs - "represents whether it assumes variables to be transformed" - transformation::C -end - -function Base.:(==)(vi1::SimpleVarInfo, vi2::SimpleVarInfo) - return vi1.values == vi2.values && - vi1.accs == vi2.accs && - vi1.transformation == vi2.transformation -end - -transformation(vi::SimpleVarInfo) = vi.transformation - -function SimpleVarInfo(values, accs) - return SimpleVarInfo(values, accs, NoTransformation()) -end -function SimpleVarInfo{T}(values) where {T<:Real} - return SimpleVarInfo(values, default_accumulators(T)) -end -function SimpleVarInfo(values) - return SimpleVarInfo{LogProbType}(values) -end -function SimpleVarInfo(values::Union{<:NamedTuple,<:AbstractDict{<:VarName}}) - return if isempty(values) - # Can't infer from values, so we just use default. - SimpleVarInfo{LogProbType}(values) - else - # Infer from `values`. - SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(values)))}(values) - end -end - -# Using `kwargs` to specify the values. -function SimpleVarInfo{T}(; kwargs...) where {T<:Real} - return SimpleVarInfo{T}(NamedTuple(kwargs)) -end -function SimpleVarInfo(; kwargs...) - return SimpleVarInfo(NamedTuple(kwargs)) -end - -# Constructor from `Model`. -function SimpleVarInfo{T}( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) where {T<:Real} - return last(init!!(rng, model, SimpleVarInfo{T}(), init_strategy)) -end -function SimpleVarInfo{T}( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) where {T<:Real} - return SimpleVarInfo{T}(Random.default_rng(), model, init_strategy) -end -# Constructors without type param -function SimpleVarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return SimpleVarInfo{LogProbType}(rng, model, init_strategy) -end -function SimpleVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return SimpleVarInfo{LogProbType}(Random.default_rng(), model, init_strategy) -end - -# Constructor from `VarInfo`. -# function SimpleVarInfo(vi::NTVarInfo, ::Type{D}) where {D} -# values = values_as(vi, D) -# return SimpleVarInfo(values, copy(getaccs(vi))) -# end -# function SimpleVarInfo{T}(vi::NTVarInfo, ::Type{D}) where {T<:Real,D} -# values = values_as(vi, D) -# accs = map(acc -> convert_eltype(T, acc), getaccs(vi)) -# return SimpleVarInfo(values, accs) -# end - -function untyped_simple_varinfo(model::Model) - varinfo = SimpleVarInfo(OrderedDict{VarName,Any}()) - return last(init!!(model, varinfo)) -end - -function typed_simple_varinfo(model::Model) - varinfo = SimpleVarInfo{Float64}() - return last(init!!(model, varinfo)) -end - -function unflatten(svi::SimpleVarInfo, x::AbstractVector) - vals = unflatten(svi.values, x) - return SimpleVarInfo(vals, svi.accs, svi.transformation) -end - -function BangBang.empty!!(vi::SimpleVarInfo) - return resetaccs!!(Accessors.@set vi.values = empty!!(vi.values)) -end -Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) - -getaccs(vi::SimpleVarInfo) = vi.accs -setaccs!!(vi::SimpleVarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -""" - keys(vi::SimpleVarInfo) - -Return an iterator of keys present in `vi`. -""" -Base.keys(vi::SimpleVarInfo) = keys(vi.values) -Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values)) - -function Base.show(io::IO, mime::MIME"text/plain", svi::SimpleVarInfo) - if !(svi.transformation isa NoTransformation) - print(io, "Transformed ") - end - - return print(io, "SimpleVarInfo(", svi.values, ", ", repr(mime, getaccs(svi)), ")") -end - -function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return from_maybe_linked_internal(vi, vn, dist, getindex(vi, vn)) -end -function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -Base.getindex(vi::SimpleVarInfo, vn::VarName) = getindex_internal(vi, vn) - -# `SimpleVarInfo` doesn't necessarily vectorize, so we can have arrays other than -# just `Vector`. -function Base.getindex(vi::SimpleVarInfo, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(getindex, vi), vns) -end -# HACK: Needed to disambiguate. -Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) - -Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) - -getindex_internal(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn) -# `AbstractDict` -function getindex_internal( - vi::SimpleVarInfo{<:Union{AbstractDict,VarNamedVector}}, vn::VarName -) - return getvalue(vi.values, vn) -end - -Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) - -function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) - # For `NamedTuple` we treat the symbol in `vn` as the _property_ to set. - return Accessors.@set vi.values = set!!(vi.values, vn, val) -end - -# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with -# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. -function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) - for (vn, val) in zip(vns, vals) - vi = BangBang.setindex!!(vi, val, vn) - end - return vi -end - -function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName) - # For dictlike objects, we treat the entire `vn` as a _key_ to set. - dict = values_as(vi) - # Attempt to split into `parent` and `child` optic. - parent, child, issuccess = splitoptic(getoptic(vn)) do optic - o = optic === nothing ? identity : optic - haskey(dict, VarName{getsym(vn)}(o)) - end - # When combined with `VarInfo`, `nothing` is equivalent to `identity`. - keyoptic = parent === nothing ? identity : parent - - dict_new = if !issuccess - # Split doesn't exist ⟹ we're working with a new key. - BangBang.setindex!!(dict, val, vn) - else - # Split exists ⟹ trying to set an existing key. - vn_key = VarName{getsym(vn)}(keyoptic) - BangBang.setindex!!(dict, set!!(dict[vn_key], child, val), vn_key) - end - return Accessors.@set vi.values = dict_new -end - -# `NamedTuple` -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, ::VarName{sym,typeof(identity)}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = merge(vi.values, NamedTuple{(sym,)}((value,))) -end -function BangBang.push!!( - vi::SimpleVarInfo{<:NamedTuple}, vn::VarName{sym}, value, ::Distribution -) where {sym} - return Accessors.@set vi.values = set!!(vi.values, vn, value) -end - -# `AbstractDict` -function BangBang.push!!( - vi::SimpleVarInfo{<:AbstractDict}, vn::VarName, value, ::Distribution -) - vi.values[vn] = value - return vi -end - -function BangBang.push!!( - vi::SimpleVarInfo{<:VarNamedVector}, vn::VarName, value, ::Distribution -) - # The semantics of push!! for SimpleVarInfo and VarNamedVector are different. For - # SimpleVarInfo, push!! allows the key to exist already, for VarNamedVector it does not. - # Hence we need to call update!! here, which has the same semantics as push!! does for - # SimpleVarInfo. - return Accessors.@set vi.values = setindex!!(vi.values, value, vn) -end - -const SimpleOrThreadSafeSimple{T,V,C} = Union{ - SimpleVarInfo{T,V,C},ThreadSafeVarInfo{<:SimpleVarInfo{T,V,C}} -} - -# Necessary for `matchingvalue` to work properly. -Base.eltype(::SimpleOrThreadSafeSimple{<:Any,V}) where {V} = V - -# `subset` -function subset(varinfo::SimpleVarInfo, vns::AbstractVector{<:VarName}) - return SimpleVarInfo( - _subset(varinfo.values, vns), map(copy, getaccs(varinfo)), varinfo.transformation - ) -end - -function _subset(x::AbstractDict, vns::AbstractVector{VN}) where {VN<:VarName} - vns_present = collect(keys(x)) - vns_found = filter( - vn_present -> any(subsumes(vn, vn_present) for vn in vns), vns_present - ) - C = ConstructionBase.constructorof(typeof(x)) - if isempty(vns_found) - return C() - else - return C(vn => x[vn] for vn in vns_found) - end -end - -function _subset(x::NamedTuple, vns) - # NOTE: Here we can only handle `vns` that contain `identity` as optic. - if any(Base.Fix1(!==, identity) ∘ getoptic, vns) - throw( - ArgumentError( - "Cannot subset `NamedTuple` with non-`identity` `VarName`. " * - "For example, `@varname(x)` is allowed, but `@varname(x[1])` is not.", - ), - ) - end - - syms = map(getsym, vns) - x_syms = filter(Base.Fix2(in, syms), keys(x)) - return NamedTuple{Tuple(x_syms)}(Tuple(map(Base.Fix1(getindex, x), x_syms))) -end - -_subset(x::VarNamedVector, vns) = subset(x, vns) - -# `merge` -function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) - values = merge(varinfo_left.values, varinfo_right.values) - accs = map(copy, getaccs(varinfo_right)) - transformation = merge_transformations( - varinfo_left.transformation, varinfo_right.transformation - ) - return SimpleVarInfo(values, accs, transformation) -end - -function set_transformed!!(vi::SimpleVarInfo, trans) - return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) -end -function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) - return Accessors.@set vi.transformation = transformation -end -function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) -end -function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) - # We keep this method around just to obey the AbstractVarInfo interface. - # However, note that this would only be a valid operation if it would be a - # no-op, which we check here. - if trans != is_transformed(vi) - error( - "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", - ) - end - return vi -end - -is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) -function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) - return is_transformed(vi.varinfo, vn) -end -is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) - -values_as(vi::SimpleVarInfo) = vi.values -values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values -function values_as(vi::SimpleVarInfo, ::Type{Vector}) - isempty(vi) && return Any[] - return mapreduce(tovec, vcat, values(vi.values)) -end -function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values))) -end -function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple}) - return NamedTuple((Symbol(k), v) for (k, v) in vi.values) -end -function values_as(vi::SimpleVarInfo, ::Type{T}) where {T} - return values_as(vi.values, T) -end - -""" - logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log joint probability of variables `θ` for the probabilistic `model`. - -See [`logprior`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logjoint(demo([1.0]), (m = 100.0, )) --9902.33787706641 - -julia> # Using a `OrderedDict`. - logjoint(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --9902.33787706641 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) + logpdf(Normal(), 100.0) --9902.33787706641 -``` -""" -logjoint(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logjoint(model, SimpleVarInfo(θ)) - -""" - logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log prior probability of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`loglikelihood`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - logprior(demo([1.0]), (m = 100.0, )) --5000.918938533205 - -julia> # Using a `OrderedDict`. - logprior(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --5000.918938533205 - -julia> # Truth. - logpdf(Normal(), 100.0) --5000.918938533205 -``` -""" -logprior(model::Model, θ::Union{NamedTuple,AbstractDict}) = - logprior(model, SimpleVarInfo(θ)) - -""" - loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) - -Return the log likelihood of variables `θ` for the probabilistic `model`. - -See also [`logjoint`](@ref) and [`logprior`](@ref). - -# Examples -```jldoctest; setup=:(using Distributions) -julia> @model function demo(x) - m ~ Normal() - for i in eachindex(x) - x[i] ~ Normal(m, 1.0) - end - end -demo (generic function with 2 methods) - -julia> # Using a `NamedTuple`. - loglikelihood(demo([1.0]), (m = 100.0, )) --4901.418938533205 - -julia> # Using a `OrderedDict`. - loglikelihood(demo([1.0]), OrderedDict(@varname(m) => 100.0)) --4901.418938533205 - -julia> # Truth. - logpdf(Normal(100.0, 1.0), 1.0) --4901.418938533205 -``` -""" -Distributions.loglikelihood(model::Model, θ::Union{NamedTuple,AbstractDict}) = - loglikelihood(model, SimpleVarInfo(θ)) - -# Allow usage of `NamedBijector` too. -function link!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = inverse(t.bijector) - x = vi.values - y, logjac = with_logabsdet_jacobian(b, x) - vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, logjac) - end - return set_transformed!!(vi_new, t) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.NamedTransform}, - vi::SimpleVarInfo{<:NamedTuple}, - ::Model, -) - b = t.bijector - y = vi.values - x, inv_logjac = with_logabsdet_jacobian(b, y) - vi_new = Accessors.@set(vi.values = x) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - if hasacc(vi_new, Val(:LogJacobian)) - vi_new = acclogjac!!(vi_new, inv_logjac) - end - return set_transformed!!(vi_new, NoTransformation()) -end - -# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. -from_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -from_internal_transform(vi::SimpleVarInfo, ::VarName, dist) = identity -# TODO: Should the following methods specialize on the case where we have a `StaticTransformation{<:Bijectors.NamedTransform}`? -from_linked_internal_transform(vi::SimpleVarInfo, ::VarName) = identity -function from_linked_internal_transform(vi::SimpleVarInfo, ::VarName, dist) - return invlink_transform(dist) -end - -has_varnamedvector(vi::SimpleVarInfo) = vi.values isa VarNamedVector diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 88200680a..d83cb289d 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -65,8 +65,6 @@ function map_accumulators!!(func::Function, vi::ThreadSafeVarInfo) return vi end -has_varnamedvector(vi::ThreadSafeVarInfo) = has_varnamedvector(vi.varinfo) - syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) diff --git a/src/utils.jl b/src/utils.jl index ba79f94b4..4a0eea96c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,9 +9,6 @@ const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. VarNameTuple = NTuple{N,VarName} where {N} -# TODO(mhauru) This is currently used in the transformation functions of NoDist, -# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in -# SimpleVarInfo and maybe other places. """ The type for all log probability variables. @@ -506,8 +503,6 @@ end # UnivariateDistributions need to be handled as a special case, because size(dist) is (), # which makes the usual machinery think we are dealing with a 0-dim array, whereas in # actuality we are dealing with a scalar. -# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and -# VarNamedVector takes over from Metadata. function from_linked_vec_transform(dist::UnivariateDistribution) f_invlink = invlink_transform(dist) f_vec = from_vec_transform(inverse(f_invlink), size(dist)) diff --git a/src/varinfo.jl b/src/varinfo.jl deleted file mode 100644 index d1ea7dae3..000000000 --- a/src/varinfo.jl +++ /dev/null @@ -1,1810 +0,0 @@ -#### -#### Types for typed and untyped VarInfo -#### - -#################### -# VarInfo metadata # -#################### - -""" -The `Metadata` struct stores some metadata about the parameters of the model. This helps -query certain information about a variable, such as its distribution, which samplers -sample this variable, its value and whether this value is transformed to real space or -not. - -Let `md` be an instance of `Metadata`: -- `md.vns` is the vector of all `VarName` instances. -- `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, and `md.is_transformed`. -- `md.vns[md.idcs[vn]] == vn`. -- `md.dists[md.idcs[vn]]` is the distribution of `vn`. -- `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. -- `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.is_transformed` is a BitVector of true/false flags for whether a variable has been - transformed. `md.is_transformed[md.idcs[vn]]` is the value corresponding to `vn`. - -To make `md::Metadata` type stable, all the `md.vns` must have the same symbol -and distribution type. However, one can have a Julia variable, say `x`, that is a -matrix or a hierarchical array sampled in partitions, e.g. -`x[1][:] ~ MvNormal(zeros(2), I); x[2][:] ~ MvNormal(ones(2), I)`, and is managed by -a single `md::Metadata` so long as all the distributions on the RHS of `~` are of the -same type. Type unstable `Metadata` will still work but will have inferior performance. -When sampling, the first iteration uses a type unstable `Metadata` for all the -variables then a specialized `Metadata` is used for each symbol along with a function -barrier to make the rest of the sampling type stable. -""" -struct Metadata{ - TIdcs<:Dict{<:VarName,Int}, - TDists<:AbstractVector{<:Distribution}, - TVN<:AbstractVector{<:VarName}, - TVal<:AbstractVector{<:Real}, -} - # Mapping from the `VarName` to its integer index in `vns`, `ranges` and `dists` - idcs::TIdcs # Dict{<:VarName,Int} - - # Vector of identifiers for the random variables, where `vns[idcs[vn]] == vn` - vns::TVN # AbstractVector{<:VarName} - - # Vector of index ranges in `vals` corresponding to `vns` - # Each `VarName` `vn` has a single index or a set of contiguous indices in `vals` - ranges::Vector{UnitRange{Int}} - - # Vector of values of all the univariate, multivariate and matrix variables - # The value(s) of `vn` is/are `vals[ranges[idcs[vn]]]` - vals::TVal # AbstractVector{<:Real} - - # Vector of distributions correpsonding to `vns` - dists::TDists # AbstractVector{<:Distribution} - - is_transformed::BitVector -end - -function Base.:(==)(md1::Metadata, md2::Metadata) - return ( - md1.idcs == md2.idcs && - md1.vns == md2.vns && - md1.ranges == md2.ranges && - md1.vals == md2.vals && - md1.dists == md2.dists && - md1.is_transformed == md2.is_transformed - ) -end - -########### -# VarInfo # -########### - -""" - struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo - metadata::Tmeta - accs::Accs - end - -A light wrapper over some kind of metadata. - -The type of the metadata can be one of a number of options. It may either be a -`Metadata` or a `VarNamedVector`, _or_, it may be a `NamedTuple` which maps -symbols to `Metadata` or `VarNamedVector` instances. Here, a _symbol_ refers -to a Julia variable and may consist of one or more `VarName`s which appear on -the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both -have the same symbol `x`. - -Several type aliases are provided for these forms of VarInfos: -- `VarInfo{<:Metadata}` is `UntypedVarInfo` -- `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` -- `VarInfo{<:NamedTuple}` is `NTVarInfo` - -The NamedTuple form, i.e. `NTVarInfo`, is useful for maintaining type stability -of model evaluation. However, the element type of NamedTuples are not contained -in its type itself: thus, there is no way to use the type system to determine -whether the elements of the NamedTuple are `Metadata` or `VarNamedVector`. - -Note that for NTVarInfo, it is the user's responsibility to ensure that each -symbol is visited at least once during model evaluation, regardless of any -stochastic branching. -""" -struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo - metadata::Tmeta - accs::Accs -end -function VarInfo(meta=Metadata()) - return VarInfo(meta, default_accumulators()) -end - -""" - VarInfo( - [rng::Random.AbstractRNG], - model, - [init_strategy::AbstractInitStrategy] - ) - -Generate a `VarInfo` object for the given `model`, by initialising it with the -given `rng` and `init_strategy`. - -!!! warning - - This function currently returns a `VarInfo` with its metadata field set to - a `NamedTuple` of `Metadata`. This is an implementation detail. In general, - this function may return any kind of object that satisfies the - `AbstractVarInfo` interface. If you require precise control over the type - of `VarInfo` returned, use the internal functions `untyped_varinfo`, - `typed_varinfo`, `untyped_vector_varinfo`, or `typed_vector_varinfo` - instead. -""" -function VarInfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_varinfo(rng, model, init_strategy) -end -function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return VarInfo(Random.default_rng(), model, init_strategy) -end - -const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} -const UntypedVarInfo = VarInfo{<:Metadata} -# TODO: NTVarInfo carries no information about the type of the actual metadata -# i.e. the elements of the NamedTuple. It could be Metadata or it could be -# VarNamedVector. -# Resolving this ambiguity would likely require us to replace NamedTuple with -# something which carried both its keys as well as its values' types as type -# parameters. -const NTVarInfo = VarInfo{<:NamedTuple} -const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ - VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} -} - -function Base.:(==)(vi1::VarInfo, vi2::VarInfo) - return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) -end - -# NOTE: This is kind of weird, but it effectively preserves the "old" -# behavior where we're allowed to call `link!` on the same `VarInfo` -# multiple times. -transformation(::VarInfo) = DynamicTransformation() - -# No-op if we're already working with a `VarNamedVector`. -metadata_to_varnamedvector(vnv::VarNamedVector) = vnv -function metadata_to_varnamedvector(md::Metadata) - idcs = copy(md.idcs) - vns = copy(md.vns) - ranges = copy(md.ranges) - vals = copy(md.vals) - is_trans = map(Base.Fix1(is_transformed, md), md.vns) - transforms = map(md.dists, is_trans) do dist, trans - if trans - return from_linked_vec_transform(dist) - else - return from_vec_transform(dist) - end - end - - return VarNamedVector( - OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms, is_trans - ) -end - -function has_varnamedvector(vi::VarInfo) - return vi.metadata isa VarNamedVector || - (vi isa NTVarInfo && any(Base.Fix2(isa, VarNamedVector), values(vi.metadata))) -end - -######################## -# VarInfo constructors # -######################## - -""" - untyped_varinfo([rng, ]model[, init_strategy]) - -Construct a VarInfo object for the given `model`, which has just a single -`Metadata` as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) -end -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - typed_varinfo(vi::UntypedVarInfo) - -This function finds all the unique `sym`s from the instances of `VarName{sym}` found in -`vi.metadata.vns`. It then extracts the metadata associated with each symbol from the -global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `metadata` as -a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each -symbol. -""" -function typed_varinfo(vi::UntypedVarInfo) - meta = vi.metadata - new_metas = Metadata[] - # Symbols of all instances of `VarName{sym}` in `vi.vns` - syms_tuple = Tuple(syms(vi)) - for s in syms_tuple - # Find all indices in `vns` with symbol `s` - inds = findall(vn -> getsym(vn) === s, meta.vns) - n = length(inds) - # New `vns` - sym_vns = getindex.((meta.vns,), inds) - # New idcs - sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) - # New dists - sym_dists = getindex.((meta.dists,), inds) - # New is_transformed - sym_is_transformed = meta.is_transformed[inds] - - # Extract new ranges and vals - _ranges = getindex.((meta.ranges,), inds) - # `copy.()` is a workaround to reduce the eltype from Real to Int or Float64 - _vals = [copy.(meta.vals[_ranges[i]]) for i in 1:n] - sym_ranges = Vector{eltype(_ranges)}(undef, n) - start = 0 - for i in 1:n - sym_ranges[i] = (start + 1):(start + length(_vals[i])) - start += length(_vals[i]) - end - sym_vals = foldl(vcat, _vals) - - push!( - new_metas, - Metadata( - sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_is_transformed - ), - ) - end - nt = NamedTuple{syms_tuple}(Tuple(new_metas)) - return VarInfo(nt, copy(vi.accs)) -end -function typed_varinfo(vi::NTVarInfo) - # This function preserves the behaviour of typed_varinfo(vi) where vi is - # already a NTVarInfo - has_varnamedvector(vi) && error( - "Cannot convert VarInfo with NamedTuple of VarNamedVector to VarInfo with NamedTuple of Metadata", - ) - return vi -end -""" - typed_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has a NamedTuple of -`Metadata` structs as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) -end -function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return typed_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - untyped_vector_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has just a single -`VarNamedVector` as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function untyped_vector_varinfo(vi::UntypedVarInfo) - md = metadata_to_varnamedvector(vi.metadata) - return VarInfo(md, copy(vi.accs)) -end -function untyped_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return last(init!!(rng, model, VarInfo(VarNamedVector()), init_strategy)) -end -function untyped_vector_varinfo( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - typed_vector_varinfo([rng, ]model[, init_strategy]) - -Return a VarInfo object for the given `model`, which has a NamedTuple of -`VarNamedVector`s as its metadata field. - -# Arguments -- `rng::Random.AbstractRNG`: The random number generator to use during model evaluation -- `model::Model`: The model for which to create the varinfo object -- `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. -""" -function typed_vector_varinfo(vi::NTVarInfo) - md = map(metadata_to_varnamedvector, vi.metadata) - return VarInfo(md, copy(vi.accs)) -end -function typed_vector_varinfo(vi::UntypedVectorVarInfo) - new_metas = group_by_symbol(vi.metadata) - nt = NamedTuple(new_metas) - return VarInfo(nt, copy(vi.accs)) -end -function typed_vector_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return typed_vector_varinfo(untyped_vector_varinfo(rng, model, init_strategy)) -end -function typed_vector_varinfo( - model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() -) - return typed_vector_varinfo(Random.default_rng(), model, init_strategy) -end - -""" - vector_length(varinfo::VarInfo) - -Return the length of the vector representation of `varinfo`. -""" -vector_length(varinfo::VarInfo) = length(varinfo.metadata) -vector_length(varinfo::NTVarInfo) = sum(length, varinfo.metadata) -vector_length(md::Metadata) = sum(length, md.ranges) - -function unflatten(vi::VarInfo, x::AbstractVector) - md = unflatten_metadata(vi.metadata, x) - return VarInfo(md, vi.accs) -end - -# We would call this `unflatten` if not for `unflatten` having a method for NamedTuples in -# utils.jl. -@generated function unflatten_metadata( - metadata::NamedTuple{names}, x::AbstractVector -) where {names} - exprs = [] - offset = :(0) - for f in names - mdf = :(metadata.$f) - len = :(sum(length, $mdf.ranges)) - push!(exprs, :($f = unflatten_metadata($mdf, x[($offset + 1):($offset + $len)]))) - offset = :($offset + $len) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.is_transformed) -end - -unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) - -#### -#### Internal functions -#### - -""" - Metadata() - -Construct an empty type unstable instance of `Metadata`. -""" -function Metadata() - vals = Vector{Real}() - is_transformed = BitVector() - - return Metadata( - Dict{VarName,Int}(), - Vector{VarName}(), - Vector{UnitRange{Int}}(), - vals, - Vector{Distribution}(), - is_transformed, - ) -end - -""" - empty!(meta::Metadata) - -Empty the fields of `meta`. - -This is useful when using a sampling algorithm that assumes an empty `meta`, e.g. `SMC`. -""" -function empty!(meta::Metadata) - empty!(meta.idcs) - empty!(meta.vns) - empty!(meta.ranges) - empty!(meta.vals) - empty!(meta.dists) - empty!(meta.is_transformed) - return meta -end - -# Removes the first element of a NamedTuple. The pairs in a NamedTuple are ordered, so this is well-defined. -if VERSION < v"1.1" - _tail(nt::NamedTuple{names}) where {names} = NamedTuple{Base.tail(names)}(nt) -else - _tail(nt::NamedTuple) = Base.tail(nt) -end - -function subset(varinfo::VarInfo, vns::AbstractVector{<:VarName}) - metadata = subset(varinfo.metadata, vns) - return VarInfo(metadata, map(copy, getaccs(varinfo))) -end - -function subset(metadata::NamedTuple, vns::AbstractVector{<:VarName}) - vns_syms = Set(unique(map(getsym, vns))) - syms = filter(Base.Fix2(in, vns_syms), keys(metadata)) - metadatas = map(syms) do sym - subset(getfield(metadata, sym), filter(==(sym) ∘ getsym, vns)) - end - return NamedTuple{syms}(metadatas) -end - -# The above method is type unstable since we don't know which symbols are in `vns`. -# In the below special case, when all `vns` have the same symbol, we can write a type stable -# version. - -@generated function subset( - metadata::NamedTuple{names}, vns::AbstractVector{<:VarName{sym}} -) where {names,sym} - return if (sym in names) - # TODO(mhauru) Note that this could still generate an empty metadata object if none - # of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for - # emptiness would make this type unstable again. - :((; $sym=subset(metadata.$sym, vns))) - else - :(NamedTuple{}()) - end -end - -function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:VarName} - # TODO: Should we error if `vns` contains a variable that is not in `metadata`? - # Find all the vns in metadata that are subsumed by one of the given vns. - vns = filter(vn -> any(subsumes(vn_given, vn) for vn_given in vns_given), metadata.vns) - indices_for_vns = map(Base.Fix1(getindex, metadata.idcs), vns) - indices = if isempty(vns) - Dict{VarName,Int}() - else - Dict(vn => i for (i, vn) in enumerate(vns)) - end - # Construct new `vals` and `ranges`. - vals_original = metadata.vals - ranges_original = metadata.ranges - # Allocate the new `vals`. and `ranges`. - vals = similar(metadata.vals, sum(length, ranges_original[indices_for_vns]; init=0)) - ranges = similar(ranges_original, length(vns)) - # The new range `r` for `vns[i]` is offset by `offset` and - # has the same length as the original range `r_original`. - # The new `indices` (from above) ensures ordering according to `vns`. - # NOTE: This means that the order of the variables in `vns` defines the order - # in the resulting `varinfo`! This can have performance implications, e.g. - # if in the model we have something like - # - # for i = 1:N - # x[i] ~ Normal() - # end - # - # and we then we do - # - # subset(varinfo, [@varname(x[i]) for i in shuffle(keys(varinfo))]) - # - # the resulting `varinfo` will have `vals` ordered differently from the - # original `varinfo`, which can have performance implications. - offset = 0 - for (idx, idx_original) in enumerate(indices_for_vns) - r_original = ranges_original[idx_original] - r = (offset + 1):(offset + length(r_original)) - vals[r] = vals_original[r_original] - ranges[idx] = r - offset = r[end] - end - - dists = metadata.dists[indices_for_vns] - is_transformed = metadata.is_transformed[indices_for_vns] - return Metadata(indices, vns, ranges, vals, dists, is_transformed) -end - -function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - return _merge(varinfo_left, varinfo_right) -end - -function _merge(varinfo_left::VarInfo, varinfo_right::VarInfo) - metadata = merge_metadata(varinfo_left.metadata, varinfo_right.metadata) - accs = map(copy, getaccs(varinfo_right)) - return VarInfo(metadata, accs) -end - -function merge_metadata(vnv_left::VarNamedVector, vnv_right::VarNamedVector) - return merge(vnv_left, vnv_right) -end - -@generated function merge_metadata( - metadata_left::NamedTuple{names_left}, metadata_right::NamedTuple{names_right} -) where {names_left,names_right} - names = Expr(:tuple) - vals = Expr(:tuple) - # Loop over `names_left` first because we want to preserve the order of the variables. - for sym in names_left - push!(names.args, QuoteNode(sym)) - if sym in names_right - push!(vals.args, :(merge_metadata(metadata_left.$sym, metadata_right.$sym))) - else - push!(vals.args, :(metadata_left.$sym)) - end - end - # Loop over remaining variables in `names_right`. - names_right_only = filter(∉(names_left), names_right) - for sym in names_right_only - push!(names.args, QuoteNode(sym)) - push!(vals.args, :(metadata_right.$sym)) - end - - return :(NamedTuple{$names}($vals)) -end - -function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) - # Extract the varnames. - vns_left = metadata_left.vns - vns_right = metadata_right.vns - vns_both = union(vns_left, vns_right) - - # Determine `eltype` of `vals`. - T_left = eltype(metadata_left.vals) - T_right = eltype(metadata_right.vals) - T = promote_type(T_left, T_right) - # TODO: Is this necessary? - if !(T <: Real) - T = Real - end - - # Determine `eltype` of `dists`. - D_left = eltype(metadata_left.dists) - D_right = eltype(metadata_right.dists) - D = promote_type(D_left, D_right) - # TODO: Is this necessary? - if !(D <: Distribution) - D = Distribution - end - - # Initialize required fields for `metadata`. - vns = VarName[] - idcs = Dict{VarName,Int}() - ranges = Vector{UnitRange{Int}}() - vals = T[] - dists = D[] - transformed = BitVector() - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - idcs[vn] = idx - push!(vns, vn) - metadata_for_vn = vn in vns_right ? metadata_right : metadata_left - - val = getindex_internal(metadata_for_vn, vn) - append!(vals, val) - r = (offset + 1):(offset + length(val)) - push!(ranges, r) - offset = r[end] - dist = getdist(metadata_for_vn, vn) - push!(dists, dist) - push!(transformed, is_transformed(metadata_for_vn, vn)) - end - - return Metadata(idcs, vns, ranges, vals, dists, transformed) -end - -const VarView = Union{Int,UnitRange,Vector{Int}} - -""" - setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) - -Set the value of `vi.vals[vview]` to `val`. -""" -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val - -""" - getmetadata(vi::VarInfo, vn::VarName) - -Return the metadata in `vi` that belongs to `vn`. -""" -getmetadata(vi::VarInfo, vn::VarName) = vi.metadata -getmetadata(vi::NTVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) - -""" - getidx(vi::VarInfo, vn::VarName) - -Return the index of `vn` in the metadata of `vi` corresponding to `vn`. -""" -getidx(vi::VarInfo, vn::VarName) = getidx(getmetadata(vi, vn), vn) -getidx(md::Metadata, vn::VarName) = md.idcs[vn] - -""" - getrange(vi::VarInfo, vn::VarName) - -Return the index range of `vn` in the metadata of `vi`. -""" -getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) -getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] - -""" - setrange!(vi::VarInfo, vn::VarName, range) - -Set the index range of `vn` in the metadata of `vi` to `range`. -""" -setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range) -setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range - -""" - getdist(vi::VarInfo, vn::VarName) - -Return the distribution from which `vn` was sampled in `vi`. -""" -getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) -getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] -# TODO(mhauru) Remove this once the old Gibbs sampler stuff is gone. -function getdist(::VarNamedVector, ::VarName) - throw(ErrorException("getdist does not exist for VarNamedVector")) -end - -getindex_internal(vi::VarInfo, vn::VarName) = getindex_internal(getmetadata(vi, vn), vn) -# TODO(torfjelde): Use `view` instead of `getindex`. Requires addressing type-stability issues though, -# since then we might be returning a `SubArray` rather than an `Array`, which is typically -# what a bijector would result in, even if the input is a view (`SubArray`). -# TODO(torfjelde): An alternative is to implement `view` directly instead. -getindex_internal(md::Metadata, vn::VarName) = getindex(md.vals, getrange(md, vn)) -function getindex_internal(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(Base.Fix1(getindex_internal, vi), vcat, vns) -end -getindex_internal(vi::VarInfo, ::Colon) = getindex_internal(vi.metadata, Colon()) -# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. -# See for example https://github.com/JuliaLang/julia/pull/46381. -function getindex_internal(vi::NTVarInfo, ::Colon) - return reduce(vcat, map(Base.Fix2(getindex_internal, Colon()), vi.metadata)) -end -function getindex_internal(vi::VarInfo{NamedTuple{(),Tuple{}}}, ::Colon) - return float(Real)[] -end -function getindex_internal(md::Metadata, ::Colon) - return mapreduce( - Base.Fix1(getindex_internal, md), vcat, md.vns; init=similar(md.vals, 0) - ) -end - -""" - setval!(vi::VarInfo, val, vn::VarName) - -Set the value(s) of `vn` in the metadata of `vi` to `val`. - -The values may or may not be transformed to Euclidean space. -""" -setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) -function setval!(md::Metadata, val::AbstractVector, vn::VarName) - return md.vals[getrange(md, vn)] = val -end -function setval!(md::Metadata, val, vn::VarName) - return md.vals[getrange(md, vn)] = tovec(val) -end - -function set_transformed!!(vi::NTVarInfo, val::Bool, vn::VarName) - md = set_transformed!!(getmetadata(vi, vn), val, vn) - return Accessors.@set vi.metadata[getsym(vn)] = md -end - -function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) - md = set_transformed!!(getmetadata(vi, vn), val, vn) - return VarInfo(md, vi.accs) -end - -function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) - metadata.is_transformed[getidx(metadata, vn)] = val - return metadata -end - -function set_transformed!!(vi::VarInfo, val::Bool) - for vn in keys(vi) - vi = set_transformed!!(vi, val, vn) - end - - return vi -end - -set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) -# HACK: This is necessary to make something like `link!!(transformation, vi, model)` -# work properly, which will transform the variables according to `transformation` -# and then call `set_transformed!!(vi, transformation)`. An alternative would be to add -# the `transformation` to the `VarInfo` object, but at the moment doesn't seem -# worth it as `VarInfo` has its own way of handling transformations. -set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) - -""" - syms(vi::VarInfo) - -Returns a tuple of the unique symbols of random variables in `vi`. -""" -syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols -syms(vi::NTVarInfo) = keys(vi.metadata) - -_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) -_getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) - -@generated function _getidcs(metadata::NamedTuple{names}) where {names} - exprs = [] - for f in names - push!(exprs, :($f = findinds(metadata.$f))) - end - length(exprs) == 0 && return :(NamedTuple()) - return :($(exprs...),) -end - -@inline findinds(f_meta::Metadata) = eachindex(f_meta.vns) -findinds(vnv::VarNamedVector) = 1:length(vnv.varnames) - -""" - all_varnames_grouped_by_symbol(vi::NTVarInfo) - -Return a `NamedTuple` of the variables in `vi` grouped by symbol. -""" -all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(vi.metadata) - -@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :($f = keys(md.$f))) - end - return expr -end - -#### -#### APIs for typed and untyped VarInfo -#### - -function BangBang.empty!!(vi::VarInfo) - _empty!(vi.metadata) - vi = resetaccs!!(vi) - return vi -end - -_empty!(metadata) = empty!(metadata) -@generated function _empty!(metadata::NamedTuple{names}) where {names} - expr = Expr(:block) - for f in names - push!(expr.args, :(empty!(metadata.$f))) - end - return expr -end - -# `keys` -Base.keys(md::Metadata) = md.vns -Base.keys(vi::VarInfo) = Base.keys(vi.metadata) - -# HACK: Necessary to avoid returning `Any[]` which won't dispatch correctly -# on other methods in the codebase which requires `Vector{<:VarName}`. -Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] -@generated function Base.keys(vi::NTVarInfo{<:NamedTuple{names}}) where {names} - expr = Expr(:call) - push!(expr.args, :vcat) - - for n in names - push!(expr.args, :(keys(vi.metadata.$n))) - end - - return expr -end - -is_transformed(vi::VarInfo, vn::VarName) = is_transformed(getmetadata(vi, vn), vn) -is_transformed(md::Metadata, vn::VarName) = md.is_transformed[getidx(md, vn)] - -getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs - -# Need to introduce the _isempty to avoid type piracy of isempty(::NamedTuple). -isempty(vi::VarInfo) = _isempty(vi.metadata) -_isempty(metadata::Metadata) = isempty(metadata.idcs) -_isempty(vnv::VarNamedVector) = isempty(vnv) -@generated function _isempty(metadata::NamedTuple{names}) where {names} - return Expr(:&&, (:(_isempty(metadata.$f)) for f in names)...) -end - -function link!!(::DynamicTransformation, vi::NTVarInfo, model::Model) - vns = all_varnames_grouped_by_symbol(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!(::DynamicTransformation, vi::VarInfo, model::Model) - vns = keys(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) -end - -function link!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _link(model, vi, vns) - vi = _link!!(vi, vns) - return vi -end - -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) -end - -function _link!!(vi::UntypedVarInfo, vns) - # TODO: Change to a lazy iterator over `vns` - if ~is_transformed(vi, vns[1]) - for vn in vns - f = internal_to_linked_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, true, vn) - end - return vi - else - @warn("[DynamicPPL] attempt to link a linked vi") - end -end - -# If we try to _link!! a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _link!!(vi::NTVarInfo, vns::VarNameTuple) - return _link!!(vi, group_varnames_by_symbol(vns)) -end - -function _link!!(vi::NTVarInfo, vns::NamedTuple) - return _link!!(vi.metadata, vi, vns) -end - -""" - filter_subsumed(filter_vns, filtered_vns) - -Return the subset of `filtered_vns` that are subsumed by any variable in `filter_vns`. -""" -function filter_subsumed(filter_vns, filtered_vns) - return filter(x -> any(subsumes(y, x) for y in filter_vns), filtered_vns) -end - -@generated function _link!!( - ::NamedTuple{metadata_names}, vi, varnames::NamedTuple{vns_names} -) where {metadata_names,vns_names} - expr = Expr(:block) - for f in metadata_names - if !(f in vns_names) - continue - end - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(varnames.$f, f_vns) - if !isempty(f_vns) - if !is_transformed(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = internal_to_linked_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, true, vn) - end - else - @warn("[DynamicPPL] attempt to link a linked vi") - end - end - end, - ) - end - push!(expr.args, :(return vi)) - return expr -end - -function invlink!!(::DynamicTransformation, vi::NTVarInfo, model::Model) - vns = all_varnames_grouped_by_symbol(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) - vns = keys(vi) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end - -function invlink!!(::DynamicTransformation, vi::VarInfo, vns::VarNameTuple, model::Model) - # If we're working with a `VarNamedVector`, we always use immutable. - has_varnamedvector(vi) && return _invlink(model, vi, vns) - vi = _invlink!!(vi, vns) - return vi -end - -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) -end - -function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) - # Because `VarInfo` does not contain any information about what the transformation - # other than whether or not it has actually been transformed, the best we can do - # is just assume that `default_transformation` is the correct one if - # `is_transformed(vi)`. - t = is_transformed(vi) ? default_transformation(model, vi) : NoTransformation() - return maybe_invlink_before_eval!!(t, vi, model) -end - -function _invlink!!(vi::UntypedVarInfo, vns) - if is_transformed(vi, vns[1]) - for vn in vns - f = linked_internal_to_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, false, vn) - end - return vi - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end -end - -# If we try to _invlink!! a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _invlink!!(vi::NTVarInfo, vns::VarNameTuple) - return _invlink!!(vi.metadata, vi, group_varnames_by_symbol(vns)) -end - -function _invlink!!(vi::NTVarInfo, vns::NamedTuple) - return _invlink!!(vi.metadata, vi, vns) -end - -@generated function _invlink!!( - ::NamedTuple{metadata_names}, vi, vns::NamedTuple{vns_names} -) where {metadata_names,vns_names} - expr = Expr(:block) - for f in metadata_names - if !(f in vns_names) - continue - end - - push!( - expr.args, - quote - f_vns = vi.metadata.$f.vns - f_vns = filter_subsumed(vns.$f, f_vns) - if is_transformed(vi, f_vns[1]) - # Iterate over all `f_vns` and transform - for vn in f_vns - f = linked_internal_to_internal_transform(vi, vn) - vi = _inner_transform!(vi, vn, f) - vi = set_transformed!!(vi, false, vn) - end - else - @warn("[DynamicPPL] attempt to invlink an invlinked vi") - end - end, - ) - end - push!(expr.args, :(return vi)) - return expr -end - -function _inner_transform!(vi::VarInfo, vn::VarName, f) - return _inner_transform!(getmetadata(vi, vn), vi, vn, f) -end - -function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) - # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn)) - # Determine the new range. - start = first(getrange(md, vn)) - # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(md, vn, start:(start + length(yvec) - 1)) - # Set the new value. - setval!(md, yvec, vn) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return vi -end - -function link(::DynamicTransformation, vi::NTVarInfo, model::Model) - return _link(model, vi, all_varnames_grouped_by_symbol(vi)) -end - -function link(::DynamicTransformation, varinfo::VarInfo, model::Model) - return _link(model, varinfo, keys(varinfo)) -end - -function link(::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, model) -end - -function link(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) - return _link(model, varinfo, vns) -end - -function link( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = link(varinfo.varinfo, vns, model) -end - -function _link(model::Model, varinfo::VarInfo, vns) - varinfo = deepcopy(varinfo) - md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) - end - return new_varinfo -end - -# If we try to _link a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _link(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) - return _link(model, varinfo, group_varnames_by_symbol(vns)) -end - -function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) - varinfo = deepcopy(varinfo) - md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - new_varinfo = acclogjac!!(new_varinfo, logjac) - end - return new_varinfo -end - -@generated function _link_metadata!( - model::Model, - varinfo::VarInfo, - metadata::NamedTuple{metadata_names}, - vns::NamedTuple{vns_names}, -) where {metadata_names,vns_names} - expr = quote - cumulative_logjac = zero(LogProbType) - end - mds = Expr(:tuple) - for f in metadata_names - if f in vns_names - push!( - mds.args, - quote - begin - md, logjac = _link_metadata!!(model, varinfo, metadata.$f, vns.$f) - cumulative_logjac += logjac - md - end - end, - ) - else - push!(mds.args, :(metadata.$f)) - end - end - - push!( - expr.args, - quote - NamedTuple{$metadata_names}($mds), cumulative_logjac - end, - ) - return expr -end - -function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) - vns = metadata.vns - cumulative_logjac = zero(LogProbType) - - # Construct the new transformed values, and keep track of their lengths. - vals_new = map(vns) do vn - # Return early if we're already in unconstrained space. - # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) - return metadata.vals[getrange(metadata, vn)] - end - - # Transform to constrained space. - x = getindex_internal(metadata, vn) - dist = getdist(metadata, vn) - f = internal_to_linked_internal_transform(varinfo, vn, dist) - y, logjac = with_logabsdet_jacobian(f, x) - # Vectorize value. - yvec = tovec(y) - # Accumulate the log-abs-det jacobian correction. - cumulative_logjac += logjac - # Mark as transformed. - set_transformed!!(varinfo, true, vn) - # Return the vectorized transformed value. - return yvec - end - - # Determine new ranges. - ranges_new = similar(metadata.ranges) - offset = 0 - for (i, v) in enumerate(vals_new) - r_start, r_end = offset + 1, length(v) + offset - offset = r_end - ranges_new[i] = r_start:r_end - end - - # Now we just create a new metadata with the new `vals` and `ranges`. - return Metadata( - metadata.idcs, - metadata.vns, - ranges_new, - reduce(vcat, vals_new), - metadata.dists, - metadata.is_transformed, - ), - cumulative_logjac -end - -function _link_metadata!!( - model::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns -) - vns = target_vns === nothing ? keys(metadata) : target_vns - dists = extract_priors(model, varinfo) - cumulative_logjac = zero(LogProbType) - for vn in vns - # First transform from however the variable is stored in vnv to the model - # representation. - transform_to_orig = gettransform(metadata, vn) - val_old = getindex_internal(metadata, vn) - val_orig, logjac1 = with_logabsdet_jacobian(transform_to_orig, val_old) - # Then transform from the model representation to the linked representation. - transform_from_linked = from_linked_vec_transform(dists[vn]) - transform_to_linked = inverse(transform_from_linked) - val_new, logjac2 = with_logabsdet_jacobian(transform_to_linked, val_orig) - # TODO(mhauru) We are calling a !! function but ignoring the return value. - # Fix this when attending to issue #653. - cumulative_logjac += logjac1 + logjac2 - metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) - set_transformed!(metadata, true, vn) - end - # Linking can often change the sizes of variables, causing inactive elements. We don't - # want to keep them around, since typically linking is done once and then the VarInfo - # is evaluated multiple times. Hence we contiguify here. - metadata = contiguify!(metadata) - return metadata, cumulative_logjac -end - -function invlink(::DynamicTransformation, vi::NTVarInfo, model::Model) - return _invlink(model, vi, all_varnames_grouped_by_symbol(vi)) -end - -function invlink(::DynamicTransformation, vi::VarInfo, model::Model) - return _invlink(model, vi, keys(vi)) -end - -function invlink( - ::DynamicTransformation, varinfo::ThreadSafeVarInfo{<:VarInfo}, model::Model -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, model) -end - -function invlink(::DynamicTransformation, varinfo::VarInfo, vns::VarNameTuple, model::Model) - return _invlink(model, varinfo, vns) -end - -function invlink( - ::DynamicTransformation, - varinfo::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set varinfo.varinfo = invlink(varinfo.varinfo, vns, model) -end - -function _invlink(model::Model, varinfo::VarInfo, vns) - varinfo = deepcopy(varinfo) - md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - new_varinfo = acclogjac!!(new_varinfo, inv_logjac) - end - return new_varinfo -end - -# If we try to _invlink a NTVarInfo with a Tuple of VarNames, first convert it to a -# NamedTuple that matches the structure of the NTVarInfo. -function _invlink(model::Model, varinfo::NTVarInfo, vns::VarNameTuple) - return _invlink(model, varinfo, group_varnames_by_symbol(vns)) -end - -function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) - varinfo = deepcopy(varinfo) - md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) - new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogJacobian)) - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - new_varinfo = acclogjac!!(new_varinfo, inv_logjac) - end - return new_varinfo -end - -@generated function _invlink_metadata!( - model::Model, - varinfo::VarInfo, - metadata::NamedTuple{metadata_names}, - vns::NamedTuple{vns_names}, -) where {metadata_names,vns_names} - expr = quote - cumulative_inv_logjac = zero(LogProbType) - end - mds = Expr(:tuple) - for f in metadata_names - if (f in vns_names) - push!( - mds.args, - quote - begin - md, inv_logjac = _invlink_metadata!!( - model, varinfo, metadata.$f, vns.$f - ) - cumulative_inv_logjac += inv_logjac - md - end - end, - ) - else - push!(mds.args, :(metadata.$f)) - end - end - - push!( - expr.args, - quote - (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) - end, - ) - return expr -end - -function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) - vns = metadata.vns - cumulative_inv_logjac = zero(LogProbType) - - # Construct the new transformed values, and keep track of their lengths. - vals_new = map(vns) do vn - # Return early if we're already in constrained space OR if we're not - # supposed to touch this `vn`. - # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if !is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) - return metadata.vals[getrange(metadata, vn)] - end - - # Transform to constrained space. - y = getindex_internal(varinfo, vn) - dist = getdist(varinfo, vn) - f = from_linked_internal_transform(varinfo, vn, dist) - x, inv_logjac = with_logabsdet_jacobian(f, y) - # Vectorize value. - xvec = tovec(x) - # Accumulate the log-abs-det jacobian correction. - cumulative_inv_logjac += inv_logjac - # Mark as no longer transformed. - set_transformed!!(varinfo, false, vn) - # Return the vectorized transformed value. - return xvec - end - - # Determine new ranges. - ranges_new = similar(metadata.ranges) - offset = 0 - for (i, v) in enumerate(vals_new) - r_start, r_end = offset + 1, length(v) + offset - offset = r_end - ranges_new[i] = r_start:r_end - end - - # Now we just create a new metadata with the new `vals` and `ranges`. - return Metadata( - metadata.idcs, - metadata.vns, - ranges_new, - reduce(vcat, vals_new), - metadata.dists, - metadata.is_transformed, - ), - cumulative_inv_logjac -end - -function _invlink_metadata!!( - ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns -) - vns = target_vns === nothing ? keys(metadata) : target_vns - cumulative_inv_logjac = zero(LogProbType) - for vn in vns - transform = gettransform(metadata, vn) - old_val = getindex_internal(metadata, vn) - new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) - # TODO(mhauru) We are calling a !! function but ignoring the return value. - cumulative_inv_logjac += inv_logjac - new_transform = from_vec_transform(new_val) - metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) - set_transformed!(metadata, false, vn) - end - # Linking can often change the sizes of variables, causing inactive elements. We don't - # want to keep them around, since typically linking is done once and then the VarInfo - # is evaluated multiple times. Hence we contiguify here. - metadata = contiguify!(metadata) - return metadata, cumulative_inv_logjac -end - -# TODO(mhauru) The treatment of the case when some variables are transformed and others are -# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed` -# returned whether the first variable was linked. For NTVarInfo we did an OR over the first -# variables under each symbol. We now more consistently use OR, but I'm not convinced this -# is really the right thing to do. -""" - is_transformed(vi::VarInfo) - -Check whether `vi` is in the transformed space. - -Turing's Hamiltonian samplers use the `link` and `invlink` functions from -[Bijectors.jl](https://github.com/TuringLang/Bijectors.jl) to map a constrained variable -(for example, one bounded to the space `[0, 1]`) from its constrained space to the set of -real numbers. `is_transformed` checks if the number is in the constrained space or the real -space. - -If some but only some of the variables in `vi` are transformed, this function will return -`true`. This behavior will likely change in the future. -""" -function is_transformed(vi::VarInfo) - return any(is_transformed(vi, vn) for vn in keys(vi)) -end - -# The default getindex & setindex!() for get & set values -# NOTE: vi[vn] will always transform the variable to its original space and Julia type -function getindex(vi::VarInfo, vn::VarName) - return from_maybe_linked_internal_transform(vi, vn)(getindex_internal(vi, vn)) -end - -function getindex(vi::VarInfo, vn::VarName, dist::Distribution) - @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getindex_internal(vi, vn) - return from_maybe_linked_internal(vi, vn, dist, val) -end - -function getindex(vi::VarInfo, vns::Vector{<:VarName}) - vals = map(vn -> getindex(vi, vn), vns) - - et = eltype(vals) - # This will catch type unstable cases, where vals has mixed types. - if !isconcretetype(et) - throw(ArgumentError("All variables must have the same type.")) - end - - if et <: Vector - all_of_equal_dimension = all(x -> length(x) == length(vals[1]), vals) - if !all_of_equal_dimension - throw(ArgumentError("All variables must have the same dimension.")) - end - end - - # TODO(mhauru) I'm not very pleased with the return type varying like this, even though - # this should be type stable. - vec_vals = reduce(vcat, vals) - if et <: Vector - # The individual variables are multivariate, and thus we return the values as a - # matrix. - return reshape(vec_vals, (:, length(vns))) - else - # The individual variables are univariate, and thus we return a vector of scalars. - return vec_vals - end -end - -function getindex(vi::VarInfo, vns::Vector{<:VarName}, dist::Distribution) - @assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - vals_linked = mapreduce(vcat, vns) do vn - getindex(vi, vn, dist) - end - return recombine(dist, vals_linked, length(vns)) -end - -# Recursively builds a tuple of the `vals` of all the symbols -@generated function _getindex(metadata, ranges::NamedTuple{names}) where {names} - expr = Expr(:tuple) - for f in names - push!(expr.args, :(metadata.$f.vals[ranges.$f])) - end - return expr -end - -# TODO(mhauru) I think the below implementation of setindex! is a mistake. It should be -# called setindex_internal! since it directly writes to the `vals` field of the metadata. -""" - setindex!(vi::VarInfo, val, vn::VarName) - -Set the current value(s) of the random variable `vn` in `vi` to `val`. - -The value(s) may or may not be transformed to Euclidean space. -""" -setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) -function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) - setindex!(vi, val, vn) - return vi -end - -@inline function findvns(vi, f_vns) - if length(f_vns) == 0 - throw("Unidentified error, please report this error in an issue.") - end - return map(vn -> vi[vn], f_vns) -end - -Base.haskey(metadata::Metadata, vn::VarName) = haskey(metadata.idcs, vn) - -""" - haskey(vi::VarInfo, vn::VarName) - -Check whether `vn` has a value in `vi`. -""" -Base.haskey(vi::VarInfo, vn::VarName) = haskey(getmetadata(vi, vn), vn) -function Base.haskey(vi::NTVarInfo, vn::VarName) - md_haskey = map(vi.metadata) do metadata - haskey(metadata, vn) - end - return any(md_haskey) -end - -function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) - lines = Tuple{String,Any}[ - ("VarNames", vi.metadata.vns), - ("Range", vi.metadata.ranges), - ("Vals", vi.metadata.vals), - ] - for accname in acckeys(vi) - push!(lines, (string(accname), getacc(vi, Val(accname)))) - end - push!(lines, ("is_transformed", vi.metadata.is_transformed)) - max_name_length = maximum(map(length ∘ first, lines)) - fmt = Printf.Format("%-$(max_name_length)s") - vi_str = ( - """ - /======================================================================= - | VarInfo - |----------------------------------------------------------------------- - """ * - prod( - map(lines) do (name, value) - """ - | $(Printf.format(fmt, name)) : $(value) - """ - end, - ) * - """ - \\======================================================================= - """ - ) - return print(io, vi_str) -end - -const _MAX_VARS_SHOWN = 4 - -function _show_varnames(io::IO, vi) - md = vi.metadata - vns = keys(md) - - vns_by_name = Dict{Symbol,Vector{VarName}}() - for vn in vns - group = get!(() -> Vector{VarName}(), vns_by_name, getsym(vn)) - push!(group, vn) - end - - L = length(vns_by_name) - if L == 0 - print(io, "0 variables, dimension 0") - else - (L == 1) ? print(io, "1 variable (") : print(io, L, " variables (") - join(io, Iterators.take(keys(vns_by_name), _MAX_VARS_SHOWN), ", ") - (L > _MAX_VARS_SHOWN) && print(io, ", ...") - print(io, "), dimension ", length(md.vals)) - end -end - -function Base.show(io::IO, vi::UntypedVarInfo) - print(io, "VarInfo (") - _show_varnames(io, vi) - print(io, "; accumulators: ") - # TODO(mhauru) This uses "text/plain" because we are doing quite a condensed repretation - # of vi anyway. However, technically `show(io, x)` should give full details of x and - # preferably output valid Julia code. - show(io, MIME"text/plain"(), getaccs(vi)) - return print(io, ")") -end - -""" - push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) - -Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`, mutating if it makes sense. -""" -function BangBang.push!!(vi::VarInfo, vn::VarName, val, dist::Distribution) - @assert ~(vn in keys(vi)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist" - md = push!!(getmetadata(vi, vn), vn, val, dist) - return VarInfo(md, vi.accs) -end - -function BangBang.push!!(vi::NTVarInfo, vn::VarName, val, dist::Distribution) - @assert ~(haskey(vi, vn)) "[push!!] attempt to add an existing variable $(getsym(vn)) ($(vn)) to NTVarInfo of syms $(syms(vi)) with dist=$dist" - sym = getsym(vn) - meta = if ~haskey(vi.metadata, sym) - # The NamedTuple doesn't have an entry for this variable, let's add one. - _new_submetadata(vi, vn, val, dist) - else - push!!(getmetadata(vi, vn), vn, val, dist) - end - vi = Accessors.@set vi.metadata[sym] = meta - return vi -end - -""" - _new_submetadata(vi::VarInfo{NamedTuple{Names,SubMetas}}, args...) where {Names,SubMetas} - -Create a new sub-metadata for an NTVarInfo. The type is chosen by the types of existing -SubMetas. -""" -@generated function _new_submetadata( - vi::VarInfo{NamedTuple{Names,SubMetas}}, vn, r, dist -) where {Names,SubMetas} - has_vnv = any(s -> s <: VarNamedVector, SubMetas.parameters) - return if has_vnv - :(return _new_vnv_submetadata(vn, r, dist)) - else - :(return _new_metadata_submetadata(vn, r, dist)) - end -end - -_new_vnv_submetadata(vn, r, _) = VarNamedVector([vn], [r]) - -function _new_metadata_submetadata(vn, r, dist) - val = tovec(r) - return Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) -end - -function Base.push!(vi::UntypedVectorVarInfo, pair::Pair, args...) - vn, val = pair - return push!(vi, vn, val, args...) -end - -# TODO(mhauru) push! can't be implemented in-place for NTVarInfo if the symbol doesn't -# exist in the NTVarInfo already. We could implement it in the cases where it it does -# exist, but that feels a bit pointless. I think we should rather rely on `push!!`. - -function Base.push!(meta::Metadata, vn, r, dist) - val = tovec(r) - meta.idcs[vn] = length(meta.idcs) + 1 - push!(meta.vns, vn) - l = length(meta.vals) - n = length(val) - push!(meta.ranges, (l + 1):(l + n)) - append!(meta.vals, val) - push!(meta.dists, dist) - push!(meta.is_transformed, false) - return meta -end - -function BangBang.push!!(meta::Metadata, vn, r, dist) - push!(meta, vn, r, dist) - return meta -end - -function Base.delete!(vi::VarInfo, vn::VarName) - delete!(getmetadata(vi, vn), vn) - return vi -end - -####################################### -# Rand & replaying method for VarInfo # -####################################### - -# TODO: Maybe rename or something? -""" - _apply!(kernel!, vi::VarInfo, values, keys) - -Calls `kernel!(vi, vn, values, keys)` for every `vn` in `vi`. -""" -function _apply!(kernel!, vi::VarInfoOrThreadSafeVarInfo, values, keys) - keys_strings = map(string, collect_maybe(keys)) - num_indices_seen = 0 - - for vn in Base.keys(vi) - indices_found = kernel!(vi, vn, values, keys_strings) - if indices_found !== nothing - num_indices_seen += length(indices_found) - end - end - - if length(keys) > num_indices_seen - # Some keys have not been seen, i.e. attempted to set variables which - # we were not able to locate in `vi`. - # Find the ones we missed so we can warn the user. - unused_keys = _find_missing_keys(vi, keys_strings) - @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" - end - - return vi -end - -function _apply!(kernel!, vi::NTVarInfo, values, keys) - return _typed_apply!(kernel!, vi, vi.metadata, values, collect_maybe(keys)) -end - -@generated function _typed_apply!( - kernel!, vi::NTVarInfo, metadata::NamedTuple{names}, values, keys -) where {names} - updates = map(names) do n - quote - for vn in Base.keys(metadata.$n) - indices_found = kernel!(vi, vn, values, keys_strings) - if indices_found !== nothing - num_indices_seen += length(indices_found) - end - end - end - end - - return quote - keys_strings = map(string, keys) - num_indices_seen = 0 - - $(updates...) - - if length(keys) > num_indices_seen - # Some keys have not been seen, i.e. attempted to set variables which - # we were not able to locate in `vi`. - # Find the ones we missed so we can warn the user. - unused_keys = _find_missing_keys(vi, keys_strings) - @warn "the following keys were not found in `vi`, and thus `kernel!` was not applied to these: $(unused_keys)" - end - - return vi - end -end - -function _find_missing_keys(vi::VarInfoOrThreadSafeVarInfo, keys) - string_vns = map(string, collect_maybe(Base.keys(vi))) - # If `key` isn't subsumed by any element of `string_vns`, it is not present in `vi`. - missing_keys = filter(keys) do key - !any(Base.Fix2(subsumes_string, key), string_vns) - end - - return missing_keys -end - -values_as(vi::VarInfo) = vi.metadata -values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) -function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) - iter = values_from_metadata(vi.metadata) - return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) -end -function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata)) -end - -function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} - iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) -end - -function values_as( - vi::VarInfo{<:NamedTuple{names}}, ::Type{D} -) where {names,D<:AbstractDict} - iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) - return ConstructionBase.constructorof(D)(iter) -end - -values_as(vi::UntypedVectorVarInfo, args...) = values_as(vi.metadata, args...) -values_as(vi::UntypedVectorVarInfo, T::Type{Vector}) = values_as(vi.metadata, T) - -function values_from_metadata(md::Metadata) - return ( - # `copy` to avoid accidentally mutation of internal representation. - vn => copy( - from_internal_transform(md, vn, getdist(md, vn))(getindex_internal(md, vn)) - ) for vn in md.vns - ) -end - -values_from_metadata(md::VarNamedVector) = pairs(md) - -# Transforming from internal representation to distribution representation. -# Without `dist` argument: base on `dist` extracted from self. -function from_internal_transform(vi::VarInfo, vn::VarName) - return from_internal_transform(getmetadata(vi, vn), vn) -end -function from_internal_transform(md::Metadata, vn::VarName) - return from_internal_transform(md, vn, getdist(md, vn)) -end -function from_internal_transform(md::VarNamedVector, vn::VarName) - return gettransform(md, vn) -end -# With both `vn` and `dist` arguments: base on provided `dist`. -function from_internal_transform(vi::VarInfo, vn::VarName, dist) - return from_internal_transform(getmetadata(vi, vn), vn, dist) -end -from_internal_transform(::Metadata, ::VarName, dist) = from_vec_transform(dist) -function from_internal_transform(::VarNamedVector, ::VarName, dist) - return from_vec_transform(dist) -end - -# Without `dist` argument: base on `dist` extracted from self. -function from_linked_internal_transform(vi::VarInfo, vn::VarName) - return from_linked_internal_transform(getmetadata(vi, vn), vn) -end -function from_linked_internal_transform(md::Metadata, vn::VarName) - return from_linked_internal_transform(md, vn, getdist(md, vn)) -end -function from_linked_internal_transform(md::VarNamedVector, vn::VarName) - return gettransform(md, vn) -end -# With both `vn` and `dist` arguments: base on provided `dist`. -function from_linked_internal_transform(vi::VarInfo, vn::VarName, dist) - # Dispatch to metadata in case this alters the behavior. - return from_linked_internal_transform(getmetadata(vi, vn), vn, dist) -end -function from_linked_internal_transform(::Metadata, ::VarName, dist) - return from_linked_vec_transform(dist) -end -function from_linked_internal_transform(::VarNamedVector, ::VarName, dist) - return from_linked_vec_transform(dist) -end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl deleted file mode 100644 index e5d2f2c2e..000000000 --- a/src/varnamedvector.jl +++ /dev/null @@ -1,1674 +0,0 @@ -const CHECK_CONSISTENCY_DEFAULT = true - -""" - VarNamedVector - -A container that stores values in a vectorised form, but indexable by variable names. - -A `VarNamedVector` can be thought of as an ordered mapping from `VarName`s to pairs of -`(internal_value, transform)`. Here `internal_value` is a vectorised value for the variable -and `transform` is a function such that `transform(internal_value)` is the "original" value -of the variable, the one that the user sees. For instance, if the variable has a matrix -value, `internal_value` could bea flattened `Vector` of its elements, and `transform` would -be a `reshape` call. - -`transform` may implement simply vectorisation, but it may do more. Most importantly, it may -implement linking, where the internal storage of a random variable is in a form where all -values in Euclidean space are valid. This is useful for sampling, because the sampler can -make changes to `internal_value` without worrying about constraints on the space of -the random variable. - -The way to access this storage format directly is through the functions `getindex_internal` -and `setindex_internal`. The `transform` argument for `setindex_internal` is optional, by -default it is either the identity, or the existing transform if a value already exists for -this `VarName`. - -`VarNamedVector` also provides a `Dict`-like interface that hides away the internal -vectorisation. This can be accessed with `getindex` and `setindex!`. `setindex!` only takes -the value, the transform is automatically set to be a simple vectorisation. The only notable -deviation from the behavior of a `Dict` is that `setindex!` will throw an error if one tries -to set a new value for a variable that lives in a different "space" than the old one (e.g. -is of a different type or size). This is because `setindex!` does not change the transform -of a variable, e.g. preserve linking, and thus the new value must be compatible with the old -transform. - -For now, a third value is in fact stored for each `VarName`: a boolean indicating whether -the variable has been transformed to unconstrained Euclidean space or not. This is only in -place temporarily due to the needs of our old Gibbs sampler. - -Internally, `VarNamedVector` stores the values of all variables in a single contiguous -vector. This makes some operations more efficient, and means that one can access the entire -contents of the internal storage quickly with `getindex_internal(vnv, :)`. The other fields -of `VarNamedVector` are mostly used to keep track of which part of the internal storage -belongs to which `VarName`. - -All constructors accept a keyword argument `check_consistency::Bool=true` that controls -whether to run checks like the number of values matching the number of variables. Some of -these checks can be expensive, so if you are confident in the input, you may want to turn -`check_consistency` off for performance. - -# Fields - -$(FIELDS) - -# Extended help - -The values for different variables are internally all stored in a single vector. For -instance, -```jldoctest varnamedvector-struct -julia> using DynamicPPL: ReshapeTransform, VarNamedVector, @varname, setindex!!, update!!, getindex_internal - -julia> vnv = VarNamedVector(); - -julia> vnv = setindex!!(vnv, [0.0, 0.0, 0.0, 0.0], @varname(x)); - -julia> vnv = setindex!!(vnv, reshape(1:6, (2,3)), @varname(y)); - -julia> vnv.vals -10-element Vector{Real}: - 0.0 - 0.0 - 0.0 - 0.0 - 1 - 2 - 3 - 4 - 5 - 6 -``` - -The `varnames`, `ranges`, and `varname_to_index` fields keep track of which value belongs to -which variable. The `transforms` field stores the transformations that needed to transform -the vectorised internal storage back to its original form: - -```jldoctest varnamedvector-struct -julia> vnv.transforms[vnv.varname_to_index[@varname(y)]] == DynamicPPL.ReshapeTransform((6,), (2,3)) -true -``` - -If a variable is updated with a new value that is of a smaller dimension than the old -value, rather than resizing `vnv.vals`, some elements in `vnv.vals` are marked as inactive. - -```jldoctest varnamedvector-struct -julia> vnv = update!!(vnv, [46.0, 48.0], @varname(x)); - -julia> vnv.vals -10-element Vector{Real}: - 46.0 - 48.0 - 0.0 - 0.0 - 1 - 2 - 3 - 4 - 5 - 6 - -julia> println(vnv.num_inactive); -Dict(1 => 2) -``` - -This helps avoid unnecessary memory allocations for values that repeatedly change dimension. -The user does not have to worry about the inactive entries as long as they use functions -like `setindex!` and `getindex!` rather than directly accessing `vnv.vals`. - -```jldoctest varnamedvector-struct -julia> vnv[@varname(x)] -2-element Vector{Real}: - 46.0 - 48.0 - -julia> getindex_internal(vnv, :) -8-element Vector{Real}: - 46.0 - 48.0 - 1 - 2 - 3 - 4 - 5 - 6 -``` -""" -struct VarNamedVector{ - K<:VarName,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T} -} - """ - mapping from a `VarName` to its integer index in `varnames`, `ranges` and `transforms` - """ - varname_to_index::Dict{K,Int} - - """ - vector of `VarNames` for the variables, where `varnames[varname_to_index[vn]] == vn` - """ - varnames::KVec - - """ - vector of index ranges in `vals` corresponding to `varnames`; each `VarName` `vn` has - a single index or a set of contiguous indices, such that the values of `vn` can be found - at `vals[ranges[varname_to_index[vn]]]` - """ - ranges::Vector{UnitRange{Int}} - - """ - vector of values of all variables; the value(s) of `vn` is/are - `vals[ranges[varname_to_index[vn]]]` - """ - vals::VVec - - """ - vector of transformations, so that `transforms[varname_to_index[vn]]` is a callable - that transforms the value of `vn` back to its original space, undoing any linking and - vectorisation - """ - transforms::TVec - - """ - vector of booleans indicating whether a variable has been explicitly transformed to - unconstrained Euclidean space, i.e. whether its domain is all of `ℝ^ⁿ`. If - `is_unconstrained[varname_to_index[vn]]` is true, it guarantees that the variable - `vn` is not constrained. However, the converse does not hold: if `is_unconstrained` - is false, the variable `vn` may still happen to be unconstrained, e.g. if its - original distribution is itself unconstrained (like a normal distribution). - """ - is_unconstrained::BitVector - - """ - mapping from a variable index to the number of inactive entries for that variable. - Inactive entries are elements in `vals` that are not part of the value of any variable. - They arise when a variable is set to a new value with a different dimension, in-place. - Inactive entries always come after the last active entry for the given variable. - See the extended help with `??VarNamedVector` for more details. - """ - num_inactive::Dict{Int,Int} - - function VarNamedVector( - varname_to_index, - varnames::KVec, - ranges, - vals::VVec, - transforms::TVec, - is_unconstrained=fill!(BitVector(undef, length(varnames)), 0), - num_inactive=Dict{Int,Int}(); - check_consistency::Bool=CHECK_CONSISTENCY_DEFAULT, - ) where {K,V,T,KVec<:AbstractVector{K},VVec<:AbstractVector{V},TVec<:AbstractVector{T}} - if check_consistency - if length(varnames) != length(ranges) || - length(varnames) != length(transforms) || - length(varnames) != length(is_unconstrained) || - length(varnames) != length(varname_to_index) - msg = ( - "Inputs to VarNamedVector have inconsistent lengths. " * - "Got lengths varnames: $(length(varnames)), " * - "ranges: $(length(ranges)), " * - "transforms: $(length(transforms)), " * - "is_unconstrained: $(length(is_unconstrained)), " * - "varname_to_index: $(length(varname_to_index))." - ) - throw(ArgumentError(msg)) - end - - num_vals = mapreduce(length, (+), ranges; init=0) + sum(values(num_inactive)) - if num_vals != length(vals) - msg = ( - "The total number of elements in `vals` ($(length(vals))) does not " * - "match the sum of the lengths of the ranges and the number of " * - "inactive entries ($(num_vals))." - ) - throw(ArgumentError(msg)) - end - - if Set(values(varname_to_index)) != Set(axes(varnames, 1)) - msg = ( - "The set of values of `varname_to_index` is not the set of valid " * - "indices for `varnames`." - ) - throw(ArgumentError(msg)) - end - - if !issubset(Set(keys(num_inactive)), Set(values(varname_to_index))) - msg = ( - "The keys of `num_inactive` are not a subset of the values of " * - "`varname_to_index`." - ) - throw(ArgumentError(msg)) - end - - # Check that the varnames don't overlap. The time cost is quadratic in number of - # variables. If this ever becomes an issue, we should be able to go down to at - # least N log N by sorting based on subsumes-order. - for vn1 in keys(varname_to_index) - for vn2 in keys(varname_to_index) - vn1 === vn2 && continue - if subsumes(vn1, vn2) - msg = ( - "Variables in a VarNamedVector should not subsume each " * - "other, but $vn1 subsumes $vn2." - ) - throw(ArgumentError(msg)) - end - end - end - - # We could also have a test to check that the ranges don't overlap, but that - # sounds unlikely to occur, and implementing it in linear time would require a - # tiny bit of thought. - end - - return new{K,V,T,KVec,VVec,TVec}( - varname_to_index, - varnames, - ranges, - vals, - transforms, - is_unconstrained, - num_inactive, - ) - end -end - -function VarNamedVector{K,V,T}() where {K,V,T} - return VarNamedVector( - Dict{K,Int}(), K[], UnitRange{Int}[], V[], T[]; check_consistency=false - ) -end - -VarNamedVector() = VarNamedVector{Union{},Union{},Union{}}() -function VarNamedVector(xs::Pair...; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector(OrderedDict(xs...); check_consistency=check_consistency) -end -function VarNamedVector(x::AbstractDict; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector(keys(x), values(x); check_consistency=check_consistency) -end -function VarNamedVector(varnames, vals; check_consistency=CHECK_CONSISTENCY_DEFAULT) - return VarNamedVector( - collect_maybe(varnames), collect_maybe(vals); check_consistency=check_consistency - ) -end -function VarNamedVector( - varnames::AbstractVector, - orig_vals::AbstractVector, - transforms=fill(identity, length(varnames)); - check_consistency=CHECK_CONSISTENCY_DEFAULT, -) - if isempty(varnames) && isempty(orig_vals) && isempty(transforms) - return VarNamedVector{eltype(varnames),eltype(orig_vals),eltype(transforms)}() - end - # Convert `vals` into a vector of vectors. - vals_vecs = map(tovec, orig_vals) - transforms = map( - (t, val) -> _compose_no_identity(t, from_vec_transform(val)), transforms, orig_vals - ) - # Make `varnames` have as concrete an element type as possible. - varnames = [v for v in varnames] - varname_to_index = Dict{eltype(varnames),Int}( - vn => i for (i, vn) in enumerate(varnames) - ) - vals = reduce(vcat, vals_vecs) - # Make the ranges. - ranges = Vector{UnitRange{Int}}() - offset = 0 - for x in vals_vecs - r = (offset + 1):(offset + length(x)) - push!(ranges, r) - offset = r[end] - end - - # Passing on check_consistency here seems wasteful. Wouldn't it be faster to do a - # lightweight check of the arguments of this function, and rely on the correctness - # of what this function does? However, the expensive check is whether any variable - # subsumes another, and that's the same regardless of where it's done, so the - # optimisation would be quite pointless. - return VarNamedVector( - varname_to_index, - varnames, - ranges, - vals, - transforms; - check_consistency=check_consistency, - ) -end - -function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) - return vnv_left.varname_to_index == vnv_right.varname_to_index && - vnv_left.varnames == vnv_right.varnames && - vnv_left.ranges == vnv_right.ranges && - vnv_left.vals == vnv_right.vals && - vnv_left.transforms == vnv_right.transforms && - vnv_left.is_unconstrained == vnv_right.is_unconstrained && - vnv_left.num_inactive == vnv_right.num_inactive -end - -function is_tightly_typed(vnv::VarNamedVector) - k = eltype(vnv.varnames) - v = eltype(vnv.vals) - t = eltype(vnv.transforms) - return (isconcretetype(k) || k === Union{}) && - (isconcretetype(v) || v === Union{}) && - (isconcretetype(t) || t === Union{}) -end - -getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] - -getrange(vnv::VarNamedVector, idx::Int) = vnv.ranges[idx] -getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) - -gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] -gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) - -# TODO(mhauru) Eventually I would like to rename the is_transformed function to -# is_unconstrained, but that's significantly breaking. -""" - is_transformed(vnv::VarNamedVector, vn::VarName) - -Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain -is all of Euclidean space. -""" -is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] - -""" - set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) - -Set the value for whether `vn` is guaranteed to have been transformed so that all of -Euclidean space is its domain. -""" -function set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) - return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val -end - -function set_transformed!!(vnv::VarNamedVector, val::Bool, vn::VarName) - set_transformed!(vnv, val, vn) - return vnv -end - -""" - has_inactive(vnv::VarNamedVector) - -Returns `true` if `vnv` has inactive entries. - -See also: [`num_inactive`](@ref) -""" -has_inactive(vnv::VarNamedVector) = !isempty(vnv.num_inactive) - -""" - num_inactive(vnv::VarNamedVector) - -Return the number of inactive entries in `vnv`. - -See also: [`has_inactive`](@ref), [`num_allocated`](@ref) -""" -num_inactive(vnv::VarNamedVector) = sum(values(vnv.num_inactive)) - -""" - num_inactive(vnv::VarNamedVector, vn::VarName) - -Returns the number of inactive entries for `vn` in `vnv`. -""" -num_inactive(vnv::VarNamedVector, vn::VarName) = num_inactive(vnv, getidx(vnv, vn)) -num_inactive(vnv::VarNamedVector, idx::Int) = get(vnv.num_inactive, idx, 0) - -""" - num_allocated(vnv::VarNamedVector) - num_allocated(vnv::VarNamedVector[, vn::VarName]) - num_allocated(vnv::VarNamedVector[, idx::Int]) - -Return the number of allocated entries in `vnv`, both active and inactive. - -If either a `VarName` or an `Int` index is specified, only count entries allocated for that -variable. - -Allocated entries take up memory in `vnv.vals`, but, if inactive, may not currently hold any -meaningful data. One can remove them with [`contiguify!`](@ref), but doing so may cause more -memory allocations in the future if variables change dimension. -""" -num_allocated(vnv::VarNamedVector) = length(vnv.vals) -num_allocated(vnv::VarNamedVector, vn::VarName) = num_allocated(vnv, getidx(vnv, vn)) -function num_allocated(vnv::VarNamedVector, idx::Int) - return length(getrange(vnv, idx)) + num_inactive(vnv, idx) -end - -# Dictionary interface. -Base.isempty(vnv::VarNamedVector) = isempty(vnv.varnames) -Base.length(vnv::VarNamedVector) = length(vnv.varnames) -Base.keys(vnv::VarNamedVector) = vnv.varnames -Base.values(vnv::VarNamedVector) = Iterators.map(Base.Fix1(getindex, vnv), vnv.varnames) -Base.pairs(vnv::VarNamedVector) = (vn => vnv[vn] for vn in keys(vnv)) -Base.haskey(vnv::VarNamedVector, vn::VarName) = haskey(vnv.varname_to_index, vn) - -# Vector-like interface. -Base.eltype(vnv::VarNamedVector) = eltype(vnv.vals) - -""" - length_internal(vnv::VarNamedVector) - -Return the length of the internal storage vector of `vnv`, ignoring inactive entries. -""" -function length_internal(vnv::VarNamedVector) - if !has_inactive(vnv) - return length(vnv.vals) - else - return sum(length, vnv.ranges) - end -end - -# Getting and setting values - -function Base.getindex(vnv::VarNamedVector, vn::VarName) - x = getindex_internal(vnv, vn) - f = gettransform(vnv, vn) - return f(x) -end - -""" - find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) - -Find the first range in `ranges` that contains `x`. - -Throw an `ArgumentError` if `x` is not in any of the ranges. -""" -function find_containing_range(ranges::AbstractVector{<:AbstractRange}, x) - # TODO: Assume `ranges` to be sorted and contiguous, and use `searchsortedfirst` - # for a more efficient approach. - range_idx = findfirst(Base.Fix1(∈, x), ranges) - - # If we're out of bounds, we raise an error. - if range_idx === nothing - throw(ArgumentError("Value $x is not in any of the ranges.")) - end - - return range_idx -end - -""" - adjusted_ranges(vnv::VarNamedVector) - -Return what `vnv.ranges` would be if there were no inactive entries. -""" -function adjusted_ranges(vnv::VarNamedVector) - # Every range following inactive entries needs to be shifted. - offset = 0 - ranges_adj = similar(vnv.ranges) - for (idx, r) in enumerate(vnv.ranges) - # Remove the `offset` in `r` due to inactive entries. - ranges_adj[idx] = r .- offset - # Update `offset`. - offset += get(vnv.num_inactive, idx, 0) - end - - return ranges_adj -end - -""" - index_to_vals_index(vnv::VarNamedVector, i::Int) - -Convert an integer index that ignores inactive entries to an index that accounts for them. - -This is needed when the user wants to index `vnv` like a vector, but shouldn't have to care -about inactive entries in `vnv.vals`. -""" -function index_to_vals_index(vnv::VarNamedVector, i::Int) - # If we don't have any inactive entries, there's nothing to do. - has_inactive(vnv) || return i - - # Get the adjusted ranges. - ranges_adj = adjusted_ranges(vnv) - # Determine the adjusted range that the index corresponds to. - r_idx = find_containing_range(ranges_adj, i) - r = vnv.ranges[r_idx] - # Determine how much of the index `i` is used to get to this range. - i_used = r_idx == 1 ? 0 : sum(length, ranges_adj[1:(r_idx - 1)]) - # Use remainder to index into `r`. - i_remainder = i - i_used - return r[i_remainder] -end - -""" - getindex_internal(vnv::VarNamedVector, vn::VarName) - -Like `getindex`, but returns the values as they are stored in `vnv`, without transforming. -""" -getindex_internal(vnv::VarNamedVector, vn::VarName) = vnv.vals[getrange(vnv, vn)] - -""" - getindex_internal(vnv::VarNamedVector, i::Int) - -Gets the `i`th element of the internal storage vector, ignoring inactive entries. -""" -getindex_internal(vnv::VarNamedVector, i::Int) = vnv.vals[index_to_vals_index(vnv, i)] - -function getindex_internal(vnv::VarNamedVector, ::Colon) - return if has_inactive(vnv) - mapreduce(Base.Fix1(getindex, vnv.vals), vcat, vnv.ranges) - else - vnv.vals - end -end - -function Base.setindex!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - return update!(vnv, val, vn) - else - return insert!(vnv, val, vn) - end -end - -""" - reset!(vnv::VarNamedVector, val, vn::VarName) - -Reset the value of `vn` in `vnv` to `val`. - -This differs from `setindex!` in that it will always change the transform of the variable -to be the default vectorisation transform. This undoes any possible linking. - -# Examples - -```jldoctest varnamedvector-reset -julia> using DynamicPPL: VarNamedVector, @varname, reset! - -julia> vnv = VarNamedVector{VarName,Any,Any}(); - -julia> vnv[@varname(x)] = reshape(1:9, (3, 3)); - -julia> setindex!(vnv, 2.0, @varname(x)) -ERROR: An error occurred while assigning the value 2.0 to variable x. If you are changing the type or size of a variable you'll need to call reset! -[...] - -julia> reset!(vnv, 2.0, @varname(x)); - -julia> vnv[@varname(x)] -2.0 -``` -""" -function reset!(vnv::VarNamedVector, val, vn::VarName) - f = from_vec_transform(val) - retval = setindex_internal!(vnv, tovec(val), vn, f) - set_transformed!(vnv, false, vn) - return retval -end - -""" - update!(vnv::VarNamedVector, val, vn::VarName) - -Update the value of `vn` in `vnv` to `val`. - -Like `setindex!`, but errors if the key `vn` doesn't exist. -""" -function update!(vnv::VarNamedVector, val, vn::VarName) - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - f = inverse(gettransform(vnv, vn)) - internal_val = try - f(val) - catch - error( - "An error occurred while assigning the value $val to variable $vn. " * - "If you are changing the type or size of a variable you'll need to call " * - "reset!", - ) - end - return setindex_internal!(vnv, internal_val, vn) -end - -""" - insert!(vnv::VarNamedVector, val, vn::VarName) - -Add a variable with given value to `vnv`. - -Like `setindex!`, but errors if the key `vn` already exists. -""" -function Base.insert!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - throw("Variable $vn already exists in VarNamedVector.") - end - return reset!(vnv, val, vn) -end - -""" - push!(vnv::VarNamedVector, pair::Pair) - -Add a variable with given value to `vnv`. Pair should be a `VarName` and a value. -""" -function Base.push!(vnv::VarNamedVector, pair::Pair) - vn, val = pair - # TODO(mhauru) Or should this rather call `reset!`? It would be more inline with what - # Dict does, but could also cause confusion. - return setindex!(vnv, val, vn) -end - -""" - setindex_internal!(vnv::VarNamedVector, val, i::Int) - -Sets the `i`th element of the internal storage vector, ignoring inactive entries. -""" -function setindex_internal!(vnv::VarNamedVector, val, i::Int) - return vnv.vals[index_to_vals_index(vnv, i)] = val -end - -""" - setindex_internal!(vnv::VarNamedVector, val, vn::VarName[, transform]) - -Like `setindex!`, but sets the values as they are stored internally in `vnv`. - -Optionally can set the transformation, such that `transform(val)` is the original value of -the variable. By default, the transform is the identity if creating a new entry in `vnv`, or -the existing transform if updating an existing entry. -""" -function setindex_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if haskey(vnv, vn) - return update_internal!(vnv, val, vn, transform) - else - return insert_internal!(vnv, val, vn, transform) - end -end - -""" - insert_internal!(vnv::VarNamedVector, val::AbstractVector, vn::VarName[, transform]) - -Add a variable with given value to `vnv`. - -Like `setindex_internal!`, but errors if the key `vn` already exists. - -`transform` should be a function that converts `val` to the original representation. By -default it's `identity`. -""" -function insert_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if transform === nothing - transform = identity - end - haskey(vnv, vn) && throw(ArgumentError("variable name $vn already exists")) - # NOTE: We need to compute the `nextrange` BEFORE we start mutating the underlying - # storage. - r_new = nextrange(vnv, val) - vnv.varname_to_index[vn] = length(vnv.varname_to_index) + 1 - push!(vnv.varnames, vn) - push!(vnv.ranges, r_new) - append!(vnv.vals, val) - push!(vnv.transforms, transform) - push!(vnv.is_unconstrained, false) - return nothing -end - -""" - update_internal!(vnv::VarNamedVector, vn::VarName, val::AbstractVector[, transform]) - -Update an existing entry for `vn` in `vnv` with the value `val`. - -Like `setindex_internal!`, but errors if the key `vn` doesn't exist. - -`transform` should be a function that converts `val` to the original representation. By -default it's the same as the old transform for `vn`. -""" -function update_internal!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - # Here we update an existing entry. - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - idx = getidx(vnv, vn) - # Extract the old range. - r_old = getrange(vnv, idx) - start_old, end_old = first(r_old), last(r_old) - n_old = length(r_old) - # Compute the new range. - n_new = length(val) - start_new = start_old - end_new = start_old + n_new - 1 - r_new = start_new:end_new - - #= - Suppose we currently have the following: - - | x | x | o | o | o | y | y | y | <- Current entries - - where 'O' denotes an inactive entry, and we're going to - update the variable `x` to be of size `k` instead of 2. - - We then have a few different scenarios: - 1. `k > 5`: All inactive entries become active + need to shift `y` to the right. - E.g. if `k = 7`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | x | x | y | y | y | <- New entries - - 2. `k = 5`: All inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | x | x | y | y | y | <- New entries - - 3. `k < 5`: Some inactive entries become active, some remain inactive. - E.g. if `k = 3`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | x | o | o | y | y | y | <- New entries - - 4. `k = 2`: No inactive entries become active. - Then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | x | o | o | o | y | y | y | <- New entries - - 5. `k < 2`: More entries become inactive. - E.g. if `k = 1`, then - - | x | x | o | o | o | y | y | y | <- Current entries - | x | o | o | o | o | y | y | y | <- New entries - =# - - # Compute the allocated space for `vn`. - had_inactive = haskey(vnv.num_inactive, idx) - n_allocated = had_inactive ? n_old + vnv.num_inactive[idx] : n_old - - if n_new > n_allocated - # Then we need to grow the underlying vector. - n_extra = n_new - n_allocated - # Allocate. - resize!(vnv.vals, length(vnv.vals) + n_extra) - # Shift current values. - shift_right!(vnv.vals, end_old + 1, n_extra) - # No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - # Update the ranges for all variables after this one. - shift_subsequent_ranges_by!(vnv, idx, n_extra) - elseif n_new == n_allocated - # => No more inactive entries. - had_inactive && delete!(vnv.num_inactive, idx) - else - # `n_new < n_allocated` - # => Need to update the number of inactive entries. - vnv.num_inactive[idx] = n_allocated - n_new - end - - # Update the range for this variable. - vnv.ranges[idx] = r_new - # Update the value. - vnv.vals[r_new] = val - if transform !== nothing - # Update the transform. - vnv.transforms[idx] = transform - end - - # TODO: Should we maybe sweep over inactive ranges and re-contiguify - # if the total number of inactive elements is "large" in some sense? - - return nothing -end - -function Base.push!(vnv::VarNamedVector, vn, val, dist) - f = from_vec_transform(dist) - return setindex_internal!(vnv, tovec(val), vn, f) -end - -function BangBang.push!!(vnv::VarNamedVector, vn, val, dist) - f = from_vec_transform(dist) - return setindex_internal!!(vnv, tovec(val), vn, f) -end - -# BangBang versions of the above functions. -# The only difference is that update_internal!! and insert_internal!! check whether the -# container types of the VarNamedVector vector need to be expanded to accommodate the new -# values. If so, they create a new instance, otherwise they mutate in place. All the others -# functions, e.g. setindex!!, setindex_internal!!, etc., are carbon copies of the ! versions -# with every ! call replaced with a !! call. - -""" - loosen_types!!(vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew}) - -Loosen the types of `vnv` to allow varname type `KNew` and transformation type `TransNew`. - -If `KNew` is a subtype of `K` and `TransNew` is a subtype of the element type of the -`TTrans` then this is a no-op and `vnv` is returned as is. Otherwise a new `VarNamedVector` -is returned with the same data but more abstract types, so that variables of type `KNew` and -transformations of type `TransNew` can be pushed to it. Some of the underlying storage is -shared between `vnv` and the return value, and thus mutating one may affect the other. - -# See also -[`tighten_types!!`](@ref) - -# Examples - -```jldoctest varnamedvector-loosen-types -julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! - -julia> vnv = VarNamedVector(@varname(x) => [1.0]); - -julia> y_trans(x) = reshape(x, (2, 2)); - -julia> setindex_internal!(vnv, collect(1:4), @varname(y), y_trans) -ERROR: MethodError: Cannot `convert` an object of type -[...] - -julia> vnv_loose = DynamicPPL.loosen_types!!( - vnv, typeof(@varname(y)), Float64, typeof(y_trans) - ); - -julia> setindex_internal!(vnv_loose, collect(1:4), @varname(y), y_trans) - -julia> vnv_loose[@varname(y)] -2×2 Matrix{Float64}: - 1.0 3.0 - 2.0 4.0 -``` -""" -function loosen_types!!( - vnv::VarNamedVector, ::Type{KNew}, ::Type{VNew}, ::Type{TNew} -) where {KNew,VNew,TNew} - K = eltype(vnv.varnames) - V = eltype(vnv.vals) - T = eltype(vnv.transforms) - if KNew <: K && VNew <: V && TNew <: T - return vnv - else - # We could use promote_type here, instead of typejoin. However, that would e.g. - # cause Ints to be converted to Float64s, since - # promote_type(Int, Float64) == Float64, which can cause problems. See - # https://github.com/TuringLang/DynamicPPL.jl/pull/1098#discussion_r2472636188. - # Base.promote_typejoin would be like typejoin, but creates Unions out of Nothing - # and Missing, rather than falling back on Any. However, it's not exported. - vn_type = typejoin(K, KNew) - val_type = typejoin(V, VNew) - transform_type = typejoin(T, TNew) - # This function would work the same way if the first if statement a few lines above - # was skipped, and we only checked for the below condition. However, the first one - # is constant propagated away at compile time (at least on Julia v1.11.7), whereas - # this one isn't. Hence we keep both for performance. - return if vn_type == K && val_type == V && transform_type == T - vnv - elseif isempty(vnv) - VarNamedVector( - Dict{vn_type,Int}(), - Vector{vn_type}(), - UnitRange{Int}[], - Vector{val_type}(), - Vector{transform_type}(), - BitVector(), - Dict{Int,Int}(); - check_consistency=false, - ) - else - # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but - # then here always revert to Vector. - VarNamedVector( - Dict{vn_type,Int}(vnv.varname_to_index), - Vector{vn_type}(vnv.varnames), - vnv.ranges, - Vector{val_type}(vnv.vals), - Vector{transform_type}(vnv.transforms), - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) - end - end -end - -""" - tighten_types!!(vnv::VarNamedVector) - -Return a `VarNamedVector` like `vnv` with the most concrete types possible. - -This function either returns `vnv` itself or new `VarNamedVector` with the same values in -it, but with the element types of various containers made as concrete as possible. - -For instance, if `vnv` has its vector of transforms have eltype `Any`, but all the -transforms are actually identity transformations, this function will return a new -`VarNamedVector` with the transforms vector having eltype `typeof(identity)`. - -This is a lot like the reverse of [`loosen_types!!`](@ref). Like with `loosen_types!!`, the -return value may share some of its underlying storage with `vnv`, and thus mutating one may -affect the other. - -# See also -[`loosen_types!!`](@ref) - -# Examples - -```jldoctest varnamedvector-tighten-types -julia> using DynamicPPL: VarNamedVector, @varname, loosen_types!!, setindex_internal! - -julia> vnv = VarNamedVector(@varname(x) => Real[23], @varname(y) => randn(2,2)); - -julia> vnv = delete!(vnv, @varname(y)); - -julia> eltype(vnv) -Real - -julia> vnv.transforms -1-element Vector{Any}: - identity (generic function with 1 method) - -julia> vnv_tight = DynamicPPL.tighten_types!!(vnv); - -julia> eltype(vnv_tight) == Int -true - -julia> vnv_tight.transforms -1-element Vector{typeof(identity)}: - identity (generic function with 1 method) -``` -""" -function tighten_types!!(vnv::VarNamedVector) - return if is_tightly_typed(vnv) - # There can not be anything to tighten, so short-circuit. - vnv - elseif isempty(vnv) - VarNamedVector() - else - VarNamedVector( - Dict(vnv.varname_to_index...), - [x for x in vnv.varnames], - vnv.ranges, - [x for x in vnv.vals], - [x for x in vnv.transforms], - vnv.is_unconstrained, - vnv.num_inactive; - check_consistency=false, - ) - end -end - -function BangBang.setindex!!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - return update!!(vnv, val, vn) - else - return insert!!(vnv, val, vn) - end -end - -function reset!!(vnv::VarNamedVector, val, vn::VarName) - f = from_vec_transform(val) - vnv = setindex_internal!!(vnv, tovec(val), vn, f) - vnv = set_transformed!!(vnv, false, vn) - return vnv -end - -function update!!(vnv::VarNamedVector, val, vn::VarName) - if !haskey(vnv, vn) - throw(KeyError(vn)) - end - f = inverse(gettransform(vnv, vn)) - internal_val = try - f(val) - catch - error( - "An error occurred while assigning the value $val to variable $vn. " * - "If you are changing the type or size of a variable you'll need to either " * - "`delete!` it first or use `setindex_internal!`", - ) - end - return setindex_internal!!(vnv, internal_val, vn) -end - -function insert!!(vnv::VarNamedVector, val, vn::VarName) - if haskey(vnv, vn) - throw("Variable $vn already exists in VarNamedVector.") - end - return reset!!(vnv, val, vn) -end - -function setindex_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if haskey(vnv, vn) - return update_internal!!(vnv, val, vn, transform) - else - return insert_internal!!(vnv, val, vn, transform) - end -end - -function insert_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - if transform === nothing - transform = identity - end - vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) - insert_internal!(vnv, val, vn, transform) - vnv = tighten_types!!(vnv) - return vnv -end - -function update_internal!!( - vnv::VarNamedVector, val::AbstractVector, vn::VarName, transform=nothing -) - transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform - vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) - update_internal!(vnv, val, vn, transform) - vnv = tighten_types!!(vnv) - return vnv -end - -function BangBang.push!!(vnv::VarNamedVector, pair::Pair) - vn, val = pair - return setindex!!(vnv, val, vn) -end - -function Base.empty!(vnv::VarNamedVector) - # TODO: Or should the semantics be different, e.g. keeping `varnames`? - empty!(vnv.varname_to_index) - empty!(vnv.varnames) - empty!(vnv.ranges) - empty!(vnv.vals) - empty!(vnv.transforms) - empty!(vnv.is_unconstrained) - empty!(vnv.num_inactive) - return nothing -end -BangBang.empty!!(vnv::VarNamedVector) = (empty!(vnv); return vnv) - -""" - replace_raw_storage(vnv::VarNamedVector, vals::AbstractVector) - -Replace the values in `vnv` with `vals`, as they are stored internally. - -This is useful when we want to update the entire underlying vector of values in one go or if -we want to change the how the values are stored, e.g. alter the `eltype`. - -!!! warning - This replaces the raw underlying values, and so care should be taken when using this - function. For example, if `vnv` has any inactive entries, then the provided `vals` - should also contain the inactive entries to avoid unexpected behavior. - -# Examples - -```jldoctest varnamedvector-replace-raw-storage -julia> using DynamicPPL: VarNamedVector, replace_raw_storage - -julia> vnv = VarNamedVector(@varname(x) => [1.0]); - -julia> replace_raw_storage(vnv, [2.0])[@varname(x)] == [2.0] -true -``` - -This is also useful when we want to differentiate wrt. the values using automatic -differentiation, e.g. ForwardDiff.jl. - -```jldoctest varnamedvector-replace-raw-storage -julia> using ForwardDiff: ForwardDiff - -julia> f(x) = sum(abs2, replace_raw_storage(vnv, x)[@varname(x)]) -f (generic function with 1 method) - -julia> ForwardDiff.gradient(f, [1.0]) -1-element Vector{Float64}: - 2.0 -``` -""" -replace_raw_storage(vnv::VarNamedVector, vals) = Accessors.@set vnv.vals = vals - -vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) - -""" - unflatten(vnv::VarNamedVector, vals::AbstractVector) - -Return a new instance of `vnv` with the values of `vals` assigned to the variables. - -This assumes that `vals` have been transformed by the same transformations that that the -values in `vnv` have been transformed by. However, unlike [`replace_raw_storage`](@ref), -`unflatten` does account for inactive entries in `vnv`, so that the user does not have to -care about them. - -This is in a sense the reverse operation of `vnv[:]`. - -The return value may share memory with the input `vnv`, and thus one can not be mutated -safely without affecting the other. - -Unflatten recontiguifies the internal storage, getting rid of any inactive entries. - -# Examples - -```jldoctest varnamedvector-unflatten -julia> using DynamicPPL: VarNamedVector, unflatten - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); - -julia> unflatten(vnv, vnv[:]) == vnv -true -""" -function unflatten(vnv::VarNamedVector, vals::AbstractVector) - if length(vals) != vector_length(vnv) - throw( - ArgumentError( - "Length of `vals` ($(length(vals))) does not match the length of " * - "`vnv` ($(vector_length(vnv))).", - ), - ) - end - new_ranges = vnv.ranges - num_inactive = vnv.num_inactive - if has_inactive(vnv) - new_ranges = recontiguify_ranges!(new_ranges) - num_inactive = Dict{Int,Int}() - end - return VarNamedVector( - vnv.varname_to_index, - vnv.varnames, - new_ranges, - vals, - vnv.transforms, - vnv.is_unconstrained, - num_inactive; - check_consistency=false, - ) -end - -function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) - # Return early if possible. - isempty(left_vnv) && return deepcopy(right_vnv) - isempty(right_vnv) && return deepcopy(left_vnv) - - # Determine varnames. - vns_left = left_vnv.varnames - vns_right = right_vnv.varnames - vns_both = union(vns_left, vns_right) - - # Check that varnames do not subsume each other. - for vn_left in vns_left - for vn_right in vns_right - vn_left == vn_right && continue - # TODO(mhauru) Subsumation doesn't actually need to be a showstopper. For - # instance, if right has a value for `x` and left has a value for `x[1]`, then - # right will take precedence anyway, and we could merge. However, that requires - # some extra logic that hasn't been done yet. - if subsumes(vn_left, vn_right) - throw( - ArgumentError( - "Cannot merge VarNamedVectors: variable name $vn_left " * - "subsumes $vn_right.", - ), - ) - elseif subsumes(vn_right, vn_left) - throw( - ArgumentError( - "Cannot merge VarNamedVectors: variable name $vn_right " * - "subsumes $vn_left.", - ), - ) - end - end - end - - # Determine `eltype` of `vals`. - T_left = eltype(left_vnv.vals) - T_right = eltype(right_vnv.vals) - T = typejoin(T_left, T_right) - - # Determine `eltype` of `varnames`. - V_left = eltype(left_vnv.varnames) - V_right = eltype(right_vnv.varnames) - V = typejoin(V_left, V_right) - if !(V <: VarName) - V = VarName - end - - # Determine `eltype` of `transforms`. - F_left = eltype(left_vnv.transforms) - F_right = eltype(right_vnv.transforms) - F = typejoin(F_left, F_right) - - # Allocate. - varname_to_index = Dict{V,Int}() - ranges = UnitRange{Int}[] - vals = T[] - transforms = F[] - is_unconstrained = BitVector(undef, length(vns_both)) - - # Range offset. - offset = 0 - - for (idx, vn) in enumerate(vns_both) - varname_to_index[vn] = idx - # Extract the necessary information from `left` or `right`. - if vn in vns_left && !(vn in vns_right) - # `vn` is only in `left`. - val = getindex_internal(left_vnv, vn) - f = gettransform(left_vnv, vn) - is_unconstrained[idx] = is_transformed(left_vnv, vn) - else - # `vn` is either in both or just `right`. - # Note that in a `merge` the right value has precedence. - val = getindex_internal(right_vnv, vn) - f = gettransform(right_vnv, vn) - is_unconstrained[idx] = is_transformed(right_vnv, vn) - end - n = length(val) - r = (offset + 1):(offset + n) - # Update. - append!(vals, val) - push!(ranges, r) - push!(transforms, f) - # Increment `offset`. - offset += n - end - - return VarNamedVector( - varname_to_index, - vns_both, - ranges, - vals, - transforms, - is_unconstrained; - check_consistency=false, - ) -end - -""" - subset(vnv::VarNamedVector, vns::AbstractVector{<:VarName}) - -Return a new `VarNamedVector` containing the values from `vnv` for variables in `vns`. - -Which variables to include is determined by the `VarName`'s `subsumes` relation, meaning -that e.g. `subset(vnv, [@varname(x)])` will include variables like `@varname(x.a[1])`. - -Preserves the order of variables in `vnv`. - -# Examples - -```jldoctest varnamedvector-subset -julia> using DynamicPPL: VarNamedVector, @varname, subset - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(y) => [3.0]); - -julia> subset(vnv, [@varname(x)]) == VarNamedVector(@varname(x) => [1.0, 2.0]) -true - -julia> subset(vnv, [@varname(x[2])]) == VarNamedVector(@varname(x[2]) => [2.0]) -true -""" -function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) - vnv_new = similar(vnv) - # Return early if possible. - isempty(vnv) && return vnv_new - - for vn in vnv.varnames - if any(subsumes(vn_given, vn) for vn_given in vns_given) - insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) - set_transformed!(vnv_new, is_transformed(vnv, vn), vn) - end - end - - return tighten_types!!(vnv_new) -end - -""" - similar(vnv::VarNamedVector) - -Return a new `VarNamedVector` with the same structure as `vnv`, but with empty values. - -In this respect `vnv` behaves more like a dictionary than an array: `similar(vnv)` will -be entirely empty, rather than have `undef` values in it. - -# Examples - -```julia-doctest-varnamedvector-similar -julia> using DynamicPPL: VarNamedVector, @varname, similar - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0], @varname(x[3]) => [3.0]); - -julia> similar(vnv) == VarNamedVector{VarName{:x}, Float64}() -true -""" -function Base.similar(vnv::VarNamedVector) - # NOTE: Whether or not we should empty the underlying containers or not - # is somewhat ambiguous. For example, `similar(vnv.varname_to_index)` will - # result in an empty `AbstractDict`, while the vectors, e.g. `vnv.ranges`, - # will result in non-empty vectors but with entries as `undef`. But it's - # much easier to write the rest of the code assuming that `undef` is not - # present, and so for now we empty the underlying containers, thus differing - # from the behavior of `similar` for `AbstractArray`s. - return VarNamedVector( - empty(vnv.varname_to_index), - similar(vnv.varnames, 0), - similar(vnv.ranges, 0), - similar(vnv.vals, 0), - similar(vnv.transforms, 0), - BitVector(), - empty(vnv.num_inactive); - check_consistency=false, - ) -end - -""" - is_contiguous(vnv::VarNamedVector) - -Returns `true` if the underlying data of `vnv` is stored in a contiguous array. - -This is equivalent to negating [`has_inactive(vnv)`](@ref). -""" -is_contiguous(vnv::VarNamedVector) = !has_inactive(vnv) - -""" - nextrange(vnv::VarNamedVector, x) - -Return the range of `length(x)` from the end of current data in `vnv`. -""" -function nextrange(vnv::VarNamedVector, x) - offset = length(vnv.vals) - return (offset + 1):(offset + length(x)) -end - -""" - shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - -Shifts the elements of `x` starting from index `start` by `n` to the right. -""" -function shift_right!(x::AbstractVector{<:Real}, start::Int, n::Int) - x[(start + n):end] = x[start:(end - n)] - return x -end - -""" - shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) - -Shifts the ranges of variables in `vnv` starting from index `idx` by `n`. -""" -function shift_subsequent_ranges_by!(vnv::VarNamedVector, idx::Int, n) - for i in (idx + 1):length(vnv.ranges) - vnv.ranges[i] = vnv.ranges[i] .+ n - end - return nothing -end - -# set!! is the function defined in utils.jl that tries to do fancy stuff with optics when -# setting the value of a generic container using a VarName. We can bypass all that because -# VarNamedVector handles VarNames natively. However, it's semantics are slightly different -# from setindex!'s: It allows resetting variables that already have a value with values of -# a different type/size. -set!!(vnv::VarNamedVector, vn::VarName, val) = reset!!(vnv, val, vn) - -function setval!(vnv::VarNamedVector, val, vn::VarName) - return setindex_internal!(vnv, tovec(val), vn) -end - -function recontiguify_ranges!(ranges::AbstractVector{<:AbstractRange}) - offset = 0 - for i in 1:length(ranges) - r_old = ranges[i] - ranges[i] = (offset + 1):(offset + length(r_old)) - offset += length(r_old) - end - - return ranges -end - -""" - contiguify!(vnv::VarNamedVector) - -Re-contiguify the underlying vector and shrink if possible. - -# Examples - -```jldoctest varnamedvector-contiguify -julia> using DynamicPPL: VarNamedVector, @varname, contiguify!, update!, has_inactive - -julia> vnv = VarNamedVector(@varname(x) => [1.0, 2.0, 3.0], @varname(y) => [3.0]); - -julia> update!(vnv, [23.0, 24.0], @varname(x)); - -julia> has_inactive(vnv) -true - -julia> length(vnv.vals) -4 - -julia> contiguify!(vnv); - -julia> has_inactive(vnv) -false - -julia> length(vnv.vals) -3 - -julia> vnv[@varname(x)] # All the values are still there. -2-element Vector{Float64}: - 23.0 - 24.0 -``` -""" -function contiguify!(vnv::VarNamedVector) - if !has_inactive(vnv) - return vnv - end - # Extract the re-contiguified values. - # NOTE: We need to do this before we update the ranges. - old_vals = copy(vnv.vals) - old_ranges = copy(vnv.ranges) - # And then we re-contiguify the ranges. - recontiguify_ranges!(vnv.ranges) - # Clear the inactive ranges. - empty!(vnv.num_inactive) - # Now we update the values. - for (old_range, new_range) in zip(old_ranges, vnv.ranges) - vnv.vals[new_range] = old_vals[old_range] - end - # And (potentially) shrink the underlying vector. - resize!(vnv.vals, vnv.ranges[end][end]) - # The rest should be left as is. - return vnv -end - -""" - group_by_symbol(vnv::VarNamedVector) - -Return a dictionary mapping symbols to `VarNamedVector`s with varnames containing that -symbol. - -# Examples - -```jldoctest varnamedvector-group-by-symbol -julia> using DynamicPPL: VarNamedVector, @varname, group_by_symbol - -julia> vnv = VarNamedVector(@varname(x) => [1.0], @varname(y) => [2.0], @varname(x[1]) => [3.0]); - -julia> d = group_by_symbol(vnv); - -julia> collect(keys(d)) -[Symbol("x"), Symbol("y")] - -julia> d[@varname(x)] == VarNamedVector(@varname(x) => [1.0], @varname(x[1]) => [3.0]) -true - -julia> d[@varname(y)] == VarNamedVector(@varname(y) => [2.0]) -true -""" -function group_by_symbol(vnv::VarNamedVector) - symbols = unique(map(getsym, vnv.varnames)) - nt_vals = map(s -> tighten_types!!(subset(vnv, [VarName{s}()])), symbols) - return OrderedDict(zip(symbols, nt_vals)) -end - -""" - shift_index_left!(vnv::VarNamedVector, idx::Int) - -Shift the index `idx` to the left by one and update the relevant fields. - -This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a -helper function for [`shift_subsequent_indices_left!`](@ref). - -!!! warning - This does not check if index we're shifting to is already occupied. -""" -function shift_index_left!(vnv::VarNamedVector, idx::Int) - # Shift the index in the lookup table. - vn = vnv.varnames[idx] - vnv.varname_to_index[vn] = idx - 1 - # Shift the index in the inactive ranges. - if haskey(vnv.num_inactive, idx) - # Done in increasing order => don't need to worry about - # potentially shifting the same index twice. - vnv.num_inactive[idx - 1] = pop!(vnv.num_inactive, idx) - end -end - -""" - shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) - -Shift the indices for all variables after `idx` to the left by one and update the relevant - fields. - -This only affects `vnv.varname_to_index` and `vnv.num_inactive` and is only valid as a -helper function for [`delete!`](@ref). -""" -function shift_subsequent_indices_left!(vnv::VarNamedVector, idx::Int) - # Shift the indices for all variables after `idx`. - for idx_to_shift in (idx + 1):length(vnv.varnames) - shift_index_left!(vnv, idx_to_shift) - end -end - -function Base.delete!(vnv::VarNamedVector, vn::VarName) - # Error if we don't have the variable. - !haskey(vnv, vn) && throw(ArgumentError("variable name $vn does not exist")) - - # Get the index of the variable. - idx = getidx(vnv, vn) - - # Delete the values. - r_start = first(getrange(vnv, idx)) - n_allocated = num_allocated(vnv, idx) - # NOTE: `deleteat!` also results in a `resize!` so we don't need to do that. - deleteat!(vnv.vals, r_start:(r_start + n_allocated - 1)) - - # Delete `vn` from the lookup table. - delete!(vnv.varname_to_index, vn) - - # Delete any inactive ranges corresponding to `vn`. - haskey(vnv.num_inactive, idx) && delete!(vnv.num_inactive, idx) - - # Re-adjust the indices for varnames occuring after `vn` so - # that they point to the correct indices after the deletions below. - shift_subsequent_indices_left!(vnv, idx) - - # Re-adjust the ranges for varnames occuring after `vn`. - shift_subsequent_ranges_by!(vnv, idx, -n_allocated) - - # Delete references from vector fields, thus shifting the indices of - # varnames occuring after `vn` by one to the left, as we adjusted for above. - deleteat!(vnv.varnames, idx) - deleteat!(vnv.ranges, idx) - deleteat!(vnv.transforms, idx) - - return vnv -end - -""" - delete!!(vnv::VarNamedVector, vn::VarName) - -Like `delete!!`, but tightens the element types of the returned `VarNamedVector`. - -# See also: -[`tighten_types!!`](@ref) -""" -BangBang.delete!!(vnv::VarNamedVector, vn::VarName) = tighten_types!!(delete!(vnv, vn)) - -""" - values_as(vnv::VarNamedVector[, T]) - -Return the values/realizations in `vnv` as type `T`, if implemented. - -If no type `T` is provided, return values as stored in `vnv`. - -# Examples - -```jldoctest -julia> using DynamicPPL: VarNamedVector - -julia> vnv = VarNamedVector(@varname(x) => 1, @varname(y) => [2.0]); - -julia> values_as(vnv) == [1.0, 2.0] -true - -julia> values_as(vnv, Vector{Float32}) == Vector{Float32}([1.0, 2.0]) -true - -julia> values_as(vnv, OrderedDict) == OrderedDict(@varname(x) => 1.0, @varname(y) => [2.0]) -true - -julia> values_as(vnv, NamedTuple) == (x = 1.0, y = [2.0]) -true -``` -""" -values_as(vnv::VarNamedVector) = values_as(vnv, Vector) -values_as(vnv::VarNamedVector, ::Type{Vector}) = getindex_internal(vnv, :) -function values_as(vnv::VarNamedVector, ::Type{Vector{T}}) where {T} - return convert(Vector{T}, values_as(vnv, Vector)) -end -function values_as(vnv::VarNamedVector, ::Type{NamedTuple}) - return NamedTuple(zip(map(Symbol, keys(vnv)), values(vnv))) -end -function values_as(vnv::VarNamedVector, ::Type{D}) where {D<:AbstractDict} - return ConstructionBase.constructorof(D)(pairs(vnv)) -end - -# See the docstring of `getvalue` for the semantics of `hasvalue` and `getvalue`, and how -# they differ from `haskey` and `getindex`. They can be found in AbstractPPL.jl. - -# TODO(mhauru) This is tricky to implement in the general case, and the below implementation -# only covers some simple cases. It's probably sufficient in most situations though. -function hasvalue(vnv::VarNamedVector, vn::VarName) - haskey(vnv, vn) && return true - any(subsumes(vn, k) for k in keys(vnv)) && return true - # Handle the easy case where the right symbol isn't even present. - !any(k -> getsym(k) == getsym(vn), keys(vnv)) && return false - - optic = getoptic(vn) - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - # If vn is of the form @varname(somesymbol[someindex]), we check whether we store - # @varname(somesymbol) and can index into it with someindex. If we rather have a - # composed optic with the last part being an index lens, we do a similar check but - # stripping out the last index lens part. If these pass, the answer is definitely - # "yes". If not, we still don't know for sure. - # TODO(mhauru) What about casese where vnv stores both @varname(x) and - # @varname(x[1]) or @varname(x.a)? Those should probably be banned, but currently - # aren't. - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - if haskey(vnv, parent_varname) - valvec = getindex(vnv, parent_varname) - return canview(head, valvec) - end - end - throw(ErrorException("hasvalue has not been fully implemented for this VarName: $(vn)")) -end - -# TODO(mhauru) Like hasvalue, this is only partially implemented. -function getvalue(vnv::VarNamedVector, vn::VarName) - !hasvalue(vnv, vn) && throw(KeyError(vn)) - haskey(vnv, vn) && getindex(vnv, vn) - - subsumed_keys = filter(k -> subsumes(vn, k), keys(vnv)) - if length(subsumed_keys) > 0 - # TODO(mhauru) What happens if getindex returns e.g. matrices, and we vcat them? - return mapreduce(k -> getindex(vnv, k), vcat, subsumed_keys) - end - - optic = getoptic(vn) - # See hasvalue for some comments on the logic of this if block. - if optic isa Accessors.IndexLens || optic isa Accessors.ComposedOptic - head, tail = if optic isa Accessors.ComposedOptic - decomp_optic = Accessors.decompose(optic) - first(decomp_optic), Accessors.compose(decomp_optic[2:end]...) - else - optic, identity - end - parent_varname = VarName{getsym(vn)}(tail) - valvec = getindex(vnv, parent_varname) - return head(valvec) - end - throw(ErrorException("getvalue has not been fully implemented for this VarName: $(vn)")) -end - -Base.get(vnv::VarNamedVector, vn::VarName) = getvalue(vnv, vn) diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index 6ce1a861e..b0cafa364 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -17,6 +17,15 @@ VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) +function VNTVarInfo(values::Union{NamedTuple,AbstractDict}) + vi = VarInfo() + for (k, v) in pairs(values) + vn = k isa Symbol ? VarName{k}() : k + vi = setindex!!(vi, v, vn) + end + return vi +end + function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) return VNTVarInfo(Random.default_rng(), model, init_strategy) end diff --git a/test/model.jl b/test/model.jl index 7c5dc2fcc..281eaaad4 100644 --- a/test/model.jl +++ b/test/model.jl @@ -25,9 +25,6 @@ function innermost_distribution_type(d::Distributions.Product) return dists[1] end -is_type_stable_varinfo(::DynamicPPL.AbstractVarInfo) = false -is_type_stable_varinfo(varinfo::DynamicPPL.VNTVarInfo) = true - const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @testset "model.jl" begin @@ -221,7 +218,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() @test !any(map(x -> x isa DynamicPPL.AbstractVarInfo, call_retval)) end - @testset "Dynamic constraints, Metadata" begin + @testset "Dynamic constraints" begin model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) vi = link!!(vi, model) @@ -415,10 +412,7 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end vns = DynamicPPL.TestUtils.varnames(model) example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = filter( - is_type_stable_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), - ) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos @test begin @inferred(DynamicPPL.evaluate!!(model, varinfo)) diff --git a/test/runtests.jl b/test/runtests.jl index 23dda437b..6521f1e4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,9 +52,7 @@ include("test_util.jl") include("accumulators.jl") include("compiler.jl") include("varnamedtuple.jl") - # include("varnamedvector.jl") include("varinfo.jl") - # include("simple_varinfo.jl") include("model.jl") include("distribution_wrappers.jl") include("linking.jl") diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl deleted file mode 100644 index 2c0e21bec..000000000 --- a/test/simple_varinfo.jl +++ /dev/null @@ -1,345 +0,0 @@ -@testset "simple_varinfo.jl" begin - @testset "constructor & indexing" begin - @testset "NamedTuple" begin - svi = SimpleVarInfo(; m=1.0) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(; m=[1.0]) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(; m=(a=[1.0],)) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo{Float32}(; m=1.0) - @test getlogjoint(svi) isa Float32 - - svi = SimpleVarInfo((m=1.0,)) - svi = accloglikelihood!!(svi, 1.0) - @test getlogjoint(svi) == 1.0 - end - - @testset "Dict" begin - svi = SimpleVarInfo(OrderedDict(@varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(OrderedDict(@varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(OrderedDict(@varname(m) => (a=[1.0],))) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - - svi = SimpleVarInfo(OrderedDict(@varname(m.a) => [1.0])) - # Now we only have a variable `m.a` which is subsumed by `m`, - # but we can't guarantee that we have the "entire" `m`. - @test !haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - end - - @testset "VarNamedVector" begin - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => 1.0)) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test !haskey(svi, @varname(m[1])) - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m) => [1.0])) - @test getlogjoint(svi) == 0.0 - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m[1])) - @test !haskey(svi, @varname(m[2])) - @test svi[@varname(m)][1] == svi[@varname(m[1])] - - svi = SimpleVarInfo(push!!(DynamicPPL.VarNamedVector(), @varname(m.a) => [1.0])) - @test haskey(svi, @varname(m)) - @test haskey(svi, @varname(m.a)) - @test haskey(svi, @varname(m.a[1])) - @test !haskey(svi, @varname(m.a[2])) - @test !haskey(svi, @varname(m.a.b)) - # The implementation of haskey and getvalue fo VarNamedVector is incomplete, the - # next test is here to remind of us that. - svi = SimpleVarInfo( - push!!(DynamicPPL.VarNamedVector(), @varname(m.a.b) => [1.0]) - ) - @test_broken !haskey(svi, @varname(m.a.b.c.d)) - end - end - - @testset "link!! & invlink!! on $(nameof(model))" for model in - DynamicPPL.TestUtils.ALL_MODELS - values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$name" for (name, vi) in ( - ("SVI{Dict}", SimpleVarInfo(OrderedDict{VarName,Any}())), - ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), - ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), - ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), - ) - if name == "SVI{NamedTuple}" && - model.f === DynamicPPL.TestUtils.demo_one_variable_multiple_constraints - # TODO(mhauru) There's a bug in SimpleVarInfo{<:NamedTuple} for cases where - # a variable set with IndexLenses changes dimension under linking. This - # makes the link!! call crash. The below call to @test just marks the fact - # that there's something broken here. - @test false broken = true - continue - end - for vn in DynamicPPL.TestUtils.varnames(model) - vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) - end - vi = last(DynamicPPL.evaluate!!(model, vi)) - - # Calculate ground truth - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_constrained... - ) - - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_unlinked = getlogjoint(vi_linked) - lp_linked = getlogjoint_internal(vi_linked) - @test lp_linked ≈ lp_linked_true - @test lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_linked) ≈ lp_unlinked - - # `invlink!!` - vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_unlinked = getlogjoint(vi_invlinked) - also_lp_unlinked = getlogjoint_internal(vi_invlinked) - @test lp_unlinked ≈ lp_unlinked_true - @test also_lp_unlinked ≈ lp_unlinked_true - @test logjoint(model, vi_invlinked) ≈ lp_unlinked - - # Should result in same values. - @test all( - DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_invlinked, vn)) ≈ - DynamicPPL.tovec(get(values_constrained, vn)) for - vn in DynamicPPL.TestUtils.varnames(model) - ) - end - end - - @testset "SimpleVarInfo on $(nameof(model))" for model in - DynamicPPL.TestUtils.ALL_MODELS - if model.f === DynamicPPL.TestUtils.demo_nested_colons - # TODO(mhauru) Either VarNamedVector or SimpleVarInfo has a bug that causes - # the push!! below to fail with a NamedTuple variable like what - # demo_nested_colons has. I don't want to fix it now though, because this may - # all go soon (as of 2025-12-16). - @test false broken = true - continue - end - # We might need to pre-allocate for the variable `m`, so we need - # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.rand_prior_true(model)) - svi_dict = SimpleVarInfo(VarInfo(model), Dict) - vnv = DynamicPPL.VarNamedVector() - for (k, v) in pairs(DynamicPPL.TestUtils.rand_prior_true(model)) - vnv = push!!(vnv, VarName{k}() => v) - end - svi_vnv = SimpleVarInfo(vnv) - - @testset "$name" for (name, svi) in ( - ("NamedTuple", svi_nt), - ("Dict", svi_dict), - ("VarNamedVector", svi_vnv), - # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), - # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), - ) - # Random seed is set in each `@testset`, so we need to sample - # a new realization for `m` here. - retval = model() - - ### Sampling ### - # Sample a new varinfo! - _, svi_new = DynamicPPL.init!!(model, svi) - - # Realization for `m` should be different wp. 1. - for vn in DynamicPPL.TestUtils.varnames(model) - @test svi_new[vn] != get(retval, vn) - end - - # Logjoint should be non-zero wp. 1. - @test getlogjoint(svi_new) != 0 - - ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.is_transformed(svi) - _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, values_eval_constrained... - ) - # Make sure that these two computation paths provide the same - # transformed values. - @test values_eval == _values_prior - else - logpri_true = DynamicPPL.TestUtils.logprior_true( - model, values_eval_constrained... - ) - logπ_true = DynamicPPL.TestUtils.logjoint_true( - model, values_eval_constrained... - ) - values_eval = values_eval_constrained - end - - # No logabsdet-jacobian correction needed for the likelihood. - loglik_true = DynamicPPL.TestUtils.loglikelihood_true( - model, values_eval_constrained... - ) - - # Update the realizations in `svi_new`. - svi_eval = svi_new - for vn in DynamicPPL.TestUtils.varnames(model) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) - end - - # Reset the logp accumulators. - svi_eval = DynamicPPL.resetaccs!!(svi_eval) - - # Compute `logjoint` using the varinfo. - logπ = logjoint(model, svi_eval) - logpri = logprior(model, svi_eval) - loglik = loglikelihood(model, svi_eval) - - # Values should not have changed. - for vn in DynamicPPL.TestUtils.varnames(model) - # TODO(mhauru) Workaround for - # https://github.com/JuliaLang/LinearAlgebra.jl/pull/1404 - # Remove once the fix is all Julia versions we support. - val = get(values_eval, vn) - if val isa Cholesky - @test svi_eval[vn].L == val.L - else - @test svi_eval[vn] == val - end - end - - # Compare log-probability computations. - @test logpri ≈ logpri_true - @test loglik ≈ loglik_true - @test logπ ≈ logπ_true - end - end - - @testset "Dynamic constraints" begin - model = DynamicPPL.TestUtils.demo_dynamic_constraint() - - # Initialize. - svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) - svi_nt = last(DynamicPPL.init!!(model, svi_nt)) - svi_vnv = DynamicPPL.set_transformed!!( - SimpleVarInfo(DynamicPPL.VarNamedVector()), true - ) - svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) - - for svi in (svi_nt, svi_vnv) - # Sample with large variations in unconstrained space. - for i in 1:10 - for vn in keys(svi) - svi = DynamicPPL.setindex!!(svi, 10 * randn(), vn) - end - retval, svi = DynamicPPL.evaluate!!(model, svi) - @test retval.m == svi[@varname(m)] # `m` is unconstrained - @test retval.x ≠ svi[@varname(x)] # `x` is constrained depending on `m` - - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.m, retval.x - ) - - # Realizations from model should all be equal to the unconstrained realization. - for vn in DynamicPPL.TestUtils.varnames(model) - @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 - end - - # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint_internal(svi) - # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 - @test lp ≈ lp_true atol = 1.2e-5 - end - end - end - - @testset "Static transformation" begin - model = DynamicPPL.TestUtils.demo_static_transformation() - - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, DynamicPPL.TestUtils.rand_prior_true(model), [@varname(s), @varname(m)] - ) - @testset "$(short_varinfo_name(vi))" for vi in varinfos - # Initialize varinfo and link. - vi_linked = DynamicPPL.link!!(vi, model) - - # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.is_transformed( - DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) - ) - - # Resulting varinfo should no longer be transformed. - vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) - @test !DynamicPPL.is_transformed(vi_result) - - # Set the values to something that is out of domain if we're in constrained space. - for vn in keys(vi) - vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) - end - - # NOTE: Evaluating a linked VarInfo, **specifically when the transformation - # is static**, will result in an invlinked VarInfo. This is because of - # `maybe_invlink_before_eval!`, which only invlinks if the transformation - # is static. (src/abstract_varinfo.jl) - retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ - DynamicPPL.tovec(retval.s) # `s` is unconstrained in original - @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) - ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result - - # `m` should not be transformed. - @test vi_linked[@varname(m)] == retval.m - @test vi_unlinked_again[@varname(m)] == retval.m - - # Get ground truths - retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( - model, retval.s, retval.m - ) - lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) - - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ - DynamicPPL.tovec(retval_unconstrained.s) - @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ - DynamicPPL.tovec(retval_unconstrained.m) - - # The unlinked varinfo should hold the unlinked logp. - lp_unlinked = getlogjoint(vi_unlinked_again) - @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true - end - end -end diff --git a/test/test_util.jl b/test/test_util.jl index 821b1e0db..9f6939adf 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,29 +16,6 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -# function short_varinfo_name(vi::DynamicPPL.NTVarInfo) -# return if DynamicPPL.has_varnamedvector(vi) -# "TypedVectorVarInfo" -# else -# "TypedVarInfo" -# end -# end -# short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" -# short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" -function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) - return "SimpleVarInfo{<:NamedTuple,<:Ref}" -end -function short_varinfo_name(::SimpleVarInfo{<:OrderedDict,<:Ref}) - return "SimpleVarInfo{<:OrderedDict,<:Ref}" -end -# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector,<:Ref}) -# return "SimpleVarInfo{<:VarNamedVector,<:Ref}" -# end -short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}" -short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}" -# function short_varinfo_name(::SimpleVarInfo{<:DynamicPPL.VarNamedVector}) -# return "SimpleVarInfo{<:VarNamedVector}" -# end function short_varinfo_name(::DynamicPPL.VNTVarInfo) return "VNTVarInfo" end diff --git a/test/varinfo.jl b/test/varinfo.jl index 1d01a0cf8..8ae0535c7 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,17 +1,6 @@ function check_varinfo_keys(varinfo, vns) - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} - # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, - # since `keys(varinfo_merged)` only contains `VarName` with `identity`. - # So we just check that the original keys are present. - for vn in vns - # Should have all the original keys. - @test haskey(varinfo, vn) - end - else - vns_varinfo = keys(varinfo) - # Should be equivalent. - @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) - end + vns_varinfo = keys(varinfo) + @test union(vns_varinfo, vns) == intersect(vns_varinfo, vns) end @testset "varinfo.jl" begin @@ -446,13 +435,9 @@ end varinfos = DynamicPPL.TestUtils.setup_varinfos( model, model(), vns; include_threadsafe=true ) - varinfos_standard = filter(Base.Fix2(isa, VarInfo), varinfos) - varinfos_simple = filter( - Base.Fix2(isa, DynamicPPL.SimpleOrThreadSafeSimple), varinfos - ) # `VarInfo` supports subsetting using, basically, arbitrary varnames. - vns_supported_standard = [ + vns_supported = [ [@varname(s)], [@varname(m)], [@varname(x[1])], @@ -477,25 +462,10 @@ end [@varname(s), @varname(m), @varname(x[1]), @varname(x[2])], ] - # `SimpleVarInfo` only supports subsetting using the varnames as they appear - # in the model. - vns_supported_simple = filter(∈(vns), vns_supported_standard) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos # All variables. check_varinfo_keys(varinfo, vns) - # Added a `convert` to make the naming of the testsets a bit more readable. - # `SimpleVarInfo{<:NamedTuple}` only supports subsetting with "simple" varnames, - ## i.e. `VarName{sym}()` without any indexing, etc. - vns_supported = - if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple && - values_as(varinfo) isa NamedTuple - vns_supported_simple - else - vns_supported_standard - end - @testset ("$(convert(Vector{VarName}, vns_subset)) empty") for vns_subset in vns_supported varinfo_subset = subset(varinfo, VarName[]) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl deleted file mode 100644 index 9a4ef12c3..000000000 --- a/test/varnamedvector.jl +++ /dev/null @@ -1,711 +0,0 @@ -replace_sym(vn::VarName, sym_new::Symbol) = VarName{sym_new}(vn.lens) - -increase_size_for_test(x::Real) = [x] -increase_size_for_test(x::AbstractArray) = repeat(x, 2) - -decrease_size_for_test(x::Real) = x -decrease_size_for_test(x::AbstractVector) = first(x) -decrease_size_for_test(x::AbstractArray) = first(eachslice(x; dims=1)) - -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.varnames)) - # If the container is concrete, we need to make sure that the varname types match. - # E.g. if `vnv.varnames` has `eltype` `VarName{:x, IndexLens{Tuple{Int64}}}` then - # we need `vn` to also be of this type. - # => If the varname types don't match, we need to relax the container type. - return any(keys(vnv)) do vn_present - typeof(vn_present) !== typeof(val) - end - end - - return false -end -function need_varnames_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_varnames_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - if isconcretetype(eltype(vnv.vals)) - return promote_type(eltype(vnv.vals), eltype(val)) != eltype(vnv.vals) - end - - return false -end -function need_values_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_values_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return if isconcretetype(eltype(vnv.transforms)) - # If the container is concrete, we need to make sure that the sizes match. - # => If the sizes don't match, we need to relax the container type. - any(keys(vnv)) do vn_present - size(vnv[vn_present]) != size(val) - end - elseif eltype(vnv.transforms) !== Any - # If it's not concrete AND it's not `Any`, then we should just make it `Any`. - true - else - # Otherwise, it's `Any`, so we don't need to relax the container type. - false - end -end -function need_transforms_relaxation(vnv::DynamicPPL.VarNamedVector, vns, vals) - return any(need_transforms_relaxation(vnv, vn, val) for (vn, val) in zip(vns, vals)) -end - -""" - relax_container_types(vnv::VarNamedVector, vn::VarName, val) - relax_container_types(vnv::VarNamedVector, vns, val) - -Relax the container types of `vnv` if necessary to accommodate `vn` and `val`. - -This attempts to avoid unnecessary container type relaxations by checking whether -the container types of `vnv` are already compatible with `vn` and `val`. - -# Notes -For example, if `vn` is not compatible with the current keys in `vnv`, then -the underlying types will be changed to `VarName` to accommodate `vn`. - -Similarly: -- If `val` is not compatible with the current values in `vnv`, then - the underlying value type will be changed to `Real`. -- If `val` requires a transformation that is not compatible with the current - transformations type in `vnv`, then the underlying transformation type will - be changed to `Any`. -""" -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) - return relax_container_types(vnv, [vn], [val]) -end -function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) - if need_varnames_relaxation(vnv, vns, vals) - varname_to_index_new = convert(Dict{VarName,Int}, vnv.varname_to_index) - varnames_new = convert(Vector{VarName}, vnv.varnames) - else - varname_to_index_new = vnv.varname_to_index - varnames_new = vnv.varnames - end - - transforms_new = if need_transforms_relaxation(vnv, vns, vals) - convert(Vector{Any}, vnv.transforms) - else - vnv.transforms - end - - vals_new = if need_values_relaxation(vnv, vns, vals) - convert(Vector{Real}, vnv.vals) - else - vnv.vals - end - - return DynamicPPL.VarNamedVector( - varname_to_index_new, - varnames_new, - vnv.ranges, - vals_new, - transforms_new, - vnv.is_unconstrained, - vnv.num_inactive, - ) -end - -@testset "VarNamedVector" begin - # Test element-related operations: - # - `getindex` - # - `setindex!` - # - `push!` - # - `update!` - # - `insert!` - # - `reset!` - # - `_internal!` versions of the above - # - !! versions of the above - # - # And these are all be tested for different types of values: - # - scalar - # - vector - # - matrix - - # Test operations on `VarNamedVector`: - # - `empty!` - # - `iterate` - # - `convert` to - # - `AbstractDict` - test_pairs = OrderedDict( - @varname(x[1]) => rand(), - @varname(x[2]) => rand(2), - @varname(x[3]) => rand(2, 3), - @varname(y[1]) => rand(), - @varname(y[2]) => rand(2), - @varname(y[3]) => rand(2, 3), - @varname(z[1]) => rand(1:10), - @varname(z[2]) => rand(1:10, 2), - @varname(z[3]) => rand(1:10, 2, 3), - ) - test_vns = collect(keys(test_pairs)) - test_vals = collect(values(test_pairs)) - - @testset "constructor: no args" begin - # Empty. - vnv = DynamicPPL.VarNamedVector() - @test isempty(vnv) - @test eltype(vnv) == Union{} - - # Empty with types. - vnv = DynamicPPL.VarNamedVector{VarName,Float64,typeof(identity)}() - @test isempty(vnv) - @test eltype(vnv) == Float64 - end - - test_varnames_iter = combinations(test_vns, 2) - @testset "$(vn_left) and $(vn_right)" for (vn_left, vn_right) in test_varnames_iter - val_left = test_pairs[vn_left] - val_right = test_pairs[vn_right] - vnv_base = DynamicPPL.VarNamedVector([vn_left, vn_right], [val_left, val_right]) - - # We'll need the transformations later. - # TODO: Should we test other transformations than just `ReshapeTransform`? - from_vec_left = DynamicPPL.from_vec_transform(val_left) - from_vec_right = DynamicPPL.from_vec_transform(val_right) - to_vec_left = inverse(from_vec_left) - to_vec_right = inverse(from_vec_right) - - # Compare to alternative constructors. - vnv_from_dict = DynamicPPL.VarNamedVector( - OrderedDict(vn_left => val_left, vn_right => val_right) - ) - @test vnv_base == vnv_from_dict - - # We want the types of fields such as `varnames` and `transforms` to specialize - # whenever possible + some functionality, e.g. `push!`, is only sensible - # if the underlying containers can support it. - # Expected behavior - should_have_restricted_varname_type = typeof(vn_left) == typeof(vn_right) - should_have_restricted_transform_type = size(val_left) == size(val_right) - # Actual behavior - has_restricted_transform_type = isconcretetype(eltype(vnv_base.transforms)) - has_restricted_varname_type = isconcretetype(eltype(vnv_base.varnames)) - - @testset "type specialization" begin - @test !should_have_restricted_varname_type || has_restricted_varname_type - @test !should_have_restricted_transform_type || has_restricted_transform_type - end - - @test eltype(vnv_base) == promote_type(eltype(val_left), eltype(val_right)) - @test DynamicPPL.length_internal(vnv_base) == length(val_left) + length(val_right) - @test length(vnv_base) == 2 - - @test !isempty(vnv_base) - - @testset "empty!" begin - vnv = deepcopy(vnv_base) - empty!(vnv) - @test isempty(vnv) - end - - @testset "similar" begin - vnv = similar(vnv_base) - @test isempty(vnv) - @test typeof(vnv) == typeof(vnv_base) - end - - @testset "getindex" begin - # With `VarName` index. - @test vnv_base[vn_left] == val_left - @test vnv_base[vn_right] == val_right - end - - @testset "getindex_internal" begin - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_left) == - to_vec_left(val_left) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, vn_right) == - to_vec_right(val_right) - end - - @testset "getindex_internal with Ints" begin - for (i, val) in enumerate(to_vec_left(val_left)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, i) == val - end - offset = length(to_vec_left(val_left)) - for (i, val) in enumerate(to_vec_right(val_right)) - @test DynamicPPL.DynamicPPL.getindex_internal(vnv_base, offset + i) == val - end - end - - @testset "update!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!" begin - vnv = deepcopy(vnv_base) - DynamicPPL.update_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.update_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "update_internal!!" begin - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.update_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.update_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "delete!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - @test !haskey(vnv, vn_left) - @test haskey(vnv, vn_right) - delete!(vnv, vn_right) - @test !haskey(vnv, vn_right) - end - - @testset "insert!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert!!(vnv, val_left .+ 100, vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert!!(vnv, val_right .+ 100, vn_right) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - DynamicPPL.insert_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.insert_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "insert_internal!!" begin - vnv = deepcopy(vnv_base) - delete!(vnv, vn_left) - delete!(vnv, vn_right) - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, from_vec_left - ) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.insert_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, from_vec_right - ) - @test vnv[vn_right] == val_right .+ 100 - end - - @testset "merge" begin - # When there are no inactive entries, `merge` on itself result in the same. - @test merge(vnv_base, vnv_base) == vnv_base - - # Merging with empty should result in the same. - @test merge(vnv_base, similar(vnv_base)) == vnv_base - @test merge(similar(vnv_base), vnv_base) == vnv_base - - # With differences. - vnv_left_only = deepcopy(vnv_base) - delete!(vnv_left_only, vn_right) - vnv_right_only = deepcopy(vnv_base) - delete!(vnv_right_only, vn_left) - - # `(x,)` and `(x, y)` should be `(x, y)`. - @test merge(vnv_left_only, vnv_base) == vnv_base - # `(x, y)` and `(x,)` should be `(x, y)`. - @test merge(vnv_base, vnv_left_only) == vnv_base - # `(x, y)` and `(y,)` should be `(x, y)`. - @test merge(vnv_base, vnv_right_only) == vnv_base - # `(y,)` and `(x, y)` should be `(y, x)`. - vnv_merged = merge(vnv_right_only, vnv_base) - @test vnv_merged != vnv_base - @test collect(keys(vnv_merged)) == [vn_right, vn_left] - end - - @testset "push!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - vnv_copy = deepcopy(vnv) - push!(vnv, (vn => val)) - @test vnv[vn] == val - end - end - - @testset "setindex_internal!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - DynamicPPL.setindex_internal!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - DynamicPPL.setindex_internal!(vnv, to_vec_right(val_right .+ 100), vn_right) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.loosen_types!!( - vnv, typeof(vn_left), eltype(vnv), typeof(increment) - ) - DynamicPPL.setindex_internal!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.loosen_types!!( - vnv, typeof(vn_right), eltype(vnv), typeof(increment) - ) - DynamicPPL.setindex_internal!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - DynamicPPL.setindex_internal!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex_internal! with Ints" begin - vnv = deepcopy(vnv_base) - for i in 1:DynamicPPL.length_internal(vnv_base) - DynamicPPL.setindex_internal!(vnv, i, i) - end - for i in 1:DynamicPPL.length_internal(vnv_base) - @test DynamicPPL.getindex_internal(vnv, i) == i - end - end - - @testset "setindex_internal!!" begin - # Not setting the transformation. - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_left(val_left .+ 100), vn_left) - @test vnv[vn_left] == val_left .+ 100 - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right - ) - @test vnv[vn_right] == val_right .+ 100 - - # Explicitly setting the transformation. - # Note that unlike with setindex_internal!, we don't need loosen_types!! here. - increment(x) = x .+ 10 - vnv = deepcopy(vnv_base) - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_left(val_left .+ 100), vn_left, increment - ) - @test vnv[vn_left] == to_vec_left(val_left .+ 110) - - vnv = DynamicPPL.setindex_internal!!( - vnv, to_vec_right(val_right .+ 100), vn_right, increment - ) - @test vnv[vn_right] == to_vec_right(val_right .+ 110) - - # Adding new values. - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - from_vec_vn = DynamicPPL.from_vec_transform(val) - to_vec_vn = inverse(from_vec_vn) - vnv = DynamicPPL.setindex_internal!!(vnv, to_vec_vn(val), vn, from_vec_vn) - @test vnv[vn] == val - end - end - - @testset "setindex! and reset!" begin - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn" for vn in test_vns - val = test_pairs[vn] - expected_length = if haskey(vnv, vn) - # If it's already present, the resulting length will be unchanged. - DynamicPPL.length_internal(vnv) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - vnv[vn] = val .+ 1 - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - - # There should be no redundant values in the underlying vector. - @test !DynamicPPL.has_inactive(vnv) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (increased size)" for vn in test_vns - val_original = test_pairs[vn] - val = increase_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - - vnv = relax_container_types(deepcopy(vnv_base), test_vns, test_vals) - @testset "$vn (decreased size)" for vn in test_vns - val_original = test_pairs[vn] - val = decrease_size_for_test(val_original) - vn_already_present = haskey(vnv, vn) - expected_length = if vn_already_present - # If it's already present, the resulting length will be altered. - DynamicPPL.length_internal(vnv) + length(val) - length(val_original) - else - DynamicPPL.length_internal(vnv) + length(val) - end - - # Have to use reset!, because setindex! doesn't support decreasing size. - DynamicPPL.reset!(vnv, val .+ 1, vn) - x = DynamicPPL.getindex_internal(vnv, :) - @test vnv[vn] == val .+ 1 - @test DynamicPPL.length_internal(vnv) == expected_length - @test length(x) == DynamicPPL.length_internal(vnv) - @test all( - DynamicPPL.getindex_internal(vnv, i) == x[i] for i in eachindex(x) - ) - end - end - end - - @testset "growing and shrinking" begin - @testset "deterministic" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - # Growing should not create inactive ranges. - for i in 1:n - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - end - - # Same size should not create inactive ranges. - x = fill(true, n) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test !DynamicPPL.has_inactive(vnv) - - # Shrinking should create inactive ranges. - for i in (n - 1):-1:1 - x = fill(true, i) - DynamicPPL.update_internal!(vnv, x, vn, identity) - @test DynamicPPL.has_inactive(vnv) - @test DynamicPPL.num_inactive(vnv, vn) == n - i - end - end - - @testset "random" begin - n = 5 - vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) - @test !DynamicPPL.has_inactive(vnv) - - # Insert a bunch of random-length vectors. - for i in 1:100 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - end - # Should never be allocating more than `n` elements. - @test DynamicPPL.num_allocated(vnv, vn) ≤ n - - # If we compaticfy, then it should always be the same size as just inserted. - for i in 1:10 - x = fill(true, rand(1:n)) - DynamicPPL.update!(vnv, x, vn) - DynamicPPL.contiguify!(vnv) - @test DynamicPPL.num_allocated(vnv, vn) == length(x) - end - end - end - - @testset "subset" begin - vnv = DynamicPPL.VarNamedVector(test_pairs) - @test subset(vnv, test_vns) == vnv - @test subset(vnv, VarName[]) == DynamicPPL.VarNamedVector() - @test merge(subset(vnv, test_vns[1:3]), subset(vnv, test_vns[4:end])) == vnv - - # Test that subset preserves transformations and unconstrainedness. - vn = @varname(t[1]) - vns = vcat(test_vns, [vn]) - vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) - DynamicPPL.set_transformed!(vnv, true, @varname(t[1])) - @test vnv[@varname(t[1])] == [4.0] - @test is_transformed(vnv, @varname(t[1])) - @test subset(vnv, vns) == vnv - end - - @testset "loosen and tighten types" begin - """ - test_tightenability(vnv::VarNamedVector) - - Test that tighten_types!! is a no-op on `vnv`. - """ - function test_tightenability(vnv::DynamicPPL.VarNamedVector) - @test vnv == DynamicPPL.tighten_types!!(deepcopy(vnv)) - # TODO(mhauru) We would like to check something more stringent here, namely that - # the operation is compiled to a direct no-op, with no instructions at all. I - # don't know how to do that though, so for now we just check that it doesn't - # allocate. - @allocations(DynamicPPL.tighten_types!!(vnv)) == 0 - return nothing - end - - vn = @varname(a[1]) - # Test that tighten_types!! is a no-op on an empty VarNamedVector. - vnv = DynamicPPL.VarNamedVector() - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - # Also check that it literally returns the same object, and both tighten and loosen - # are type stable. - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise for a VarNamedVector with something pushed into it. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise for a VarNamedVector with abstract element-types, when that is needed for - # the current contents because mixed types have been pushed into it. However, this - # time, since the types are only as tight as they can be, but not actually concrete, - # tighten_types!! can't be type stable. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - vnv = setindex!!(vnv, 2, @varname(b)) - @test !DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - # Likewise when first mixed types are pushed, but then deleted. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - vnv = setindex!!(vnv, 2, @varname(b)) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = delete!!(vnv, vn) - @test DynamicPPL.is_tightly_typed(vnv) - test_tightenability(vnv) - @test vnv === DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.tighten_types!!(vnv) - @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) - - # Test that loosen_types!! does really loosen them and that tighten_types!! reverts - # that. - vnv = DynamicPPL.VarNamedVector() - vnv = setindex!!(vnv, 1.0, vn) - @test DynamicPPL.is_tightly_typed(vnv) - k = eltype(vnv.varnames) - e = eltype(vnv.vals) - t = eltype(vnv.transforms) - # Loosen key type. - vnv = @inferred DynamicPPL.loosen_types!!(vnv, VarName, e, t) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosen element type - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, Real, t) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosen transformation type - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, Function) - @test !DynamicPPL.is_tightly_typed(vnv) - vnv = DynamicPPL.tighten_types!!(vnv) - @test DynamicPPL.is_tightly_typed(vnv) - # Loosening to the same types as currently should do nothing. - vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, t) - @test DynamicPPL.is_tightly_typed(vnv) - @allocations(DynamicPPL.loosen_types!!(vnv, k, e, t)) == 0 - end -end - -@testset "VarInfo + VarNamedVector" begin - models = DynamicPPL.TestUtils.ALL_MODELS - @testset "$(model.f)" for model in models - # NOTE: Need to set random seed explicitly to avoid using the same seed - # for initialization as for sampling in the inner testset below. - Random.seed!(42) - value_true = DynamicPPL.TestUtils.rand_prior_true(model) - vns = DynamicPPL.TestUtils.varnames(model) - varnames = DynamicPPL.TestUtils.varnames(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos( - model, value_true, varnames; include_threadsafe=false - ) - # Filter out those which are not based on `VarNamedVector`. - varinfos = filter(DynamicPPL.has_varnamedvector, varinfos) - # Get the true log joint. - logp_true = DynamicPPL.TestUtils.logjoint_true(model, value_true...) - - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - # Need to make sure we're using a different random seed from the - # one used in the above call to `rand_prior_true`. - Random.seed!(43) - - # Are values correct? - DynamicPPL.TestUtils.test_values(varinfo, value_true, vns) - - # Is evaluation correct? - varinfo_eval = last(DynamicPPL.evaluate!!(model, deepcopy(varinfo))) - # Log density should be the same. - @test getlogjoint(varinfo_eval) ≈ logp_true - # Values should be the same. - DynamicPPL.TestUtils.test_values(varinfo_eval, value_true, vns) - - # Is sampling correct? - varinfo_sample = last(DynamicPPL.init!!(model, deepcopy(varinfo))) - # Log density should be different. - @test getlogjoint(varinfo_sample) != getlogjoint(varinfo) - # Values should be different. - DynamicPPL.TestUtils.test_values( - varinfo_sample, value_true, vns; compare=!isequal - ) - end - end -end From 8ba36f6dffd09c69aaa395bf27245d18ef283de9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:10:10 +0000 Subject: [PATCH 32/56] Fix a lot of doctests --- src/abstract_varinfo.jl | 44 ++++++++++++++++++++------------------- src/model.jl | 8 +++---- src/values_as_in_model.jl | 23 ++++++++------------ src/vntvarinfo.jl | 4 ++++ 4 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 1c5159626..c4af10898 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -506,13 +506,15 @@ If no `Type` is provided, return values as stored in `varinfo`. julia> # Just use an example model to construct the `VarInfo` because we're lazy. vi = DynamicPPL.VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); -julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; +julia> vi = DynamicPPL.setindex!!(vi, 1.0, @varname(s)); + +julia> vi = DynamicPPL.setindex!!(vi, 2.0, @varname(m)); julia> values_as(vi, NamedTuple) (s = 1.0, m = 2.0) julia> values_as(vi, OrderedDict) -OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: +OrderedDict{Any, Any} with 2 entries: s => 1.0 m => 2.0 @@ -570,20 +572,20 @@ demo (generic function with 2 methods) julia> model = demo(); -julia> varinfo = VarInfo(model); +julia> vi = VarInfo(model); -julia> keys(varinfo) +julia> keys(vi) 4-element Vector{VarName}: s m x[1] x[2] -julia> for (i, vn) in enumerate(keys(varinfo)) - varinfo[vn] = i +julia> for (i, vn) in enumerate(keys(vi)) + vi = DynamicPPL.setindex!!(vi, Float64(i), vn) end -julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +julia> vi[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4-element Vector{Float64}: 1.0 2.0 @@ -591,59 +593,59 @@ julia> varinfo[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4.0 julia> # Extract one with only `m`. - varinfo_subset1 = subset(varinfo, [@varname(m),]); + vi_subset1 = subset(vi, [@varname(m),]); -julia> keys(varinfo_subset1) -1-element Vector{VarName{:m, typeof(identity)}}: +julia> keys(vi_subset1) +1-element Vector{VarName}: m -julia> varinfo_subset1[@varname(m)] +julia> vi_subset1[@varname(m)] 2.0 julia> # Extract one with both `s` and `x[2]`. - varinfo_subset2 = subset(varinfo, [@varname(s), @varname(x[2])]); + vi_subset2 = subset(vi, [@varname(s), @varname(x[2])]); -julia> keys(varinfo_subset2) +julia> keys(vi_subset2) 2-element Vector{VarName}: s x[2] -julia> varinfo_subset2[[@varname(s), @varname(x[2])]] +julia> vi_subset2[[@varname(s), @varname(x[2])]] 2-element Vector{Float64}: 1.0 4.0 ``` -`subset` is particularly useful when combined with [`merge(varinfo::AbstractVarInfo)`](@ref) +`subset` is particularly useful when combined with [`merge(vi::AbstractVarInfo)`](@ref) ```jldoctest varinfo-subset julia> # Merge the two. - varinfo_subset_merged = merge(varinfo_subset1, varinfo_subset2); + vi_subset_merged = merge(vi_subset1, vi_subset2); -julia> keys(varinfo_subset_merged) +julia> keys(vi_subset_merged) 3-element Vector{VarName}: m s x[2] -julia> varinfo_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] +julia> vi_subset_merged[[@varname(s), @varname(m), @varname(x[2])]] 3-element Vector{Float64}: 1.0 2.0 4.0 julia> # Merge the two with the original. - varinfo_merged = merge(varinfo, varinfo_subset_merged); + vi_merged = merge(vi, vi_subset_merged); -julia> keys(varinfo_merged) +julia> keys(vi_merged) 4-element Vector{VarName}: s m x[1] x[2] -julia> varinfo_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] +julia> vi_merged[[@varname(s), @varname(m), @varname(x[1]), @varname(x[2])]] 4-element Vector{Float64}: 1.0 2.0 diff --git a/src/model.jl b/src/model.jl index cd36ee44b..7d65df842 100644 --- a/src/model.jl +++ b/src/model.jl @@ -501,7 +501,7 @@ true julia> # Since we conditioned on `a.m`, it is not treated as a random variable. # However, `a.x` will still be a random variable. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: @@ -513,7 +513,7 @@ Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: julia> # Now `a.x` will be sampled. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x ``` """ @@ -839,7 +839,7 @@ julia> Set(keys(fixed(cm))) == Set([@varname(a.m), @varname(x)]) true julia> keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x julia> # We can also condition on `a.m` _outside_ of the PrefixContext: @@ -851,7 +851,7 @@ Dict{VarName{:a, Accessors.PropertyLens{:m}}, Float64} with 1 entry: julia> # Now `a.x` will be sampled. keys(VarInfo(cm)) -1-element Vector{VarName{:a, Accessors.PropertyLens{:x}}}: +1-element Vector{VarName}: a.x ``` """ diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index f7440d6ff..304b99a3e 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -107,35 +107,30 @@ julia> @model function model_changing_support() julia> model = model_changing_support(); -julia> # Construct initial type-stable `VarInfo`. +julia> # Construct initial `VarInfo`. varinfo = VarInfo(rng, model); julia> # Link it so it works in unconstrained space. - varinfo_linked = DynamicPPL.link(varinfo, model); + varinfo_linked = DynamicPPL.link!!(copy(varinfo), model); -julia> # Perform computations in unconstrained space, e.g. changing the values of `θ`. +julia> # Perform computations in unconstrained space, e.g. changing the values of `vals`. # Flip `x` so we hit the other support of `y`. - θ = [!varinfo[@varname(x)], rand(rng)]; + vals = [!varinfo[@varname(x)], rand(rng)]; julia> # Update the `VarInfo` with the new values. - varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, θ); + varinfo_linked = DynamicPPL.unflatten!!(varinfo_linked, vals); julia> # Determine the expected support of `y`. - lb, ub = θ[1] == 1 ? (0, 1) : (11, 12) + lb, ub = vals[1] == 1 ? (0, 1) : (11, 12) (0, 1) julia> # Approach 1: Convert back to constrained space using `invlink` and extract. - varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); + varinfo_invlinked = DynamicPPL.invlink!!(copy(varinfo_linked), model); -julia> # (×) Fails! Because `VarInfo` _saves_ the original distributions - # used in the very first model evaluation, hence the support of `y` - # is not updated even though `x` has changed. - lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub -false +julia> lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub +true julia> # Approach 2: Extract realizations using `values_as_in_model`. - # (✓) `values_as_in_model` will re-run the model and extract - # the correct realization of `y` given the new values of `x`. lb ≤ values_as_in_model(model, true, varinfo_linked)[@varname(y)] ≤ ub true ``` diff --git a/src/vntvarinfo.jl b/src/vntvarinfo.jl index b0cafa364..a7eafc460 100644 --- a/src/vntvarinfo.jl +++ b/src/vntvarinfo.jl @@ -54,6 +54,10 @@ function Base.getindex(vi::VNTVarInfo, vn::VarName) return tv.transform(tv.val) end +function Base.getindex(vi::VNTVarInfo, vns::Vector{<:VarName}) + return [getindex(vi, vn) for vn in vns] +end + function Base.getindex(vi::VNTVarInfo, vn::VarName, dist::Distribution) val = getindex_internal(vi, vn) return from_maybe_linked_internal(vi, vn, dist, val) From 1f6335db7c097ff393696723bf4c64179486e68b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:11:03 +0000 Subject: [PATCH 33/56] Rename vntvarinfo.jl to varinfo.jl --- src/DynamicPPL.jl | 2 +- src/{vntvarinfo.jl => varinfo.jl} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename src/{vntvarinfo.jl => varinfo.jl} (100%) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b5a77be03..d6f4025ca 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -199,7 +199,7 @@ include("accumulators.jl") include("default_accumulators.jl") include("abstract_varinfo.jl") include("threadsafe.jl") -include("vntvarinfo.jl") +include("varinfo.jl") include("onlyaccs.jl") include("compiler.jl") include("pointwise_logdensities.jl") diff --git a/src/vntvarinfo.jl b/src/varinfo.jl similarity index 100% rename from src/vntvarinfo.jl rename to src/varinfo.jl From dbcf5f646b70a965fd739f3a0accca05e483c564 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:14:30 +0000 Subject: [PATCH 34/56] Rename VNTVarInfo to VarInfo --- src/logdensityfunction.jl | 2 +- src/varinfo.jl | 135 ++++++++++++++++++-------------------- test/test_util.jl | 4 +- 3 files changed, 68 insertions(+), 73 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 4f8ac4933..17101d0d2 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -303,7 +303,7 @@ representation, along with whether each variable is linked or unlinked. This function returns a VarNamedTuple mapping all VarNames to their corresponding `RangeAndLinked`. """ -function get_ranges_and_linked(vi::VNTVarInfo) +function get_ranges_and_linked(vi::VarInfo) # TODO(mhauru) Check that the closure doesn't cause type instability here. vnt = VarNamedTuple() vnt, _ = mapreduce( diff --git a/src/varinfo.jl b/src/varinfo.jl index a7eafc460..37728dca2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,11 +1,8 @@ -struct VNTVarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo +struct VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo values::T accs::Accs end -# TODO(mhauru) Make this renaming permanent. -const VarInfo = VNTVarInfo - struct TransformedValue{ValType,TransformType,SizeType} val::ValType linked::Bool @@ -15,9 +12,9 @@ end VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size -VNTVarInfo() = VNTVarInfo(VarNamedTuple(), default_accumulators()) +VarInfo() = VarInfo(VarNamedTuple(), default_accumulators()) -function VNTVarInfo(values::Union{NamedTuple,AbstractDict}) +function VarInfo(values::Union{NamedTuple,AbstractDict}) vi = VarInfo() for (k, v) in pairs(values) vn = k isa Symbol ? VarName{k}() : k @@ -26,52 +23,52 @@ function VNTVarInfo(values::Union{NamedTuple,AbstractDict}) return vi end -function VNTVarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return VNTVarInfo(Random.default_rng(), model, init_strategy) +function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return VarInfo(Random.default_rng(), model, init_strategy) end -function VNTVarInfo( +function VarInfo( rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return last(init!!(rng, model, VNTVarInfo(), init_strategy)) + return last(init!!(rng, model, VarInfo(), init_strategy)) end -getaccs(vi::VNTVarInfo) = vi.accs -setaccs!!(vi::VNTVarInfo, accs::AccumulatorTuple) = VNTVarInfo(vi.values, accs) +getaccs(vi::VarInfo) = vi.accs +setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.values, accs) -transformation(::VNTVarInfo) = DynamicTransformation() +transformation(::VarInfo) = DynamicTransformation() -Base.copy(vi::VNTVarInfo) = VNTVarInfo(copy(vi.values), copy(getaccs(vi))) +Base.copy(vi::VarInfo) = VarInfo(copy(vi.values), copy(getaccs(vi))) -Base.haskey(vi::VNTVarInfo, vn::VarName) = haskey(vi.values, vn) +Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) -Base.length(vi::VNTVarInfo) = length(vi.values) +Base.length(vi::VarInfo) = length(vi.values) -function Base.getindex(vi::VNTVarInfo, vn::VarName) +function Base.getindex(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.transform(tv.val) end -function Base.getindex(vi::VNTVarInfo, vns::Vector{<:VarName}) +function Base.getindex(vi::VarInfo, vns::Vector{<:VarName}) return [getindex(vi, vn) for vn in vns] end -function Base.getindex(vi::VNTVarInfo, vn::VarName, dist::Distribution) +function Base.getindex(vi::VarInfo, vn::VarName, dist::Distribution) val = getindex_internal(vi, vn) return from_maybe_linked_internal(vi, vn, dist, val) end -Base.isempty(vi::VNTVarInfo) = isempty(vi.values) -Base.empty(vi::VNTVarInfo) = VNTVarInfo(empty(vi.values), map(reset, vi.accs)) -BangBang.empty!!(vi::VNTVarInfo) = VNTVarInfo(empty!!(vi.values), map(reset, vi.accs)) +Base.isempty(vi::VarInfo) = isempty(vi.values) +Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) -function setindex_internal!!(vi::VNTVarInfo, val, vn::VarName) +function setindex_internal!!(vi::VarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end # TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, @@ -80,7 +77,7 @@ end # of doing the transformation to the caller, it'll be done even when e.g. using # OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be removed # once VAIMAcc is the only way to get values out of an evaluation. -function setindex_with_dist!!(vi::VNTVarInfo, val, dist::Distribution, vn::VarName) +function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) # Determine whether to insert a transformed value into `vi`. # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we @@ -98,81 +95,81 @@ function setindex_with_dist!!(vi::VNTVarInfo, val, dist::Distribution, vn::VarNa transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () tv = TransformedValue(transformed_val, insert_transformed_value, transform, val_size) - vi = VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) + vi = VarInfo(setindex!!(vi.values, tv, vn), vi.accs) return vi, logjac end -function BangBang.setindex!!(vi::VNTVarInfo, val, vn::VarName) +function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) transform = from_vec_transform(val) transformed_val = inverse(transform)(val) tv = TransformedValue(transformed_val, false, transform, size(val)) - return VNTVarInfo(setindex!!(vi.values, tv, vn), vi.accs) + return VarInfo(setindex!!(vi.values, tv, vn), vi.accs) end -Base.keys(vi::VNTVarInfo) = keys(vi.values) -Base.values(vi::VNTVarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) +Base.keys(vi::VarInfo) = keys(vi.values) +Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) -function set_transformed!!(vi::VNTVarInfo, linked::Bool, vn::VarName) +function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end -# VNTVarInfo does not care whether the transformation was Static or Dynamic, it just tracks +# VarInfo does not care whether the transformation was Static or Dynamic, it just tracks # whether one was applied at all. -function set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation, vn::VarName) +function set_transformed!!(vi::VarInfo, ::AbstractTransformation, vn::VarName) return set_transformed!!(vi, true, vn) end -set_transformed!!(vi::VNTVarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) +set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) -function set_transformed!!(vi::VNTVarInfo, ::NoTransformation, vn::VarName) +function set_transformed!!(vi::VarInfo, ::NoTransformation, vn::VarName) return set_transformed!!(vi, false, vn) end -set_transformed!!(vi::VNTVarInfo, ::NoTransformation) = set_transformed!!(vi, false) +set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) -function set_transformed!!(vi::VNTVarInfo, linked::Bool) +function set_transformed!!(vi::VarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv TransformedValue(tv.val, linked, tv.transform, tv.size) end - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end -function getindex_internal(vi::VNTVarInfo, vn::VarName) +function getindex_internal(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.val end # TODO(mhauru) This is mimicing old behaviour, but is now wrong: The internal # representation does not have to be a Vector. -getindex_internal(vi::VNTVarInfo, ::Colon) = values_as(vi, Vector) +getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -function is_transformed(vi::VNTVarInfo, vn::VarName) +function is_transformed(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.linked end # TODO(mhauru) Other VarInfos have something like this. Do we need it? Or should we use the # below version? -function from_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) +function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) end -# function from_internal_transform(vi::VNTVarInfo, vn::VarName, ::Distribution) +# function from_internal_transform(vi::VarInfo, vn::VarName, ::Distribution) # return getindex(vi.values, vn).transform # end -function from_linked_internal_transform(::VNTVarInfo, ::VarName, dist::Distribution) +function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_linked_vec_transform(dist) end -function from_linked_internal_transform(vi::VNTVarInfo, vn::VarName) +function from_linked_internal_transform(vi::VarInfo, vn::VarName) return getindex(vi.values, vn).transform end -function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) +function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = map_pairs!!(vi.values) do pair @@ -192,18 +189,18 @@ function link!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) cumulative_logjac += logjac1 + logjac2 return new_tv end - vi = VNTVarInfo(new_values, vi.accs) + vi = VarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, cumulative_logjac) end return vi end -function link!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) +function link!!(t::DynamicTransformation, vi::VarInfo, model::Model) return link!!(t, vi, nothing, model) end -function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = map_pairs!!(vi.values) do pair @@ -224,18 +221,18 @@ function invlink!!(::DynamicTransformation, vi::VNTVarInfo, vns, model::Model) cumulative_logjac += logjac1 + logjac2 return new_tv end - vi = VNTVarInfo(new_values, vi.accs) + vi = VarInfo(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, cumulative_logjac) end return vi end -function invlink!!(t::DynamicTransformation, vi::VNTVarInfo, model::Model) +function invlink!!(t::DynamicTransformation, vi::VarInfo, model::Model) return invlink!!(t, vi, nothing, model) end -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model) +function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) # By default this will simply evaluate the model with `DynamicTransformationContext`, # and so we need to specialize to avoid this. return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) @@ -243,7 +240,7 @@ end function link!!( t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vi::ThreadSafeVarInfo{<:VarInfo}, vns::VarNameTuple, model::Model, ) @@ -252,9 +249,7 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function invlink!!( - t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VNTVarInfo}, model::Model -) +function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) # By default this will simply evaluate the model with `DynamicTransformationContext`, # and so we need to specialize to avoid this. return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) @@ -262,7 +257,7 @@ end function invlink!!( ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VNTVarInfo}, + vi::ThreadSafeVarInfo{<:VarInfo}, vns::VarNameTuple, model::Model, ) @@ -273,11 +268,11 @@ end # TODO(mhauru) I don't think this should return the internal values, but that's the current # convention. -function values_as(vi::VNTVarInfo, ::Type{Vector}) +function values_as(vi::VarInfo, ::Type{Vector}) return mapfoldl(pair -> tovec(pair.second.val), vcat, vi.values; init=Union{}[]) end -function values_as(vi::VNTVarInfo, ::Type{T}) where {T<:AbstractDict} +function values_as(vi::VarInfo, ::Type{T}) where {T<:AbstractDict} return mapfoldl(identity, function (cumulant, pair) vn, tv = pair val = tv.transform(tv.val) @@ -289,7 +284,7 @@ end # interface provided by rand(::Model). We should change that to return a VarNamedTuple # instead, and then this method (and any other values_as methods for NamedTuple) could be # removed. -function values_as(vi::VNTVarInfo, ::Type{NamedTuple}) +function values_as(vi::VarInfo, ::Type{NamedTuple}) return mapfoldl( identity, function (cumulant, pair) @@ -309,7 +304,7 @@ function untyped_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return VNTVarInfo(rng, model, init_strategy) + return VarInfo(rng, model, init_strategy) end function typed_varinfo( @@ -317,10 +312,10 @@ function typed_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return VNTVarInfo(rng, model, init_strategy) + return VarInfo(rng, model, init_strategy) end -typed_varinfo(vi::VNTVarInfo) = vi +typed_varinfo(vi::VarInfo) = vi function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) return typed_varinfo(Random.default_rng(), model, init_strategy) @@ -350,7 +345,7 @@ function get_next_chunk!(vci::VectorChunkIterator, len::Int) return chunk end -function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) +function unflatten!!(vi::VarInfo, vec::AbstractVector) # You may wonder, why have a whole struct for this, rather than just an index variable # that the mapping function would close over. I wonder too. But for some reason type # inference fails on such an index variable, turning it into a Core.Box. @@ -367,16 +362,16 @@ function unflatten!!(vi::VNTVarInfo, vec::AbstractVector) new_val = get_next_chunk!(vci, len) return TransformedValue(new_val, tv.linked, tv.transform, tv.size) end - return VNTVarInfo(new_values, vi.accs) + return VarInfo(new_values, vi.accs) end -function subset(varinfo::VNTVarInfo, vns) +function subset(varinfo::VarInfo, vns) new_values = subset(varinfo.values, vns) - return VNTVarInfo(new_values, map(copy, getaccs(varinfo))) + return VarInfo(new_values, map(copy, getaccs(varinfo))) end -function Base.merge(varinfo_left::VNTVarInfo, varinfo_right::VNTVarInfo) +function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) - return VNTVarInfo(new_values, new_accs) + return VarInfo(new_values, new_accs) end diff --git a/test/test_util.jl b/test/test_util.jl index 9f6939adf..8f402ad8f 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -16,8 +16,8 @@ Return string representing a short description of `vi`. function short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) return "threadsafe($(short_varinfo_name(vi.varinfo)))" end -function short_varinfo_name(::DynamicPPL.VNTVarInfo) - return "VNTVarInfo" +function short_varinfo_name(::DynamicPPL.VarInfo) + return "VarInfo" end # convenient functions for testing model.jl From 0edaa53e9acd366e75acd595511ecad147734684 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:43:03 +0000 Subject: [PATCH 35/56] Remove (un)typed_varinfo --- src/chains.jl | 8 -------- src/test_utils/contexts.jl | 27 +++++++++++---------------- src/varinfo.jl | 28 ---------------------------- 3 files changed, 11 insertions(+), 52 deletions(-) diff --git a/src/chains.jl b/src/chains.jl index ca653fff9..ee4312547 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -38,7 +38,6 @@ function ParamsWithStats( include_colon_eq::Bool=true, include_log_probs::Bool=true, ) - varinfo = maybe_to_typed_varinfo(varinfo) accs = if include_log_probs ( DynamicPPL.LogPriorAccumulator(), @@ -64,13 +63,6 @@ function ParamsWithStats( return ParamsWithStats(params, stats) end -# Re-evaluating the model is unconscionably slow for untyped VarInfo. It's much faster to -# convert it to a typed varinfo first, hence this method. -# https://github.com/TuringLang/Turing.jl/issues/2604 -# maybe_to_typed_varinfo(vi::UntypedVarInfo) = typed_varinfo(vi) -# maybe_to_typed_varinfo(vi::UntypedVectorVarInfo) = typed_vector_varinfo(vi) -maybe_to_typed_varinfo(vi::AbstractVarInfo) = vi - """ ParamsWithStats( varinfo::AbstractVarInfo, diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index c48d2ddfd..cceedee8c 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -36,16 +36,12 @@ function test_leaf_context(context::DynamicPPL.AbstractContext, model::DynamicPP # varinfos.) Thus we only test evaluation with VarInfos that are already # filled with values. @testset "evaluation" begin - # Generate a new filled untyped varinfo - _, untyped_vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) - typed_vi = DynamicPPL.typed_varinfo(untyped_vi) + # Generate a new filled varinfo + _, vi = DynamicPPL.init!!(model, DynamicPPL.VarInfo()) # Set the test context as the new leaf context new_model = DynamicPPL.setleafcontext(model, context) - # Check that evaluation works - for vi in [untyped_vi, typed_vi] - _, vi = DynamicPPL.evaluate!!(new_model, vi) - @test vi isa DynamicPPL.VarInfo - end + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo end end @@ -73,13 +69,12 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic @testset "initialisation and evaluation" begin new_model = contextualize(model, context) - for vi in [DynamicPPL.VarInfo(), DynamicPPL.typed_varinfo(DynamicPPL.VarInfo())] - # Initialisation - _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) - @test vi isa DynamicPPL.VarInfo - # Evaluation - _, vi = DynamicPPL.evaluate!!(new_model, vi) - @test vi isa DynamicPPL.VarInfo - end + vi = DynamicPPL.VarInfo() + # Initialisation + _, vi = DynamicPPL.init!!(new_model, DynamicPPL.VarInfo()) + @test vi isa DynamicPPL.VarInfo + # Evaluation + _, vi = DynamicPPL.evaluate!!(new_model, vi) + @test vi isa DynamicPPL.VarInfo end end diff --git a/src/varinfo.jl b/src/varinfo.jl index 37728dca2..170181b80 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -297,34 +297,6 @@ function values_as(vi::VarInfo, ::Type{NamedTuple}) ) end -# TODO(mhauru) These two are now redundant, just conforming to the old interface -# temporarily. -function untyped_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return VarInfo(rng, model, init_strategy) -end - -function typed_varinfo( - rng::Random.AbstractRNG, - model::Model, - init_strategy::AbstractInitStrategy=InitFromPrior(), -) - return VarInfo(rng, model, init_strategy) -end - -typed_varinfo(vi::VarInfo) = vi - -function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return typed_varinfo(Random.default_rng(), model, init_strategy) -end - -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) -end - """ VectorChunkIterator{T<:AbstractVector} From c2748a79a33d2234a081000479d2cfbf8369f89b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 14:47:21 +0000 Subject: [PATCH 36/56] Add docstrings to varinfo.jl --- src/varinfo.jl | 135 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 119 insertions(+), 16 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 170181b80..688c90b03 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,12 +1,67 @@ +""" + VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + +The default implementation of `AbstractVarInfo`, storing variable values and accumulators. + +`VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, +and a tuple of accumulators. The only really noteworthy thing about it is that it stores +the values of variables vectorised as instances of `TransformedValue`. That is, it stores +each value as a vector and a transformation to be applied to that vector to get the actual +value. It also stores whether the transformation is such that it guarantees all real vectors +to be valid internal representations of the variable (i.e., whether the variable has been +linked), as well as the size of the actual post-transformation value. These are all fields +of [`TransformedValue`](@ref). + +Note that `setindex!!` and `getindex` on `VarInfo` deal with the actual values of variables. +To get access to the internal vectorised values, use [`getindex_internal`](@ref), +[`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). + +There's also a `VarInfo`-specific function [`setindex_with_dist!!`](@ref), which sets a +variable's value with a transformation based on the statistical distribution this value is +a sample for. + +For more details on the internal storage, see documentation of [`TransformedValue`](@ref) and +[`VarNamedTuple`](@ref). + +# Fields +$(TYPEDFIELDS) + +""" struct VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo values::T accs::Accs end +# TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was +# using a Vector as the internal storage in all cases. We should revisit this, and allow +# values to be stored "raw", since VarNamedTuple supports it. + +# TODO(mhauru) Related to the above, I think we should reconsider whether we should store +# transformations at all. We rarely use them, since they may be dynamic in a model. +# tilde_assume!! rather gets the transformation from the current distribution encountered +# during model execution. However, this would change the interface quite a lot, so I want to +# finish implementing VarInfo using VNT (mostly) respecting the old interface first. + +""" + TransformedValue{ValType,TransformType,SizeType} + +A struct for storing a variable's value in its internal (vectorised) form. + +# Fields +$(TYPEDFIELDS) +""" struct TransformedValue{ValType,TransformType,SizeType} + "The internal (vectorised) value." val::ValType + """Boolean indicating whether the variable is linked, i.e. the transformation maps all + real vectors to valid values.""" linked::Bool + """The transformation from internal (vectorised) to actual value. In other words, the + actual value of the variable being stored is `transform(val)`.""" transform::TransformType + """The size of the actual value after transformation. This is needed when a + TransformedValue is stored as a block in an array (see [`PartialArray`](@ref) in + `VarNamedTuples`).""" size::SizeType end @@ -41,10 +96,10 @@ setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.values, accs) transformation(::VarInfo) = DynamicTransformation() Base.copy(vi::VarInfo) = VarInfo(copy(vi.values), copy(getaccs(vi))) - Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) - Base.length(vi::VarInfo) = length(vi.values) +Base.keys(vi::VarInfo) = keys(vi.values) +Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) function Base.getindex(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) @@ -64,6 +119,13 @@ Base.isempty(vi::VarInfo) = isempty(vi.values) Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) +""" + setindex_internal!!(vi::VarInfo, val, vn::VarName) + +Set the internal (vectorised) value of variable `vn` in `vi` to `val`. + +This does not change the transformation or linked status of the variable. +""" function setindex_internal!!(vi::VarInfo, val, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) @@ -73,10 +135,23 @@ end # TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, # we need `dist` to determine the linking transformation (or even just the vectorisation -# transformation, in the case of ProductNamedTupleDistribions), and if we leave the work -# of doing the transformation to the caller, it'll be done even when e.g. using -# OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be removed -# once VAIMAcc is the only way to get values out of an evaluation. +# transformation in the case of ProductNamedTupleDistribions), and if we leave the work +# of doing the transformation to the caller (tilde_assume!!), it'll be done even when e.g. +# using OnlyAccsVarInfo. Hence having this function. It should eventually hopefully be +# removed once VAIMAcc is the only way to get values out of an evaluation. +""" + setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) + +Set the value of `vn` in `vi` to `val`, applying a transformation based on `dist`. + +`val` is taken to be the actual value of the variable, and is transformed into the internal +(vectorised) representation using a transformation based on `dist`. If the variable is +linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are linked, the +linking transformation is used; otherwise, the standard vector transformation is used. + +Returns the modified `vi` together with the log absolute determinant of the Jacobian of the +transformation applied. +""" function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) # Determine whether to insert a transformed value into `vi`. # If the VarInfo alrady had a value for this variable, we will @@ -99,6 +174,14 @@ function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) return vi, logjac end +""" + setindex!!(vi::VarInfo, val, vn::VarName) + +Set the value of `vn` in `vi` to `val`. + +The transformation for `vn` is reset to be the standard vector transformation for values of +the type of `val` and linking status is set to false. +""" function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) transform = from_vec_transform(val) transformed_val = inverse(transform)(val) @@ -106,9 +189,13 @@ function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) return VarInfo(setindex!!(vi.values, tv, vn), vi.accs) end -Base.keys(vi::VarInfo) = keys(vi.values) -Base.values(vi::VarInfo) = mapreduce(p -> p.second.val, push!, vi.values; init=Any[]) +""" + set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) + +Set the linked status of variable `vn` in `vi` to `linked`. +This does not change the value or transformation of the variable. +""" function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) @@ -137,13 +224,16 @@ function set_transformed!!(vi::VarInfo, linked::Bool) return VarInfo(new_values, vi.accs) end +""" + getindex_internal(vi::VarInfo, vn::VarName) + +Get the internal (vectorised) value of variable `vn` in `vi`. +""" function getindex_internal(vi::VarInfo, vn::VarName) tv = getindex(vi.values, vn) return tv.val end -# TODO(mhauru) This is mimicing old behaviour, but is now wrong: The internal -# representation does not have to be a Vector. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) function is_transformed(vi::VarInfo, vn::VarName) @@ -151,20 +241,18 @@ function is_transformed(vi::VarInfo, vn::VarName) return tv.linked end -# TODO(mhauru) Other VarInfos have something like this. Do we need it? Or should we use the -# below version? function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) end -# function from_internal_transform(vi::VarInfo, vn::VarName, ::Distribution) -# return getindex(vi.values, vn).transform -# end - function from_linked_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_linked_vec_transform(dist) end +function from_internal_transform(vi::VarInfo, vn::VarName) + return getindex(vi.values, vn).transform +end + function from_linked_internal_transform(vi::VarInfo, vn::VarName) return getindex(vi.values, vn).transform end @@ -337,11 +425,26 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) return VarInfo(new_values, vi.accs) end +""" + subset(varinfo::VarInfo, vns) + +Create a new `VarInfo` containing only the variables in `vns`. + +`vns` can be almost any collection of `VarName`s, e.g. a `Set`, `Vector`, or `Tuple`. +""" function subset(varinfo::VarInfo, vns) new_values = subset(varinfo.values, vns) return VarInfo(new_values, map(copy, getaccs(varinfo))) end +""" + merge(varinfo_left::VarInfo, varinfo_right::VarInfo) + +Merge two `VarInfo`s into a new `VarInfo` containing all variables from both. + +If a variable exists in both `varinfo_left` and `varinfo_right`, the value from +`varinfo_right` is used. +""" function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) From 6dbae236ddf5b3345c959c20498d1a802bd27e0f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 17:15:31 +0000 Subject: [PATCH 37/56] Simplify transformations --- src/DynamicPPL.jl | 1 - src/abstract_varinfo.jl | 66 +--------------- src/contexts.jl | 16 ++-- src/contexts/transformation.jl | 44 ----------- src/test_utils/contexts.jl | 2 +- src/threadsafe.jl | 55 -------------- src/varinfo.jl | 134 +++++++++++++++++---------------- 7 files changed, 81 insertions(+), 237 deletions(-) delete mode 100644 src/contexts/transformation.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d6f4025ca..5889a6915 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -188,7 +188,6 @@ using .VarNamedTuples: VarNamedTuples, VarNamedTuple, map_pairs!!, map_values!!, include("contexts.jl") include("contexts/default.jl") include("contexts/init.jl") -include("contexts/transformation.jl") include("contexts/prefix.jl") include("contexts/conditionfix.jl") # Must come after contexts/prefix.jl include("model.jl") diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index c4af10898..67ac822cd 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -32,6 +32,9 @@ in the execution of a given `Model`. This is in constrast to `StaticTransformation` which transforms all variables _before_ the execution of a given `Model`. +Different VarInfo types should implement their own methods for `link!!` and `invlink!!` for +`DynamicTransformation`. + See also: [`StaticTransformation`](@ref). """ struct DynamicTransformation <: AbstractTransformation end @@ -53,23 +56,6 @@ struct StaticTransformation{F} <: AbstractTransformation bijector::F end -""" - merge_transformations(transformation_left, transformation_right) - -Merge two transformations. - -The main use of this is in [`merge(::AbstractVarInfo, ::AbstractVarInfo)`](@ref). -""" -function merge_transformations(::NoTransformation, ::NoTransformation) - return NoTransformation() -end -function merge_transformations(::DynamicTransformation, ::DynamicTransformation) - return DynamicTransformation() -end -function merge_transformations(left::StaticTransformation, right::StaticTransformation) - return StaticTransformation(merge_bijectors(left.bijector, right.bijector)) -end - function merge_bijectors(left::Bijectors.NamedTransform, right::Bijectors.NamedTransform) return Bijectors.NamedTransform(merge_bijector(left.bs, right.bs)) end @@ -744,31 +730,6 @@ end function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end -function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{false}()) - vi = last(evaluate!!(model, vi)) - return set_transformed!!(vi, t) -end -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - # TODO(mhauru) This assumes that the user has defined the bijector using the same - # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user - # interface. - b = inverse(t.bijector) - x = vi[:] - y, logjac = with_logabsdet_jacobian(b, x) - # Set parameters and add the logjac term. - # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant - # that getindex(vi, vn) would apply the default link transform of the distribution. With - # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any - # transform. Neither is correct, rather the transform should be the inverse of b. - vi = unflatten!!(vi, y) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, logjac) - end - return set_transformed!!(vi, t) -end """ link([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) @@ -811,27 +772,6 @@ end function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end -function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{true}()) - vi = last(evaluate!!(model, vi)) - return set_transformed!!(vi, NoTransformation()) -end -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model -) - b = t.bijector - y = vi[:] - x, inv_logjac = with_logabsdet_jacobian(b, y) - - # Mildly confusing: we need to _add_ the logjac of the inverse transform, - # because we are trying to remove the logjac of the forward transform - # that was previously accumulated when linking. - vi = unflatten!!(vi, x) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, inv_logjac) - end - return set_transformed!!(vi, NoTransformation()) -end """ invlink([t::AbstractTransformation, ]vi::AbstractVarInfo, model::Model) diff --git a/src/contexts.jl b/src/contexts.jl index 46c5b8855..0eccf7b53 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -25,18 +25,18 @@ Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. # Examples -```jldoctest -julia> using DynamicPPL: DynamicTransformationContext, ConditionContext +```jldoctest; setup=:(using Random) +julia> using DynamicPPL: InitContext, ConditionContext julia> ctx = ConditionContext((; a = 1)); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.setchildcontext(ctx, DynamicTransformationContext{true}()); +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, InitContext(MersenneTwister(23), InitFromPrior())); julia> DynamicPPL.childcontext(ctx_prior) -DynamicTransformationContext{true}() +InitContext{MersenneTwister, InitFromPrior}(MersenneTwister(23), InitFromPrior()) ``` """ setchildcontext @@ -60,8 +60,8 @@ in which case effectively append `right` to `left`, dropping the original leaf context of `left`. # Examples -```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, DynamicTransformationContext +```jldoctest; setup=:(using Random) +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, InitContext julia> struct ParentContext{C} <: AbstractParentContext context::C @@ -77,8 +77,8 @@ julia> ctx = ParentContext(ParentContext(DefaultContext())) ParentContext(ParentContext(DefaultContext())) julia> # Replace the leaf context with another leaf. - leafcontext(setleafcontext(ctx, DynamicTransformationContext{true}())) -DynamicTransformationContext{true}() + leafcontext(setleafcontext(ctx, InitContext(MersenneTwister(23), InitFromPrior()))) +InitContext{MersenneTwister, InitFromPrior}(MersenneTwister(23), InitFromPrior()) julia> # Append another parent context. setleafcontext(ctx, ParentContext(DefaultContext())) diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl deleted file mode 100644 index 0914d7a79..000000000 --- a/src/contexts/transformation.jl +++ /dev/null @@ -1,44 +0,0 @@ -""" - struct DynamicTransformationContext{isinverse} <: AbstractContext - -When a model is evaluated with this context, transform the accompanying `AbstractVarInfo` to -constrained space if `isinverse` or unconstrained if `!isinverse`. - -Note that some `AbstractVarInfo` types, must notably `VarInfo`, override the -`DynamicTransformationContext` methods with more efficient implementations. -`DynamicTransformationContext` is a fallback for when we need to evaluate the model to know -how to do the transformation. -""" -struct DynamicTransformationContext{isinverse} <: AbstractContext end - -function tilde_assume!!( - ::DynamicTransformationContext{isinverse}, - right::Distribution, - vn::VarName, - vi::AbstractVarInfo, -) where {isinverse} - # vi[vn, right] always provides the value in unlinked space. - x = vi[vn, right] - - if is_transformed(vi, vn) - isinverse || @warn "Trying to link an already transformed variable ($vn)" - else - isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" - end - - transform = isinverse ? identity : link_transform(right) - y, logjac = with_logabsdet_jacobian(transform, x) - vi = accumulate_assume!!(vi, x, logjac, vn, right) - vi = setindex!!(vi, y, vn) - return x, vi -end - -function tilde_observe!!( - ::DynamicTransformationContext, - right::Distribution, - left, - vn::Union{VarName,Nothing}, - vi::AbstractVarInfo, -) - return tilde_observe!!(DefaultContext(), right, left, vn, vi) -end diff --git a/src/test_utils/contexts.jl b/src/test_utils/contexts.jl index cceedee8c..7182f511e 100644 --- a/src/test_utils/contexts.jl +++ b/src/test_utils/contexts.jl @@ -49,7 +49,7 @@ function test_parent_context(context::DynamicPPL.AbstractContext, model::Dynamic @testset "get/set leaf and child contexts" begin # Ensure we're using a different leaf context than the current. leafcontext_new = if DynamicPPL.leafcontext(context) isa DefaultContext - DynamicPPL.DynamicTransformationContext{false}() + DynamicPPL.InitContext(Random.MersenneTwister(1234), InitFromPrior()) else DefaultContext() end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index d83cb289d..547dd6a1e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -82,61 +82,6 @@ function invlink!!(t::AbstractTransformation, vi::ThreadSafeVarInfo, args...) return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, args...) end -function link(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, model) -end - -function invlink(t::AbstractTransformation, vi::ThreadSafeVarInfo, model::Model) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, model) -end - -function link( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model -) - return Accessors.@set vi.varinfo = link(t, vi.varinfo, vns, model) -end - -function invlink( - t::AbstractTransformation, vi::ThreadSafeVarInfo, vns::VarNameTuple, model::Model -) - return Accessors.@set vi.varinfo = invlink(t, vi.varinfo, vns, model) -end - -# Need to define explicitly for `DynamicTransformation` to avoid method ambiguity. -# NOTE: We also can't just defer to the wrapped varinfo, because we need to ensure -# consistency between `vi.accs_by_thread` field and `getacc(vi.varinfo)`, which accumulates -# to define `getacc(vi)`. -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{false}()) - return set_transformed!!(last(evaluate!!(model, vi)), t) -end - -function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - model = setleafcontext(model, DynamicTransformationContext{true}()) - return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation()) -end - -function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return link!!(t, deepcopy(vi), model) -end - -function invlink(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) - return invlink!!(t, deepcopy(vi), model) -end - -# These two StaticTransformation methods needed to resolve ambiguities -function link!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = link!!(t, vi.varinfo, model) -end - -function invlink!!( - t::StaticTransformation{<:Bijectors.Transform}, vi::ThreadSafeVarInfo, model::Model -) - return Accessors.@set vi.varinfo = invlink!!(t, vi.varinfo, model) -end - function maybe_invlink_before_eval!!(vi::ThreadSafeVarInfo, model::Model) # Defer to the wrapped `AbstractVarInfo` object. # NOTE: When computing `getacc` for `ThreadSafeVarInfo` we do include the diff --git a/src/varinfo.jl b/src/varinfo.jl index 688c90b03..3af47691b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -257,38 +257,25 @@ function from_linked_internal_transform(vi::VarInfo, vn::VarName) return getindex(vi.values, vn).transform end -function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) - dists = extract_priors(model, vi) - cumulative_logjac = zero(LogProbType) - new_values = map_pairs!!(vi.values) do pair - vn, tv = pair - if vns !== nothing && !any(x -> subsumes(x, vn), vns) - # Not one of the target variables. - return tv - end - dist = getindex(dists, vn) - vec_transform = from_vec_transform(dist) - link_transform = from_linked_vec_transform(dist) - val_untransformed, logjac1 = with_logabsdet_jacobian(vec_transform, tv.val) - val_new, logjac2 = with_logabsdet_jacobian( - inverse(link_transform), val_untransformed - ) - new_tv = TransformedValue(val_new, true, link_transform, tv.size) - cumulative_logjac += logjac1 + logjac2 - return new_tv - end - vi = VarInfo(new_values, vi.accs) - if hasacc(vi, Val(:LogJacobian)) - vi = acclogjac!!(vi, cumulative_logjac) - end - return vi -end +""" + _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link isa Bool} -function link!!(t::DynamicTransformation, vi::VarInfo, model::Model) - return link!!(t, vi, nothing, model) -end +The internal function that implements both link!! and invlink!!. -function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) +The last argument controls whether linking (true) or invlinking (false) is performed. If +`vns` is `nothing`, all variables in `vi` are transformed; otherwise, only the variables +in `vns` are transformed. Existing variables already in the desired state are left +unchanged. +""" +function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where {link} + @assert link isa Bool + # Note that extract_priors causes a model execution. In the past with the Metadata-based + # VarInfo we rather derived the transformations from the distributions stored in the + # VarInfo itself. However, that is not fail-safe with dynamic models, and would require + # storing the distributions in TransformedValue (which we could start doing). Instead we + # use extract_priors to get the current, correct transformations. This logic is very + # similar to what DynamicTransformation used to do, and we might replace this with a + # context that transforms each variable in turn during the execution. dists = extract_priors(model, vi) cumulative_logjac = zero(LogProbType) new_values = map_pairs!!(vi.values) do pair @@ -297,15 +284,23 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) # Not one of the target variables. return tv end - current_val = tv.val + if tv.linked == link + # Already in the desired state. + return tv + end dist = getindex(dists, vn) vec_transform = from_vec_transform(dist) link_transform = from_linked_vec_transform(dist) - val_untransformed, logjac1 = with_logabsdet_jacobian(link_transform, current_val) + current_transform, new_transform = if link + (vec_transform, link_transform) + else + (link_transform, vec_transform) + end + val_untransformed, logjac1 = with_logabsdet_jacobian(current_transform, tv.val) val_new, logjac2 = with_logabsdet_jacobian( - inverse(vec_transform), val_untransformed + inverse(new_transform), val_untransformed ) - new_tv = TransformedValue(val_new, false, vec_transform, tv.size) + new_tv = TransformedValue(val_new, link, new_transform, tv.size) cumulative_logjac += logjac1 + logjac2 return new_tv end @@ -316,42 +311,51 @@ function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) return vi end -function invlink!!(t::DynamicTransformation, vi::VarInfo, model::Model) - return invlink!!(t, vi, nothing, model) +function link!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) + return _link_or_invlink!!(vi, vns, model, Val(true)) end - -function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, model) +function link!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, nothing, model, Val(true)) end - -function link!!( - t::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) +function invlink!!(::DynamicTransformation, vi::VarInfo, vns, model::Model) + return _link_or_invlink!!(vi, vns, model, Val(false)) +end +function invlink!!(::DynamicTransformation, vi::VarInfo, model::Model) + return _link_or_invlink!!(vi, nothing, model, Val(false)) +end + +function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) + # TODO(mhauru) This assumes that the user has defined the bijector using the same + # variable ordering as what `vi[:]` and `unflatten!!(vi, x)` use. This is a bad user + # interface. + b = inverse(t.bijector) + x = vi[:] + y, logjac = with_logabsdet_jacobian(b, x) + # Set parameters and add the logjac term. + # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant + # that getindex(vi, vn) would apply the default link transform of the distribution. With + # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any + # transform. Neither is correct, rather the transform should be the inverse of b. + vi = unflatten!!(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return set_transformed!!(vi, t) end -function invlink!!(t::DynamicTransformation, vi::ThreadSafeVarInfo{<:VarInfo}, model::Model) - # By default this will simply evaluate the model with `DynamicTransformationContext`, - # and so we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(t, vi.varinfo, model) -end +function invlink!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::Model) + b = t.bijector + y = vi[:] + x, inv_logjac = with_logabsdet_jacobian(b, y) -function invlink!!( - ::DynamicTransformation, - vi::ThreadSafeVarInfo{<:VarInfo}, - vns::VarNameTuple, - model::Model, -) - # By default this will simply evaluate the model with `DynamicTransformationContext`, and so - # we need to specialize to avoid this. - return Accessors.@set vi.varinfo = DynamicPPL.invlink!!(vi.varinfo, vns, model) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + vi = unflatten!!(vi, x) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, inv_logjac) + end + return set_transformed!!(vi, NoTransformation()) end # TODO(mhauru) I don't think this should return the internal values, but that's the current From 2fa7333ea1a0a19fec823da2e13c69e58cfc062e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 18:19:32 +0000 Subject: [PATCH 38/56] Fix docs --- docs/src/api.md | 23 ++++++++--------------- docs/src/internals/varinfo.md | 1 - src/varinfo.jl | 3 +-- 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index a506c793e..5cd94fccd 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -343,6 +343,8 @@ AbstractVarInfo ```@docs VarInfo +DynamicPPL.TransformedValue +DynamicPPL.setindex_with_dist!! ``` One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [main Turing documentation](https://turinglang.org/docs/developers/transforms/dynamicppl/). @@ -354,11 +356,7 @@ is_transformed set_transformed!! ``` -```@docs -Base.empty! -``` - -#### `VarNamedTuple` +#### `VarNamedTuple`s `VarInfo` is only a thin wrapper around [`VarNamedTuple`](@ref), which stores arbitrary data keyed by `VarName`s. For more details on `VarNamedTuple`, see the Internals section of our documentation. @@ -366,6 +364,10 @@ For more details on `VarNamedTuple`, see the Internals section of our documentat ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple DynamicPPL.VarNamedTuples.vnt_size +DynamicPPL.VarNamedTuples.apply!! +DynamicPPL.VarNamedTuples.map_pairs!! +DynamicPPL.VarNamedTuples.map_values!! +DynamicPPL.VarNamedTuples.PartialArray ``` ### Accumulators @@ -411,19 +413,10 @@ accloglikelihood!! ```@docs keys getindex -push!! empty!! isempty DynamicPPL.getindex_internal -DynamicPPL.setindex_internal! -DynamicPPL.update_internal! -DynamicPPL.insert_internal! -DynamicPPL.length_internal -DynamicPPL.reset! -DynamicPPL.update! -DynamicPPL.insert! -DynamicPPL.loosen_types!! -DynamicPPL.tighten_types!! +DynamicPPL.setindex_internal!! ``` ```@docs diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index c57ea1fcf..f3f100a81 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -39,7 +39,6 @@ One can access a vectorised version of a variable's value with the following vec - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - `setindex_internal!!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values - - `length_internal(::VarInfo)`: return the length of the flat representation of `metadata`. The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. diff --git a/src/varinfo.jl b/src/varinfo.jl index 3af47691b..4cda6b40f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -60,8 +60,7 @@ struct TransformedValue{ValType,TransformType,SizeType} actual value of the variable being stored is `transform(val)`.""" transform::TransformType """The size of the actual value after transformation. This is needed when a - TransformedValue is stored as a block in an array (see [`PartialArray`](@ref) in - `VarNamedTuples`).""" + `TransformedValue` is stored as a block in an array.""" size::SizeType end From 7cbc4a7bf6c1f5657fb8258fd8c1b460d19e8367 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 18:21:23 +0000 Subject: [PATCH 39/56] Mark some inference tests as broken on 1.10 --- test/varnamedtuple.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 1937ea189..f3f1e83e6 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -667,6 +667,9 @@ Base.size(st::SizedThing) = st.size end @testset "length" begin + # Type inference for length fails in some cases on Julia versions < 1.11 + inference_broken = VERSION < v"1.11" + vnt = VarNamedTuple() @test @inferred(length(vnt)) == 0 @@ -683,23 +686,23 @@ Base.size(st::SizedThing) = st.size @test @inferred(length(vnt)) == 3 vnt = setindex!!(vnt, -1.0, @varname(d[4])) - @test @inferred(length(vnt)) == 4 + @test @inferred(length(vnt)) == 4 broken = inference_broken vnt = setindex!!(vnt, ["a", "b"], @varname(d[1:2])) - @test @inferred(length(vnt)) == 6 + @test @inferred(length(vnt)) == 6 broken = inference_broken vnt = setindex!!(vnt, 2.0, @varname(e.f[3].g.h[2].i)) vnt = setindex!!(vnt, 3.0, @varname(e.f[3].g.h[2].j)) - @test @inferred(length(vnt)) == 8 + @test @inferred(length(vnt)) == 8 broken = inference_broken vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 2:4, 2, 1:2, 3])) - @test @inferred(length(vnt)) == 14 + @test @inferred(length(vnt)) == 14 broken = inference_broken vnt = setindex!!(vnt, SizedThing((3, 2)), @varname(x[1, 4:6, 2, 1:2, 3])) - @test @inferred(length(vnt)) == 14 + @test @inferred(length(vnt)) == 14 broken = inference_broken vnt = setindex!!(vnt, [:a, :b], @varname(y[4][3][2][1:2])) - @test @inferred(length(vnt)) == 16 + @test @inferred(length(vnt)) == 16 broken = inference_broken test_invariants(vnt) end @@ -917,7 +920,9 @@ Base.size(st::SizedThing) = st.size @test haskey(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4])) @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) - @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val + # Type inference fails on this one for Julia versions < 1.11 + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val broken = + VERSION < v"1.11" end @testset "map and friends" begin From b4361c04612731e0357a17ac42dc850a055e003c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:02:12 +0000 Subject: [PATCH 40/56] Polish VNT and tests --- src/varnamedtuple.jl | 41 +++++++++++++++++++---------------------- test/varnamedtuple.jl | 6 +++++- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 37158442b..0287e393b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -26,7 +26,7 @@ function _haskey end Like `setindex!!`, but special-cased for `VarNamedTuple` and `PartialArray` to recurse into nested structures. -The `allow_new` keywword argument is a performance optimisation: If it is set to +The `allow_new` keyword argument is a performance optimisation: If it is set to `Val(false)`, the function can assume that the key being set already exists in `collection`. This allows skipping some code paths, which may have a minor benefit at runtime, but more importantly, allows for better constant propagation and type stability at compile time. @@ -541,7 +541,7 @@ function _check_index_validity(pa::PartialArray, inds::NTuple{N,INDEX_TYPES}) wh end function _getindex(pa::PartialArray, inds::Vararg{INDEX_TYPES}) - # The original, non-bare inds is needed later for ArrayLikeBlock checks. + # The unmodified inds is needed later for ArrayLikeBlock checks. orig_inds = inds inds = _unwrap_concretized_slice.(inds) _check_index_validity(pa, inds) @@ -1237,6 +1237,7 @@ end end end +# As above but with a prefix VarName `vn`. @generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} exs = Expr[] for name in Names @@ -1273,10 +1274,13 @@ map_values!!(func, vnt::VarNamedTuple) = map_pairs!!(pair -> func(pair.second), Apply `f` to all elements of `vnt`, and reduce the results using `op`, starting from `init`. +The order is the same as in `mapfoldl`, i.e. left-associative with `init` as the +left-most value. + `init` is a keyword argument to conform to the usual `mapreduce` interface in Base, but it is not optional. -`f` op` should accept pairs of `VarName` and value. +`f` op` should accept pairs of `varname => value`. """ function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) if init === nothing @@ -1298,41 +1302,30 @@ _mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa. @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, init ) where {Names} - exs = Expr[] - push!( - exs, - quote - result = init - end, - ) + exs = Expr[:(result = init)] for name in Names push!( exs, - :( + quote result = _mapreduce_recursive( f, op, vnt.data.$name, VarName{$(QuoteNode(name))}(), result ) - ), + end, ) end push!(exs, :(return result)) return Expr(:block, exs...) end +# As above but with a prefix VarName `vn`. @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, vn, init ) where {Names} - exs = Expr[] - push!( - exs, - quote - result = init - end, - ) + exs = Expr[:(result = init)] for name in Names push!( exs, - :( + quote result = _mapreduce_recursive( f, op, @@ -1340,7 +1333,7 @@ end AbstractPPL.prefix(VarName{$(QuoteNode(name))}(), vn), result, ) - ), + end, ) end push!(exs, :(return result)) @@ -1354,7 +1347,7 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) albs_seen = Set{ArrayLikeBlock}() @inbounds for i in CartesianIndices(pa.mask) if pa.mask[i] - val = @inbounds pa.data[i] + val = pa.data[i] is_alb = val isa ArrayLikeBlock if is_alb if val in albs_seen @@ -1370,6 +1363,10 @@ function _mapreduce_recursive(f, op, pa::PartialArray, vn, init) return result end +# TODO(mhauru) We could try to keep the return types of these more tight, rather than always +# return the same, abstract element type. Would that be better? It would be faster in some +# cases, but would be less consistent, and could result in a lot of allocations in the +# mapreduce, as the element type is gradually expanded. Base.keys(vnt::VarNamedTuple) = mapreduce(first, push!, vnt; init=VarName[]) Base.values(vnt::VarNamedTuple) = mapreduce(pair -> pair.second, push!, vnt; init=Any[]) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index f3f1e83e6..7c2a263f5 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -502,6 +502,9 @@ Base.size(st::SizedThing) = st.size vnt = setindex!!(vnt, SizedThing((3, 1, 4)), @varname(p[2, 1][2:4, 5:5, 11:14])) test_invariants(vnt) + # TODO(mhauru) I'm a bit saddened by the lack of type stability for subset: It's + # return type always infers as VarNamedTuple. Improving this would require a + # different implementation of subset. @test subset(vnt, VarName[]) == VarNamedTuple() @test subset(vnt, (@varname(z),)) == VarNamedTuple() @test subset(vnt, (@varname(d[4]),)) == VarNamedTuple() @@ -1025,7 +1028,8 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_mapped, @varname(w[4][3][2, 1]))) == "b" call_counter = 0 - vnt_applied = @inferred(apply!!(f_val, vnt, @varname(a))) + vnt_applied = copy(vnt) + vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(a))) @test call_counter == 1 test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(a))) == 11 From 73e50df5d576fd76473888a4244cf88c2ac59f6f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:10:07 +0000 Subject: [PATCH 41/56] Fix broken test marking --- test/varnamedtuple.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 7c2a263f5..63ee12c5b 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -924,8 +924,7 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt, @varname(y.z[2, 2:3, 3, 2:3, 4]))) == val @test haskey(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4])) # Type inference fails on this one for Julia versions < 1.11 - @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val broken = - VERSION < v"1.11" + @test @inferred(getindex(vnt, @varname(y.z[3, 2:3, 3, 2:3, 4]))) == val end @testset "map and friends" begin @@ -1048,7 +1047,14 @@ Base.size(st::SizedThing) = st.size test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(c.d))) == [2.0] - vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) + vnt_applied = begin + # The @inferred fails on Julia 1.10. + @static if VERSION < v"1.11" + apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i)) + else + @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].i))) + end + end @test call_counter == 4 test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" From 922fbb62f3c3717ddf36a68cad22f7dc2f082ac3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:46:41 +0000 Subject: [PATCH 42/56] Polish varinfo.jl --- src/varinfo.jl | 46 +++++++++++++++++++++------------------------- 1 file changed, 21 insertions(+), 25 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4cda6b40f..4dfa538c1 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -145,34 +145,32 @@ Set the value of `vn` in `vi` to `val`, applying a transformation based on `dist `val` is taken to be the actual value of the variable, and is transformed into the internal (vectorised) representation using a transformation based on `dist`. If the variable is -linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are linked, the -linking transformation is used; otherwise, the standard vector transformation is used. +currently linked in `vi`, or doesn't exist in `vi` but all other variables in `vi` are +linked, the linking transformation is used; otherwise, the standard vector transformation is +used. Returns the modified `vi` together with the log absolute determinant of the Jacobian of the transformation applied. """ function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) - # Determine whether to insert a transformed value into `vi`. - # If the VarInfo alrady had a value for this variable, we will - # keep the same linked status as in the original VarInfo. If not, we - # check the rest of the VarInfo to see if other variables are linked. - # is_transformed(vi) returns true if vi is nonempty and all variables in vi - # are linked. - insert_transformed_value = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) - # TODO(mhauru) We should move away from having all values vectorised by default. - # That messes with our use of unflatten though, so will require some thought. - transform = if insert_transformed_value + link = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) + transform = if link from_linked_vec_transform(dist) else from_vec_transform(dist) end transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) + # All values for which `size` is not defined are assumed to be scalars. val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () - tv = TransformedValue(transformed_val, insert_transformed_value, transform, val_size) + tv = TransformedValue(transformed_val, link, transform, val_size) vi = VarInfo(setindex!!(vi.values, tv, vn), vi.accs) return vi, logjac end +# TODO(mhauru) The below is somewhat unsafe or incomplete: For instance, from_vec_transform +# isn't defined for NamedTuples. However, this is needed in some places where values for +# in a VarInfo are set outside the context of a `tilde_assume!!` and no distribution is +# available. Hopefully we'll get rid of this eventually. """ setindex!!(vi::VarInfo, val, vn::VarName) @@ -228,17 +226,11 @@ end Get the internal (vectorised) value of variable `vn` in `vi`. """ -function getindex_internal(vi::VarInfo, vn::VarName) - tv = getindex(vi.values, vn) - return tv.val -end - +getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val +# TODO(mhauru) The below should be removed together with unflatten!!. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -function is_transformed(vi::VarInfo, vn::VarName) - tv = getindex(vi.values, vn) - return tv.linked -end +is_transformed(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).linked function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) @@ -253,6 +245,9 @@ function from_internal_transform(vi::VarInfo, vn::VarName) end function from_linked_internal_transform(vi::VarInfo, vn::VarName) + if !is_transformed(vi, vn) + error("Variable $vn is not linked; cannot get linked transformation.") + end return getindex(vi.values, vn).transform end @@ -330,11 +325,10 @@ function link!!(t::StaticTransformation{<:Bijectors.Transform}, vi::VarInfo, ::M b = inverse(t.bijector) x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - # Set parameters and add the logjac term. # TODO(mhauru) This doesn't set the transforms of `vi`. With the old Metadata that meant # that getindex(vi, vn) would apply the default link transform of the distribution. With # the new VarNamedTuple-based VarInfo it means that getindex(vi, vn) won't apply any - # transform. Neither is correct, rather the transform should be the inverse of b. + # link transform. Neither is correct, rather the transform should be the inverse of b. vi = unflatten!!(vi, y) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) @@ -417,7 +411,7 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) old_val = tv.val if !(old_val isa AbstractVector) error( - "Can not unflatten a VarInfo for which existing values are not vectors:" * + "Can't unflatten a VarInfo for which existing values are not vectors:" * " Got value of type $(typeof(old_val)).", ) end @@ -445,6 +439,8 @@ end Merge two `VarInfo`s into a new `VarInfo` containing all variables from both. +The accumulators are taken exclusively from `varinfo_right`. + If a variable exists in both `varinfo_left` and `varinfo_right`, the value from `varinfo_right` is used. """ From 66c79709942abbccc38fe01d158668b04719f96d Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:56:09 +0000 Subject: [PATCH 43/56] Polish internal docs --- docs/src/internals/varinfo.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index f3f100a81..6d87e5edc 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -14,7 +14,7 @@ It contains `values` takes care of storing information related to values of individual random variables, while `accs` keeps track of information that we keep accumulating in the course of evaluating through a model. Variables are regonised by their `VarName`. -We want to work with `VarName` rather than something like `Symbol` or `String` as `VarName` contains additional structural information. +We want to work with `VarName`s rather than something like `Symbol` or `String` as `VarName` contains additional structural information. For instance, a `Symbol("x[1]")` can be a result of either `var"x[1]" ~ Normal()` or `x[1] ~ Normal()`; these scenarios are disambiguated by `VarName`. `VarName`s also allow things such as setting values for `x[1]` and `x[2]` and getting a value for `x` as a whole. @@ -24,7 +24,6 @@ To ensure that `VarInfo` is simple and intuitive to work with we want it to repl - `haskey(::VarInfo)`: check if a particular `VarName` is present. - `getindex(::VarInfo, ::VarName)`: return the realization corresponding to a particular `VarName`. - `setindex!!(::VarInfo, val, ::VarName)`: set the realization corresponding to a particular `VarName`. - - `delete!!(::VarInfo, ::VarName)`: delete the realization corresponding to a particular `VarName`. - `empty!!(::VarInfo)`: delete all data. - `merge(::VarInfo, ::VarInfo)`: merge two containers according to similar rules as `Dict`. @@ -36,15 +35,12 @@ One can access a vectorised version of a variable's value with the following vec - `getindex_internal(::VarInfo, ::VarName)`: get the flattened value of a single variable. - `getindex_internal(::VarInfo, ::Colon)`: get the flattened values of all variables. - - `getindex_internal(::VarInfo, i::Int)`: get `i`th value of the flattened vector of all values - `setindex_internal!!(::VarInfo, ::AbstractVector, ::VarName)`: set the flattened value of a variable. - - `setindex_internal!!(::VarInfo, val, i::Int)`: set the `i`th value of the flattened vector of all values The functions have `_internal` in their name because internally `VarInfo` always stores values as vectorised. Moreover, a link transformation can be applied to a `VarInfo` with `link!!` (and reversed with `invlink!!`), which applies a reversible transformation to the internal storage format of a variable that makes the range of the random variable cover all of Euclidean space. `getindex_internal` and `setindex_internal!` give direct access to the vectorised value after such a transformation, which is what samplers often need to be able sample in unconstrained space. -One can also manually set a transformation by giving `setindex_internal!!` a fourth, optional argument, that is a function that maps internally stored value to the actual value of the variable. Finally, we want want the underlying storage to have a few performance-related properties: From 51fdcbec611911be498d4a6499c5d6f09ae24cf1 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 13 Jan 2026 19:58:59 +0000 Subject: [PATCH 44/56] More broken inference tests on v1.10 --- test/varnamedtuple.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/varnamedtuple.jl b/test/varnamedtuple.jl index 63ee12c5b..abd1406e6 100644 --- a/test/varnamedtuple.jl +++ b/test/varnamedtuple.jl @@ -1060,7 +1060,14 @@ Base.size(st::SizedThing) = st.size @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].j))) == 5.0 - vnt_applied = @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) + vnt_applied = begin + # The @inferred fails on Julia 1.10. + @static if VERSION < v"1.11" + apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j)) + else + @inferred(apply!!(f_val, vnt_applied, @varname(e.f[3].g.h[2].j))) + end + end @test call_counter == 5 test_invariants(vnt_applied; skip=(:parseeval,)) @test @inferred(getindex(vnt_applied, @varname(e.f[3].g.h[2].i))) == "ab" From 07a13c4532c53da6363ed239c36d2cfb045300ce Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:11:09 +0000 Subject: [PATCH 45/56] Export VarNamedTuple and its functions --- src/DynamicPPL.jl | 4 ++++ src/values_as_in_model.jl | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 5889a6915..9961125e2 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -46,6 +46,10 @@ import Base: # VarInfo export AbstractVarInfo, VarInfo, + VarNamedTuple, + map_pairs!!, + map_values!!, + apply!!, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, diff --git a/src/values_as_in_model.jl b/src/values_as_in_model.jl index 304b99a3e..9ee622424 100644 --- a/src/values_as_in_model.jl +++ b/src/values_as_in_model.jl @@ -111,7 +111,7 @@ julia> # Construct initial `VarInfo`. varinfo = VarInfo(rng, model); julia> # Link it so it works in unconstrained space. - varinfo_linked = DynamicPPL.link!!(copy(varinfo), model); + varinfo_linked = DynamicPPL.link(varinfo, model); julia> # Perform computations in unconstrained space, e.g. changing the values of `vals`. # Flip `x` so we hit the other support of `y`. @@ -125,7 +125,7 @@ julia> # Determine the expected support of `y`. (0, 1) julia> # Approach 1: Convert back to constrained space using `invlink` and extract. - varinfo_invlinked = DynamicPPL.invlink!!(copy(varinfo_linked), model); + varinfo_invlinked = DynamicPPL.invlink(varinfo_linked, model); julia> lb ≤ first(varinfo_invlinked[@varname(y)]) ≤ ub true From 92dd490a0839b8a7b703d8efc838f6eb70e3c537 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:34:46 +0000 Subject: [PATCH 46/56] Add HISTROY.md entry on the new VarInfo --- HISTORY.md | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index bb40b8464..d6c13f7a6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,6 +2,68 @@ ## 0.40 +### `VarNamedTuple` + +DynamicPPL now exports a new type, called `VarNamedTuple`, which stores values keyed by `VarName`s. +With it are exported a few new functions for using it: `map_values!!`, `map_pairs!!`, `apply!!`. +Our documentation's Internals section now has a page about `VarNamedTuple`, how it works, and what it's good for. + +`VarNamedTuple` is now used internally in many different parts: In `VarInfo`, in `values_as_in_model`, in `LogDensityFunction`, etc. +Almost all of the below changes are the consequence from switching over to using `VarNamedTuple` for various features internally. + +### Overhaul of `VarInfo` + +DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types. +Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped, and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. +These have all been replaced by a rewritten implementation of `VarInfo`. +While the basics of the `VarInfo` interface remain the same, this brings with it many changes: + +#### No more many `AbstractVarInfo` types + +`SimpleVarInfo`, `untyped_varinfo`, `typed_varinfo`, and many other constructors, some exported some not, have been removed. +The remaining one is `VarInfo(...)`, which can take a model or a collection of values. +See the docstring for details. + +Some related types and functions, that weren't exported but may have been used by some, have also been removed, most notably `VarNamedVector` and its associated functions like `loosen_types!!` and `tighten_types!!`. + +#### Setting and getting values + +Previously the various `AbstractVarInfo` types had a multitude of functions for setting values: +`push!!`, `push!`, `setindex!`, `update!`, `update_internal!`, `insert_internal!`, `reset!`, etc. +These have all been replaced by three functions + + - `setindex!!` is the one to use for simply setting a variable in `VarInfo` to a known value. It works regardless of whether the variable already exists. + - `setindex_internal!!` is the one to use for setting the internal, vectorised representation of a variable. See the docstring for details. + - `setindex_with_dist!!` is to be used when you want to set a value, but choose the internal representation based on which distribution this value is a sample for. + +The order of the arguments for some of these functions has also changed, and now more closely matches the usual convention for `setindex!!`. + +Note that `setindex!` (with a single `!`) is not defined, and thus you can't do `varinfo[varname] = new_value`. + +`unflatten` works as before, but has been renamed to `unflatten!!`, since it may mutate the first argument and aliases memory with the second argument (it uses views rather than copies of the input vector). + +#### Linking is now safer + +`link!!` and `invlink!!` on `VarInfo` used to assume that the prior distribution of a variable didn't change from one execution to another (as it does in e.g. `truncated(dist; lower=x)` where `x` is a random variable). +This is no longer the case. +Linking should thus be safer to do. +The cost to pay is that calls to `link!!` and `invlink!!` (and the non-mutating versions) now trigger a model evaluation, to determine the correct priors to use. + +#### Other miscellanea + + - The `Experimental` module had functions like `Experimental.determine_suitable_varinfo` for determining which `AbstractVarInfo` type was suitable for a given model. This is now redundant and has been removed. + - `Bijectors.bijector(::Model)`, which creates a bijector from the vectorised variable space of the model to the linked variable space of the model, now has slightly different optional arguments. See the docstring for details. + - `NamedDist` no longer allows variable names with `Colon`s in them, such as `x[:]`. + +There are probably also changes to the `VarInfo` interface that we've neglected to document here, since the overhaul of `VarInfo` has been quite complete. +If anything related to `VarInfo` is behaving unexpectedly, e.g. the arguments or return type of a function seem to have changed, please check the docstring, which should be comprehensive. + +#### Performance benefits + +The purpose of this overhaul of `VarInfo` is code simplification and performance benefits. + +TODO(mhauru) Add some basic summary of what has gotten faster by how much. + ### Changes to indexing random variables with square brackets 0.40 internally reimplements how DynamicPPL handles random variables like `x[1]`, `x.y[2,2]`, and `x[:,1:4,5]`, i.e. ones that use indexing with square brackets. From 06f6c1efdb3f566a0f1e49162e7e75c5d5396fb9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:55:24 +0000 Subject: [PATCH 47/56] Apply suggestions from code review Co-authored-by: Penelope Yong --- src/varinfo.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4dfa538c1..06623ca25 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -12,9 +12,9 @@ to be valid internal representations of the variable (i.e., whether the variable linked), as well as the size of the actual post-transformation value. These are all fields of [`TransformedValue`](@ref). -Note that `setindex!!` and `getindex` on `VarInfo` deal with the actual values of variables. -To get access to the internal vectorised values, use [`getindex_internal`](@ref), -[`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). +Note that `setindex!!` and `getindex` on `VarInfo` take and return values in the support of +the original distribution. To get access to the internal vectorised values, use +[`getindex_internal`](@ref), [`setindex_internal!!`](@ref), and [`unflatten!!`](@ref). There's also a `VarInfo`-specific function [`setindex_with_dist!!`](@ref), which sets a variable's value with a transformation based on the statistical distribution this value is @@ -105,7 +105,7 @@ function Base.getindex(vi::VarInfo, vn::VarName) return tv.transform(tv.val) end -function Base.getindex(vi::VarInfo, vns::Vector{<:VarName}) +function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) return [getindex(vi, vn) for vn in vns] end From 4f893bc1665169df66de4ae088fb81edc8bb81c8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 15:57:57 +0000 Subject: [PATCH 48/56] Use SkipSizeCheck rather than Val(:pass) --- src/varnamedtuple.jl | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 0287e393b..802b8222b 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -102,6 +102,13 @@ const INDEX_TYPES = Union{Integer,AbstractUnitRange,Colon,AbstractPPL.Concretize _unwrap_concretized_slice(cs::AbstractPPL.ConcretizedSlice) = cs.range _unwrap_concretized_slice(x::Union{Integer,AbstractUnitRange,Colon}) = x +""" + SkipSizeCheck() + +A special return value for `vnt_size` indicating that size checks should be skipped. +""" +struct SkipSizeCheck end + """ vnt_size(x) @@ -111,7 +118,7 @@ By default, this falls back onto `Base.size`, but can be overloaded for custom t This notion of type is used to determine whether a value can be set into a `PartialArray` as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. -A special return value of `Val(:pass)` indicates that the size check should be skipped. +A special return value of `SkipSizeCheck()` indicates that the size check should be skipped. """ vnt_size(x) = size(x) @@ -301,7 +308,7 @@ _internal_size(pa::PartialArray, args...) = size(pa.data, args...) # be stored as a PartialArray wrapped in an ArrayLikeBlock, stored in another PartialArray. # Note that this bypasses _any_ size checks, so that e.g. @varname(x[1:3][1,15]) is also a # valid key. -vnt_size(pa::PartialArray) = Val(:pass) +vnt_size(::PartialArray) = SkipSizeCheck() function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively @@ -686,7 +693,7 @@ function _setindex!!(pa::PartialArray, value, inds::Vararg{INDEX_TYPES}) new_data = pa.data if _needs_arraylikeblock(value, inds...) inds_size = reduce((x, y) -> tuple(x..., y...), map(size, inds)) - if vnt_size(value) !== Val(:pass) && vnt_size(value) != inds_size + if !(vnt_size(value) isa SkipSizeCheck) && vnt_size(value) != inds_size throw( DimensionMismatch( "Assigned value has size $(vnt_size(value)), which does not match " * @@ -1216,7 +1223,7 @@ function _map_recursive!!(func, alb::ArrayLikeBlock, vn) new_block = _map_recursive!!(func, alb.block, vn) sz_new = vnt_size(new_block) sz_old = vnt_size(alb.block) - if sz_new !== Val(:pass) && sz_old !== Val(:pass) && sz_new != sz_old + if !(sz_new isa SkipSizeCheck) && !(sz_old isa SkipSizeCheck) && sz_new != sz_old throw( DimensionMismatch( "map_pairs!! can't change the size of an ArrayLikeBlock. Tried to change " * From fdb1373c78a68b96069e2c1c2f5b9096a82a0a57 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 16:01:27 +0000 Subject: [PATCH 49/56] Remove getindex with dist argument --- src/varinfo.jl | 5 ----- test/linking.jl | 6 +++--- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 06623ca25..3e026648b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -109,11 +109,6 @@ function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) return [getindex(vi, vn) for vn in vns] end -function Base.getindex(vi::VarInfo, vn::VarName, dist::Distribution) - val = getindex_internal(vi, vn) - return from_maybe_linked_internal(vi, vn, dist, val) -end - Base.isempty(vi::VarInfo) = isempty(vi.values) Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) diff --git a/test/linking.jl b/test/linking.jl index 2047b9d11..bfd1285b1 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -89,7 +89,7 @@ end DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) # The non-internal logjoint should be the same since it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) - @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) + @test vi_linked[@varname(m)] == LowerTriangular(vi[@varname(m)]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @test length(vi_linked[:]) == length(y) @@ -100,7 +100,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == length(vi[:]) - @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + @test vi_invlinked[@varname(m)] ≈ LowerTriangular(vi[@varname(m)]) # The non-internal logjoint should still be the same, again since # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) @@ -121,7 +121,7 @@ end model, values_original, (@varname(x),) ) @testset "$(short_varinfo_name(vi))" for vi in vis - val = vi[@varname(x), dist] + val = vi[@varname(x)] # Ensure that `reconstruct` works as intended. @test val isa Cholesky @test val.uplo == uplo From a023a7fc6c8a57d3941510afda1bff070d7a2dd0 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 14 Jan 2026 16:10:59 +0000 Subject: [PATCH 50/56] Simplify map and mapreduce for VNT --- src/varnamedtuple.jl | 57 ++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 802b8222b..3f758852e 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -1234,16 +1234,6 @@ function _map_recursive!!(func, alb::ArrayLikeBlock, vn) return ArrayLikeBlock(new_block, alb.inds) end -@generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}) where {Names} - exs = Expr[] - for name in Names - push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) - end - return quote - return VarNamedTuple(NamedTuple{Names}(($(exs...),))) - end -end - # As above but with a prefix VarName `vn`. @generated function _map_recursive!!(func, vnt::VarNamedTuple{Names}, vn::T) where {Names,T} exs = Expr[] @@ -1267,7 +1257,17 @@ Apply `func` to all key => value pairs of `vnt`, in place if possible. `func` should accept a pair of `VarName` and value, and return the new value to be set. """ -map_pairs!!(func, vnt::VarNamedTuple) = _map_recursive!!(func, vnt) +@generated function map_pairs!!(func, vnt::VarNamedTuple{Names}) where {Names} + exs = Expr[] + for name in Names + push!(exs, :(_map_recursive!!(func, vnt.data.$name, VarName{$(QuoteNode(name))}()))) + end + return quote + return VarNamedTuple(NamedTuple{Names}(($(exs...),))) + end +end + +Base.foreach(func, vnt::VarNamedTuple) = map_pairs!!(p -> (func(p); p), vnt) """ map_values!!(func, vnt::VarNamedTuple) @@ -1289,26 +1289,19 @@ is not optional. `f` op` should accept pairs of `varname => value`. """ -function Base.mapreduce(f, op, vnt::VarNamedTuple; init=nothing) - if init === nothing - throw( - NotImplementedError( - "mapreduce without init is not implemented for VarNamedTuple." - ), - ) +@generated function Base.mapreduce( + f, op, vnt::VarNamedTuple{Names}; init::InitType=nothing +) where {Names,InitType} + if InitType === Nothing + return quote + throw( + ArgumentError( + "mapreduce without init is not implemented for VarNamedTuple." + ), + ) + end end - return _mapreduce_recursive(f, op, vnt, init) -end - -# Our mapreduce is always left-associative. -Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) - -_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) -_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) -@generated function _mapreduce_recursive( - f, op, vnt::VarNamedTuple{Names}, init -) where {Names} exs = Expr[:(result = init)] for name in Names push!( @@ -1324,6 +1317,12 @@ _mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa. return Expr(:block, exs...) end +# Our mapreduce is always left-associative. +Base.mapfoldl(f, op, vnt::VarNamedTuple; init=nothing) = mapreduce(f, op, vnt; init=init) + +_mapreduce_recursive(f, op, x, vn, init) = op(init, f(vn => x)) +_mapreduce_recursive(f, op, pa::ArrayLikeBlock, vn, init) = op(init, f(vn => pa.block)) + # As above but with a prefix VarName `vn`. @generated function _mapreduce_recursive( f, op, vnt::VarNamedTuple{Names}, vn, init From c369b0926bceefa859883daf4777280ec98b13b8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 10:28:16 +0000 Subject: [PATCH 51/56] Remove unused utility functions --- src/abstract_varinfo.jl | 8 +- src/accumulators.jl | 3 +- src/chains.jl | 2 +- src/contexts/init.jl | 13 +- src/logdensityfunction.jl | 10 +- src/utils.jl | 320 -------------------------------------- src/varinfo.jl | 2 +- src/varname.jl | 17 -- test/utils.jl | 23 --- test/varinfo.jl | 4 +- 10 files changed, 21 insertions(+), 381 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 67ac822cd..51341e3d4 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -727,7 +727,7 @@ See also: [`default_transformation`](@ref), [`invlink!!`](@ref). function link!!(vi::AbstractVarInfo, model::Model) return link!!(default_transformation(model, vi), vi, model) end -function link!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function link!!(vi::AbstractVarInfo, vns, model::Model) return link!!(default_transformation(model, vi), vi, vns, model) end @@ -746,7 +746,7 @@ See also: [`default_transformation`](@ref), [`invlink`](@ref). function link(vi::AbstractVarInfo, model::Model) return link(default_transformation(model, vi), vi, model) end -function link(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function link(vi::AbstractVarInfo, vns, model::Model) return link(default_transformation(model, vi), vi, vns, model) end function link(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) @@ -769,7 +769,7 @@ See also: [`default_transformation`](@ref), [`link!!`](@ref). function invlink!!(vi::AbstractVarInfo, model::Model) return invlink!!(default_transformation(model, vi), vi, model) end -function invlink!!(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function invlink!!(vi::AbstractVarInfo, vns, model::Model) return invlink!!(default_transformation(model, vi), vi, vns, model) end @@ -789,7 +789,7 @@ See also: [`default_transformation`](@ref), [`link`](@ref). function invlink(vi::AbstractVarInfo, model::Model) return invlink(default_transformation(model, vi), vi, model) end -function invlink(vi::AbstractVarInfo, vns::VarNameTuple, model::Model) +function invlink(vi::AbstractVarInfo, vns, model::Model) return invlink(default_transformation(model, vi), vi, vns, model) end function invlink(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) diff --git a/src/accumulators.jl b/src/accumulators.jl index 0208f19a5..ae1c26094 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -118,8 +118,7 @@ See also: [`split`](@ref) """ function combine end -# TODO(mhauru) The existence of this function makes me sad. See comment in unflatten in -# src/varinfo.jl. +# TODO(mhauru) The existence of this function makes me sad. See comment in src/model.jl. """ convert_eltype(::Type{T}, acc::AbstractAccumulator) diff --git a/src/chains.jl b/src/chains.jl index ee4312547..cfd27d87a 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -113,7 +113,7 @@ Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided `param_vector`. This method is intended to replace the old method of obtaining parameters and statistics -via `unflatten` plus re-evaluation. It is faster for two reasons: +via `unflatten!!` plus re-evaluation. It is faster for two reasons: 1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 65ea08ec5..d92bc35f8 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -56,12 +56,13 @@ used to determine whether the float type needs to be modified). In case that wasn't enough: in fact, even the above is not always true. Firstly, the accumulator argument is only true when evaluating with ThreadSafeVarInfo. See the comments -in `DynamicPPL.unflatten` for more details. For non-threadsafe evaluation, Julia is capable -of automatically promoting the types on its own. Secondly, the promotion only matters if you -are trying to directly assign into a `Vector{Float64}` with a `ForwardDiff.Dual` or similar -tracer type, for example using `xs[i] = MyDual`. This doesn't actually apply to -tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` under the hood, which -also does the promotion for you. For the gory details, see the following issues: +in `DynamicPPL.unflatten!!` for more details. For non-threadsafe evaluation, Julia is +capable of automatically promoting the types on its own. Secondly, the promotion only +matters if you are trying to directly assign into a `Vector{Float64}` with a +`ForwardDiff.Dual` or similar tracer type, for example using `xs[i] = MyDual`. This doesn't +actually apply to tilde-statements like `xs[i] ~ ...` because those use `Accessors.@set` +under the hood, which also does the promotion for you. For the gory details, see the +following issues: - https://github.com/TuringLang/DynamicPPL.jl/issues/906 for accumulator types - https://github.com/TuringLang/DynamicPPL.jl/issues/823 for type argument promotion diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 17101d0d2..6ae1dc3a1 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -87,9 +87,9 @@ from: Up until DynamicPPL v0.38, there have been two ways of evaluating a DynamicPPL model at a given set of parameters: -1. With `unflatten` + `evaluate!!` with `DefaultContext`: this stores a vector of parameters - inside a VarInfo's metadata, then reads parameter values from the VarInfo during - evaluation. +1. With `unflatten!!` + `evaluate!!` with `DefaultContext`: this stores a vector of + parameters inside a VarInfo's metadata, then reads parameter values from the VarInfo + during evaluation. 2. With `InitFromParams`: this reads parameter values from a NamedTuple or a Dict, and stores them inside a VarInfo's metadata. @@ -114,7 +114,7 @@ In particular, it is not clear: - which parts of the vector correspond to which random variables, and - whether the variables are linked or unlinked. -Traditionally, this problem has been solved by `unflatten`, because that function would +Traditionally, this problem has been solved by `unflatten!!`, because that function would place values into the VarInfo's metadata alongside the information about ranges and linking. That way, when we evaluate with `DefaultContext`, we can read this information out again. However, we want to avoid using a metadata. Thus, here, we _extract this information from @@ -131,7 +131,7 @@ the `LogDensityFunction` object. Therefore, a `LogDensityFunction` object cannot models which have variable numbers of parameters, or models which may visit random variables in different orders depending on stochastic control flow. **Indeed, silent errors may occur with such models.** This is a general limitation of vectorised parameters: the original -`unflatten` + `evaluate!!` approach also fails with such models. +`unflatten!!` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ # true if all variables are linked; false if all variables are unlinked; nothing if diff --git a/src/utils.jl b/src/utils.jl index 4a0eea96c..f0f46157b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,13 +2,6 @@ # defined in other files. function subset end -# singleton for indicating if no default arguments are present -struct NoDefault end -const NO_DEFAULT = NoDefault() - -# A short-hand for a type commonly used in type signatures for VarInfo methods. -VarNameTuple = NTuple{N,VarName} where {N} - """ The type for all log probability variables. @@ -545,275 +538,6 @@ tovec(t::Tuple) = mapreduce(tovec, vcat, t) tovec(nt::NamedTuple) = mapreduce(tovec, vcat, values(nt)) tovec(C::Cholesky) = tovec(Matrix(C.UL)) -""" - recombine(dist::Union{UnivariateDistribution,MultivariateDistribution}, vals::AbstractVector, n::Int) - -Recombine `vals`, representing a batch of samples from `dist`, so that it's a compatible with `dist`. - -!!! warning - This only supports `UnivariateDistribution` and `MultivariateDistribution`, which are the only two - distribution types which are allowed on the right-hand side of a `.~` statement in a model. -""" -function recombine(::UnivariateDistribution, val::AbstractVector, ::Int) - # This is just a no-op, since we're trying to convert a vector into a vector. - return copy(val) -end -function recombine(d::MultivariateDistribution, val::AbstractVector, n::Int) - # Here `val` is of the length `length(d) * n` and so we need to reshape it. - return copy(reshape(val, length(d), n)) -end - -####################### -# Convenience methods # -####################### -""" - collect_maybe(x) - -Return `x` if `x` is an array, otherwise return `collect(x)`. -""" -collect_maybe(x) = collect(x) -collect_maybe(x::AbstractArray) = x - -####################### -# BangBang.jl related # -####################### -function set!!(obj, optic::AbstractPPL.ALLOWED_OPTICS, value) - opticmut = BangBang.prefermutation(optic) - return Accessors.set(obj, opticmut, value) -end -function set!!(obj, vn::VarName{sym}, value) where {sym} - optic = BangBang.prefermutation( - AbstractPPL.getoptic(vn) ∘ Accessors.PropertyLens{sym}() - ) - return Accessors.set(obj, optic, value) -end - -############################# -# AbstractPPL.jl extensions # -############################# -# This is preferable to `haskey` because the order of arguments is different, and -# we're more likely to specialize on the key in these settings rather than the container. -# TODO: I'm not sure about this name. -""" - canview(optic, container) - -Return `true` if `optic` can be used to view `container`, and `false` otherwise. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: canview) -julia> canview(@o(_.a), (a = 1.0, )) -true - -julia> canview(@o(_.a), (b = 1.0, )) # property `a` does not exist -false - -julia> canview(@o(_.a[1]), (a = [1.0, 2.0], )) -true - -julia> canview(@o(_.a[3]), (a = [1.0, 2.0], )) # out of bounds -false -``` -""" -canview(optic, container) = false -canview(::typeof(identity), _) = true -function canview(optic::Accessors.PropertyLens{field}, x) where {field} - return hasproperty(x, field) -end - -# `IndexLens`: only relevant if `x` supports indexing. -canview(optic::Accessors.IndexLens, x) = false -function canview(optic::Accessors.IndexLens, x::AbstractArray) - return checkbounds(Bool, x, optic.indices...) -end - -# `ComposedOptic`: check that we can view `.inner` and `.outer`, but using -# value extracted using `.inner`. -function canview(optic::Accessors.ComposedOptic, x) - return canview(optic.inner, x) && canview(optic.outer, optic.inner(x)) -end - -""" - parent(vn::VarName) - -Return the parent `VarName`. - -# Examples -```julia-repl; setup=:(using DynamicPPL: parent) -julia> parent(@varname(x.a[1])) -x.a - -julia> (parent ∘ parent)(@varname(x.a[1])) -x - -julia> (parent ∘ parent ∘ parent)(@varname(x.a[1])) -x -``` -""" -function parent(vn::VarName) - p = parent(getoptic(vn)) - return p === nothing ? VarName{getsym(vn)}(identity) : VarName{getsym(vn)}(p) -end - -""" - parent(optic) - -Return the parent optic. If `optic` doesn't have a parent, -`nothing` is returned. - -See also: [`parent_and_child`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent) -julia> parent(@o(_.a[1])) -(@o _.a) - -julia> # Parent of optic without parents results in `nothing`. - (parent ∘ parent)(@o(_.a[1])) === nothing -true -``` -""" -parent(optic::AbstractPPL.ALLOWED_OPTICS) = first(parent_and_child(optic)) - -""" - parent_and_child(optic) - -Return a 2-tuple of optics `(parent, child)` where `parent` is the -parent optic of `optic` and `child` is the child optic of `optic`. - -If `optic` does not have a parent, we return `(nothing, optic)`. - -See also: [`parent`]. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: parent_and_child) -julia> parent_and_child(@o(_.a[1])) -((@o _.a), (@o _[1])) - -julia> parent_and_child(@o(_.a)) -(nothing, (@o _.a)) -``` -""" -parent_and_child(optic::AbstractPPL.ALLOWED_OPTICS) = (nothing, optic) -function parent_and_child(optic::Accessors.ComposedOptic) - p, child = parent_and_child(optic.outer) - parent = p === nothing ? optic.inner : p ∘ optic.inner - return parent, child -end - -""" - splitoptic(condition, optic) - -Return a 3-tuple `(parent, child, issuccess)` where, if `issuccess` is `true`, -`parent` is a optic such that `condition(parent)` is `true` and `child ∘ parent == optic`. - -If `issuccess` is `false`, then no such split could be found. - -# Examples -```jldoctest; setup=:(using Accessors; using DynamicPPL: splitoptic) -julia> p, c, issucesss = splitoptic(@o(_.a[1])) do parent - # Succeeds! - parent == @o(_.a) - end -((@o _.a), (@o _[1]), true) - -julia> c ∘ p -(@o _.a[1]) - -julia> splitoptic(@o(_.a[1])) do parent - # Fails! - parent == @o(_.b) - end -(nothing, (@o _.a[1]), false) -``` -""" -function splitoptic(condition, optic) - current_parent, current_child = parent_and_child(optic) - # We stop if either a) `condition` is satisfied, or b) we reached the root. - while !condition(current_parent) && current_parent !== nothing - current_parent, c = parent_and_child(current_parent) - current_child = current_child ∘ c - end - - return current_parent, current_child, condition(current_parent) -end - -""" - remove_parent_optic(vn_parent::VarName, vn_child::VarName) - -Remove the parent optic `vn_parent` from `vn_child`. - -# Examples -```jldoctest; setup = :(using Accessors; using DynamicPPL: remove_parent_optic) -julia> remove_parent_optic(@varname(x), @varname(x.a)) -(@o _.a) - -julia> remove_parent_optic(@varname(x), @varname(x.a[1])) -(@o _.a[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1])) -(@o _[1]) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a[1].b)) -(@o _[1].b) - -julia> remove_parent_optic(@varname(x.a), @varname(x.a)) -ERROR: Could not find x.a in x.a - -julia> remove_parent_optic(@varname(x.a[2]), @varname(x.a[1])) -ERROR: Could not find x.a[2] in x.a[1] -``` -""" -function remove_parent_optic(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} - _, child, issuccess = splitoptic(getoptic(vn_child)) do optic - o = optic === nothing ? identity : optic - o == getoptic(vn_parent) - end - - issuccess || error("Could not find $vn_parent in $vn_child") - return child -end - -# HACK(torfjelde): This makes it so it works on iterators, etc. by default. -# TODO(torfjelde): Do better. -""" - unflatten(original, x::AbstractVector) - -Return instance of `original` constructed from `x`. -""" -function unflatten(original, x::AbstractVector) - lengths = map(length, original) - end_indices = cumsum(lengths) - return map(zip(original, lengths, end_indices)) do (v, l, end_idx) - start_idx = end_idx - l + 1 - return unflatten(v, @view(x[start_idx:end_idx])) - end -end - -unflatten(::Real, x::Real) = x -unflatten(::Real, x::AbstractVector) = only(x) -unflatten(::AbstractVector{<:Real}, x::Real) = vcat(x) -unflatten(::AbstractVector{<:Real}, x::AbstractVector) = x -unflatten(original::AbstractArray{<:Real}, x::AbstractVector) = reshape(x, size(original)) - -function unflatten(original::Tuple, x::AbstractVector) - lengths = map(length, original) - end_indices = cumsum(lengths) - return ntuple(length(original)) do i - v = original[i] - l = lengths[i] - end_idx = end_indices[i] - start_idx = end_idx - l + 1 - return unflatten(v, @view(x[start_idx:end_idx])) - end -end -function unflatten(original::NamedTuple{names}, x::AbstractVector) where {names} - return NamedTuple{names}(unflatten(values(original), x)) -end -function unflatten(original::AbstractDict, x::AbstractVector) - D = ConstructionBase.constructorof(typeof(original)) - return D(zip(keys(original), unflatten(collect(values(original)), x))) -end - """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) @@ -916,50 +640,6 @@ _merge(left::AbstractDict, right::NamedTuple) = merge(left, to_varname_dict(righ _merge(::NamedTuple{()}, right::AbstractDict) = right _merge(left::NamedTuple, right::AbstractDict) = merge(to_varname_dict(left), right) -""" - unique_syms(vns::T) where {T<:NTuple{N,VarName}} - -Return the unique symbols of the variables in `vns`. - -Note that `unique_syms` is only defined for `Tuple`s of `VarName`s and, unlike -`Base.unique`, returns a `Tuple`. The point of `unique_syms` is that it supports constant -propagating the result, which is possible only when the argument and the return value are -`Tuple`s. -""" -@generated function unique_syms(::T) where {T<:VarNameTuple} - retval = Expr(:tuple) - syms = [first(vn.parameters) for vn in T.parameters] - for sym in unique(syms) - push!(retval.args, QuoteNode(sym)) - end - return retval -end - -""" - group_varnames_by_symbol(vns::NTuple{N,VarName}) where {N} - -Return a `NamedTuple` of the variables in `vns` grouped by symbol. - -Note that `group_varnames_by_symbol` only accepts a `Tuple` of `VarName`s. This allows it to -be type stable. - -Example: -```julia -julia> vns_tuple = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) -(x, y[1], x.a, z[15], y[2]) - -julia> vns_nt = (; x=[@varname(x), @varname(x.a)], y=[@varname(y[1]), @varname(y[2])], z=[@varname(z[15])]) -(x = VarName{:x}[x, x.a], y = VarName{:y, IndexLens{Tuple{Int64}}}[y[1], y[2]], z = VarName{:z, IndexLens{Tuple{Int64}}}[z[15]]) - -julia> group_varnames_by_symbol(vns_tuple) == vns_nt -``` -""" -function group_varnames_by_symbol(vns::VarNameTuple) - syms = unique_syms(vns) - elements = map(collect, tuple((filter(vn -> getsym(vn) == s, vns) for s in syms)...)) - return NamedTuple{syms}(elements) -end - """ basetypeof(x) diff --git a/src/varinfo.jl b/src/varinfo.jl index 3e026648b..860fb7372 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -406,7 +406,7 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) old_val = tv.val if !(old_val isa AbstractVector) error( - "Can't unflatten a VarInfo for which existing values are not vectors:" * + "Can't unflatten!! a VarInfo for which existing values are not vectors:" * " Got value of type $(typeof(old_val)).", ) end diff --git a/src/varname.jl b/src/varname.jl index 7ffe9cc08..e1492bb32 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -1,20 +1,3 @@ -""" - subsumes_string(u::String, v::String[, u_indexing]) - -Check whether stringified variable name `v` describes a sub-range of stringified variable `u`. - -This is a very restricted version `subumes(u::VarName, v::VarName)` only really supporting: -- Scalar: `x` subsumes `x[1, 2]`, `x[1, 2]` subsumes `x[1, 2][3]`, etc. - -## Note -- To get same matching capabilities as `AbstractPPL.subumes(u::VarName, v::VarName)` - for strings, one can always do `eval(varname(Meta.parse(u))` to get `VarName` of `u`, - and similarly to `v`. But this is slow. -""" -function subsumes_string(u::String, v::String, u_indexing=u * "[") - return u == v || startswith(v, u_indexing) -end - """ inargnames(varname::VarName, model::Model) diff --git a/test/utils.jl b/test/utils.jl index bef1c2ba8..bc01fc0ce 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -186,29 +186,6 @@ end t = (2.0, [3.0, 4.0]) @test DynamicPPL.tovec(t) == [2.0, 3.0, 4.0] end - - @testset "unique_syms" begin - vns = (@varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2])) - @inferred DynamicPPL.unique_syms(vns) - @inferred DynamicPPL.unique_syms(()) - @test DynamicPPL.unique_syms(vns) == (:x, :y, :z) - @test DynamicPPL.unique_syms(()) == () - end - - @testset "group_varnames_by_symbol" begin - vns_tuple = ( - @varname(x), @varname(y[1]), @varname(x.a), @varname(z[15]), @varname(y[2]) - ) - vns_vec = collect(vns_tuple) - vns_nt = (; - x=[@varname(x), @varname(x.a)], - y=[@varname(y[1]), @varname(y[2])], - z=[@varname(z[15])], - ) - vns_vec_single_symbol = [@varname(x.a), @varname(x.b), @varname(x[1])] - @inferred DynamicPPL.group_varnames_by_symbol(vns_tuple) - @test DynamicPPL.group_varnames_by_symbol(vns_tuple) == vns_nt - end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index 8ae0535c7..639b4f688 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -346,7 +346,7 @@ end end end - @testset "unflatten + linking" begin + @testset "unflatten!! + linking" begin @testset "Model: $(model.f)" for model in [ DynamicPPL.TestUtils.demo_one_variable_multiple_constraints(), DynamicPPL.TestUtils.demo_lkjchol(), @@ -403,7 +403,7 @@ end end end - @testset "unflatten type stability" begin + @testset "unflatten!! type stability" begin @model function demo(y) x ~ Normal() y ~ Normal(x, 1) From 6128a562203936a5d6da71cf4e0b19d84cbff289 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 13:16:49 +0000 Subject: [PATCH 52/56] Use OnlyAccsVarInfo in extract_priors --- src/extract_priors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 8c7b5f7db..182e933e4 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -121,7 +121,7 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - varinfo = VarInfo() + varinfo = OnlyAccsVarInfo() varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors From 8fddfef6ce719f8819f2a32b6504bd27a9c00196 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 13:25:37 +0000 Subject: [PATCH 53/56] Make linking status a type parameter of VarInfo --- src/logdensityfunction.jl | 2 +- src/varinfo.jl | 132 +++++++++++++++++++++++++++----------- test/varinfo.jl | 6 +- 3 files changed, 99 insertions(+), 41 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 6ae1dc3a1..9337d159c 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -313,7 +313,7 @@ function get_ranges_and_linked(vi::VarInfo) val = tv.val range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, tv.linked, tv.size) + ral = RangeAndLinked(range, is_transformed(tv), tv.size) vnt = setindex!!(vnt, ral, vn) return vnt, offset end, diff --git a/src/varinfo.jl b/src/varinfo.jl index 860fb7372..a59837ba0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1,8 +1,12 @@ """ - VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo + VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo The default implementation of `AbstractVarInfo`, storing variable values and accumulators. +The `Linked` type parameter is either `true` or `false` to mark that all variables in this +`VarInfo` are linked, or `nothing` to indicate that some variables may be linked and some +not, and a runtime check is needed. + `VarInfo` is quite a thin wrapper around a `VarNamedTuple` storing the variable values, and a tuple of accumulators. The only really noteworthy thing about it is that it stores the values of variables vectorised as instances of `TransformedValue`. That is, it stores @@ -27,9 +31,15 @@ For more details on the internal storage, see documentation of [`TransformedValu $(TYPEDFIELDS) """ -struct VarInfo{T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo +struct VarInfo{Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} <: AbstractVarInfo values::T accs::Accs + + function VarInfo{Linked}( + values::T, accs::Accs + ) where {Linked,T<:VarNamedTuple,Accs<:AccumulatorTuple} + return new{Linked,T,Accs}(values, accs) + end end # TODO(mhauru) The policy of vectorising all values was set when the old VarInfo type was @@ -42,31 +52,43 @@ end # during model execution. However, this would change the interface quite a lot, so I want to # finish implementing VarInfo using VNT (mostly) respecting the old interface first. +# TODO(mhauru) We are considering removing `transform` completely, and forcing people to use +# ValuesAsInModelAcc instead. If that is done, we may want to move the Linked type parameter +# to just be a bool field. It's currently a type parameter to make the type of `transform` +# easier to type infer, but if `transform` no longer exists, it might start to cause +# unnecessary type inconcreteness in the elements of PartialArray. """ - TransformedValue{ValType,TransformType,SizeType} + TransformedValue{Linked,ValType,TransformType,SizeType} A struct for storing a variable's value in its internal (vectorised) form. +The type parameter `Linked` is a `Bool` indicating whether the variable is linked, i.e. +whether the transformation maps all real vectors to valid values. # Fields $(TYPEDFIELDS) """ -struct TransformedValue{ValType,TransformType,SizeType} +struct TransformedValue{Linked,ValType,TransformType,SizeType} "The internal (vectorised) value." val::ValType - """Boolean indicating whether the variable is linked, i.e. the transformation maps all - real vectors to valid values.""" - linked::Bool """The transformation from internal (vectorised) to actual value. In other words, the actual value of the variable being stored is `transform(val)`.""" transform::TransformType """The size of the actual value after transformation. This is needed when a `TransformedValue` is stored as a block in an array.""" size::SizeType + + function TransformedValue{Linked}( + val::ValType, transform::TransformType, size::SizeType + ) where {Linked,ValType,TransformType,SizeType} + return new{Linked,ValType,TransformType,SizeType}(val, transform, size) + end end +is_transformed(::TransformedValue{Linked}) where {Linked} = Linked + VarNamedTuples.vnt_size(tv::TransformedValue) = tv.size -VarInfo() = VarInfo(VarNamedTuple(), default_accumulators()) +VarInfo() = VarInfo{false}(VarNamedTuple(), default_accumulators()) function VarInfo(values::Union{NamedTuple,AbstractDict}) vi = VarInfo() @@ -90,11 +112,15 @@ function VarInfo( end getaccs(vi::VarInfo) = vi.accs -setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = VarInfo(vi.values, accs) +function setaccs!!(vi::VarInfo{Linked}, accs::AccumulatorTuple) where {Linked} + return VarInfo{Linked}(vi.values, accs) +end transformation(::VarInfo) = DynamicTransformation() -Base.copy(vi::VarInfo) = VarInfo(copy(vi.values), copy(getaccs(vi))) +function Base.copy(vi::VarInfo{Linked}) where {Linked} + return VarInfo{Linked}(copy(vi.values), copy(getaccs(vi))) +end Base.haskey(vi::VarInfo, vn::VarName) = haskey(vi.values, vn) Base.length(vi::VarInfo) = length(vi.values) Base.keys(vi::VarInfo) = keys(vi.values) @@ -110,8 +136,8 @@ function Base.getindex(vi::VarInfo, vns::AbstractVector{<:VarName}) end Base.isempty(vi::VarInfo) = isempty(vi.values) -Base.empty(vi::VarInfo) = VarInfo(empty(vi.values), map(reset, vi.accs)) -BangBang.empty!!(vi::VarInfo) = VarInfo(empty!!(vi.values), map(reset, vi.accs)) +Base.empty(vi::VarInfo) = VarInfo{false}(empty(vi.values), map(reset, vi.accs)) +BangBang.empty!!(vi::VarInfo) = VarInfo{false}(empty!!(vi.values), map(reset, vi.accs)) """ setindex_internal!!(vi::VarInfo, val, vn::VarName) @@ -120,11 +146,11 @@ Set the internal (vectorised) value of variable `vn` in `vi` to `val`. This does not change the transformation or linked status of the variable. """ -function setindex_internal!!(vi::VarInfo, val, vn::VarName) +function setindex_internal!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(val, old_tv.linked, old_tv.transform, old_tv.size) + new_tv = TransformedValue{is_transformed(old_tv)}(val, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VarInfo(new_values, vi.accs) + return VarInfo{Linked}(new_values, vi.accs) end # TODO(mhauru) It shouldn't really be VarInfo's business to know about `dist`. However, @@ -147,8 +173,14 @@ used. Returns the modified `vi` together with the log absolute determinant of the Jacobian of the transformation applied. """ -function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) - link = haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) +function setindex_with_dist!!( + vi::VarInfo{Linked}, val, dist::Distribution, vn::VarName +) where {Linked} + link = if Linked === nothing + haskey(vi, vn) ? is_transformed(vi, vn) : is_transformed(vi) + else + Linked + end transform = if link from_linked_vec_transform(dist) else @@ -157,8 +189,9 @@ function setindex_with_dist!!(vi::VarInfo, val, dist::Distribution, vn::VarName) transformed_val, logjac = with_logabsdet_jacobian(inverse(transform), val) # All values for which `size` is not defined are assumed to be scalars. val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () - tv = TransformedValue(transformed_val, link, transform, val_size) - vi = VarInfo(setindex!!(vi.values, tv, vn), vi.accs) + tv = TransformedValue{link}(transformed_val, transform, val_size) + new_linked = Linked == link ? Linked : nothing + vi = VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) return vi, logjac end @@ -174,11 +207,12 @@ Set the value of `vn` in `vi` to `val`. The transformation for `vn` is reset to be the standard vector transformation for values of the type of `val` and linking status is set to false. """ -function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) +function BangBang.setindex!!(vi::VarInfo{Linked}, val, vn::VarName) where {Linked} + new_linked = Linked == false ? false : nothing transform = from_vec_transform(val) transformed_val = inverse(transform)(val) - tv = TransformedValue(transformed_val, false, transform, size(val)) - return VarInfo(setindex!!(vi.values, tv, vn), vi.accs) + tv = TransformedValue{false}(transformed_val, transform, size(val)) + return VarInfo{new_linked}(setindex!!(vi.values, tv, vn), vi.accs) end """ @@ -188,11 +222,14 @@ Set the linked status of variable `vn` in `vi` to `linked`. This does not change the value or transformation of the variable. """ -function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) +function set_transformed!!(vi::VarInfo{Linked}, linked::Bool, vn::VarName) where {Linked} old_tv = getindex(vi.values, vn) - new_tv = TransformedValue(old_tv.val, linked, old_tv.transform, old_tv.size) + new_tv = TransformedValue{linked}(old_tv.val, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - return VarInfo(new_values, vi.accs) + # The below check shouldn't ever pass, this should always result in `nothing`, but may + # as well play it safe, it'll be constant propagated away anyway. + new_linked = Linked == linked ? Linked : nothing + return VarInfo{new_linked}(new_values, vi.accs) end # VarInfo does not care whether the transformation was Static or Dynamic, it just tracks @@ -211,9 +248,9 @@ set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false function set_transformed!!(vi::VarInfo, linked::Bool) new_values = map_values!!(vi.values) do tv - TransformedValue(tv.val, linked, tv.transform, tv.size) + TransformedValue{linked}(tv.val, tv.transform, tv.size) end - return VarInfo(new_values, vi.accs) + return VarInfo{linked}(new_values, vi.accs) end """ @@ -225,7 +262,13 @@ getindex_internal(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).val # TODO(mhauru) The below should be removed together with unflatten!!. getindex_internal(vi::VarInfo, ::Colon) = values_as(vi, Vector) -is_transformed(vi::VarInfo, vn::VarName) = getindex(vi.values, vn).linked +function is_transformed(vi::VarInfo{Linked}, vn::VarName) where {Linked} + return if Linked === nothing + is_transformed(getindex(vi.values, vn)) + else + Linked + end +end function from_internal_transform(::VarInfo, ::VarName, dist::Distribution) return from_vec_transform(dist) @@ -273,7 +316,7 @@ function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where { # Not one of the target variables. return tv end - if tv.linked == link + if is_transformed(tv) == link # Already in the desired state. return tv end @@ -289,11 +332,17 @@ function _link_or_invlink!!(vi::VarInfo, vns, model::Model, ::Val{link}) where { val_new, logjac2 = with_logabsdet_jacobian( inverse(new_transform), val_untransformed ) - new_tv = TransformedValue(val_new, link, new_transform, tv.size) + # !is_transformed(tv) is the same as `link`, but might be easier for type inference. + new_tv = TransformedValue{!is_transformed(tv)}(val_new, new_transform, tv.size) cumulative_logjac += logjac1 + logjac2 return new_tv end - vi = VarInfo(new_values, vi.accs) + vi_linked = if vns === nothing + link + else + nothing + end + vi = VarInfo{vi_linked}(new_values, vi.accs) if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, cumulative_logjac) end @@ -397,7 +446,7 @@ function get_next_chunk!(vci::VectorChunkIterator, len::Int) return chunk end -function unflatten!!(vi::VarInfo, vec::AbstractVector) +function unflatten!!(vi::VarInfo{Linked}, vec::AbstractVector) where {Linked} # You may wonder, why have a whole struct for this, rather than just an index variable # that the mapping function would close over. I wonder too. But for some reason type # inference fails on such an index variable, turning it into a Core.Box. @@ -412,9 +461,9 @@ function unflatten!!(vi::VarInfo, vec::AbstractVector) end len = length(old_val) new_val = get_next_chunk!(vci, len) - return TransformedValue(new_val, tv.linked, tv.transform, tv.size) + return TransformedValue{is_transformed(tv)}(new_val, tv.transform, tv.size) end - return VarInfo(new_values, vi.accs) + return VarInfo{Linked}(new_values, vi.accs) end """ @@ -424,9 +473,9 @@ Create a new `VarInfo` containing only the variables in `vns`. `vns` can be almost any collection of `VarName`s, e.g. a `Set`, `Vector`, or `Tuple`. """ -function subset(varinfo::VarInfo, vns) +function subset(varinfo::VarInfo{Linked}, vns) where {Linked} new_values = subset(varinfo.values, vns) - return VarInfo(new_values, map(copy, getaccs(varinfo))) + return VarInfo{Linked}(new_values, map(copy, getaccs(varinfo))) end """ @@ -439,8 +488,15 @@ The accumulators are taken exclusively from `varinfo_right`. If a variable exists in both `varinfo_left` and `varinfo_right`, the value from `varinfo_right` is used. """ -function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) +function Base.merge( + varinfo_left::VarInfo{LinkedLeft}, varinfo_right::VarInfo{LinkedRight} +) where {LinkedLeft,LinkedRight} new_values = merge(varinfo_left.values, varinfo_right.values) new_accs = map(copy, getaccs(varinfo_right)) - return VarInfo(new_values, new_accs) + new_linked = if LinkedLeft == LinkedRight + LinkedLeft + else + nothing + end + return VarInfo{new_linked}(new_values, new_accs) end diff --git a/test/varinfo.jl b/test/varinfo.jl index 639b4f688..9fb8c6d4d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -265,8 +265,10 @@ end _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) vals = values(vi) - all_transformed(vi) = mapreduce(p -> p.second.linked, &, vi.values; init=true) - any_transformed(vi) = mapreduce(p -> p.second.linked, |, vi.values; init=false) + all_transformed(vi) = + mapreduce(p -> is_transformed(p.second), &, vi.values; init=true) + any_transformed(vi) = + mapreduce(p -> is_transformed(p.second), |, vi.values; init=false) @test !any_transformed(vi) From aa3adb327fd2313f97dffd4df2e20a3f207cfcdc Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 14:47:39 +0000 Subject: [PATCH 54/56] Fix a typo Co-authored-by: Penelope Yong --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index d6c13f7a6..c3a704552 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -14,7 +14,7 @@ Almost all of the below changes are the consequence from switching over to using ### Overhaul of `VarInfo` DynamicPPL tracks variable values during model execution using one of the `AbstractVarInfo` types. -Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped, and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. +Previously, there were many versions of them: `VarInfo`, both "typed" and "untyped", and `SimpleVarInfo` with both `NamedTuple` and `OrderedDict` as storage backends. These have all been replaced by a rewritten implementation of `VarInfo`. While the basics of the `VarInfo` interface remain the same, this brings with it many changes: From 0c03233daa9d4b09345087ba1786d60da57114af Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 14:55:03 +0000 Subject: [PATCH 55/56] Simplify code Co-authored-by: Penelope Yong --- src/extract_priors.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/extract_priors.jl b/src/extract_priors.jl index 182e933e4..def2b7756 100644 --- a/src/extract_priors.jl +++ b/src/extract_priors.jl @@ -121,8 +121,7 @@ julia> length(extract_priors(rng, model)[@varname(x)]) extract_priors(args::Union{Model,AbstractVarInfo}...) = extract_priors(Random.default_rng(), args...) function extract_priors(rng::Random.AbstractRNG, model::Model) - varinfo = OnlyAccsVarInfo() - varinfo = setaccs!!(varinfo, (PriorDistributionAccumulator(),)) + varinfo = OnlyAccsVarInfo((PriorDistributionAccumulator(),)) varinfo = last(init!!(rng, model, varinfo)) return getacc(varinfo, Val(:PriorDistributionAccumulator)).priors end From 39df57b586e9b58d7b954bb78d758335ece312d3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 15 Jan 2026 15:01:20 +0000 Subject: [PATCH 56/56] Fix comments, remove dead line --- src/logdensityfunction.jl | 2 -- src/varinfo.jl | 6 ++++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9337d159c..7cb84cbc2 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -304,8 +304,6 @@ This function returns a VarNamedTuple mapping all VarNames to their correspondin `RangeAndLinked`. """ function get_ranges_and_linked(vi::VarInfo) - # TODO(mhauru) Check that the closure doesn't cause type instability here. - vnt = VarNamedTuple() vnt, _ = mapreduce( identity, function ((vnt, offset), pair) diff --git a/src/varinfo.jl b/src/varinfo.jl index a59837ba0..191537ad8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -226,8 +226,6 @@ function set_transformed!!(vi::VarInfo{Linked}, linked::Bool, vn::VarName) where old_tv = getindex(vi.values, vn) new_tv = TransformedValue{linked}(old_tv.val, old_tv.transform, old_tv.size) new_values = setindex!!(vi.values, new_tv, vn) - # The below check shouldn't ever pass, this should always result in `nothing`, but may - # as well play it safe, it'll be constant propagated away anyway. new_linked = Linked == linked ? Linked : nothing return VarInfo{new_linked}(new_values, vi.accs) end @@ -496,6 +494,10 @@ function Base.merge( new_linked = if LinkedLeft == LinkedRight LinkedLeft else + # TODO(mhauru) Consider doing something more clever here, e.g. checking whether + # either varinfo_left or varinfo_right is empty, or actually iterating over all the + # values to check their linked status. Needs to balance keeping the type parameter + # alive vs runtime costs. nothing end return VarInfo{new_linked}(new_values, new_accs)