Skip to content

VNT Part 5: VarNamedTuple as VarInfo#1183

Merged
mhauru merged 58 commits intomhauru/vnt-for-fastldffrom
mhauru/vnt-for-varinfo
Jan 15, 2026
Merged

VNT Part 5: VarNamedTuple as VarInfo#1183
mhauru merged 58 commits intomhauru/vnt-for-fastldffrom
mhauru/vnt-for-varinfo

Conversation

@mhauru
Copy link
Member

@mhauru mhauru commented Dec 18, 2025

I put together a quick sketch of what it would look like to use VarNamedTuple as a VarInfo directly. By that I mean having a VarInfo type that is nothing but accumulators plus a VarNamedTuple that maps each VarName to a tuple (or actually a tiny struct, but anyway) of three values: Stored value for this variable, whether it's linked, and what transform should be applied to convert the stored value back to "model space". I'm calling this new VarInfo type VNTVarInfo (name to be changed later).

This isn't finished yet, but the majority of tests pass. There are a lot of failures around edge cases like Cholesky and weird VarNames and such, but for most simple models you can do

vi = VNTVarInfo(model)
vi = link!!(vi, model)
evaluate!!(model, vi)

and it'll give you the correct result. unflatten and vi[:] also work.

I'll keep working on this, but at this point I wanted to pause to do some benchmarks, see how viable this is. Benchmark code, very similar to #1182, running evaluate!! on our benchmarking models:

Details
module VIBench

using DynamicPPL, Distributions, Chairmarks
using StableRNGs: StableRNG
include("benchmarks/src/Models.jl")
using .Models: Models

function run()
    rng = StableRNG(23)

    smorgasbord_instance = Models.smorgasbord(randn(rng, 100), randn(rng, 100))

    loop_univariate1k, multivariate1k = begin
        data_1k = randn(rng, 1_000)
        loop = Models.loop_univariate(length(data_1k)) | (; o=data_1k)
        multi = Models.multivariate(length(data_1k)) | (; o=data_1k)
        loop, multi
    end

    loop_univariate10k, multivariate10k = begin
        data_10k = randn(rng, 10_000)
        loop = Models.loop_univariate(length(data_10k)) | (; o=data_10k)
        multi = Models.multivariate(length(data_10k)) | (; o=data_10k)
        loop, multi
    end

    # lda_instance = begin
    #     w = [1, 2, 3, 2, 1, 1]
    #     d = [1, 1, 1, 2, 2, 2]
    #     Models.lda(2, d, w)
    # end

    models = [
        ("simple_assume_observe", Models.simple_assume_observe(randn(rng))),
        ("smorgasbord", smorgasbord_instance),
        ("loop_univariate1k", loop_univariate1k),
        ("multivariate1k", multivariate1k),
        ("loop_univariate10k", loop_univariate10k),
        ("multivariate10k", multivariate10k),
        ("dynamic", Models.dynamic()),
        ("parent", Models.parent(randn(rng))),
        # ("lda", lda_instance),
    ]

    function print_diff(r, ref)
        diff = r.time - ref.time
        units = if diff < 1e-6
            "ns"
        elseif diff < 1e-3
            "µs"
        else
            "ms"
        end
        diff = if units == "ns"
            round(diff / 1e-9; digits=1)
        elseif units == "µs"
            round(diff / 1e-6; digits=1)
        else
            round(diff / 1e-3; digits=1)
        end
        sign = diff < 0 ? "" : "+"
        return println(" ($(sign)$(diff) $units)")
    end

    new = isdefined(DynamicPPL, :(VNTVarInfo))
    prefix = new ? "New" : "Old"

    for (name, m) in models
        println()
        println(name)
        vi = VarInfo(StableRNG(23), m)
        vi_linked = link!!(deepcopy(vi), m)
        # logp = getlogjoint(last(DynamicPPL.evaluate!!(m, vi)))
        # logp_linked = getlogjoint(last(DynamicPPL.evaluate!!(m, vi_linked)))
        # @show logp
        # @show logp_linked
        res = @b DynamicPPL.evaluate!!($m, $vi)
        print("$prefix unlinked: ")
        display(res)
        res = @b DynamicPPL.evaluate!!($m, $vi_linked)
        print("$prefix linked:   ")
        display(res)

        if !isdefined(DynamicPPL, :(VNTVarInfo))
            svi_nt = SimpleVarInfo(vi, NamedTuple)
            try
                res = @b DynamicPPL.evaluate!!($m, $svi_nt)
            catch e
                res = missing
            end
            print("SVI NT:       ")
            display(res)
            svi_od = SimpleVarInfo(vi, OrderedDict)
            res = @b DynamicPPL.evaluate!!($m, $svi_od)
            print("SVI OD:       ")
            display(res)
        end
    end
end

run()

end

Results contrasting the new VarInfo with both the old VarInfo and with SimpleVarInfo{NamedTuple} and SimpleVarInfo{OrderedDict}. Some SVI NT results are missing because it couldn't handle the IndexLenses:

simple_assume_observe
New unlinked: 2.778 ns
New linked:   12.201 ns
Old unlinked: 91.414 ns (4 allocs: 128 bytes)
Old linked:   80.752 ns (4 allocs: 128 bytes)
SVI NT:       2.468 ns
SVI OD:       4.941 ns

smorgasbord
New unlinked: 5.375 μs (12 allocs: 6.156 KiB)
New linked:   6.146 μs (18 allocs: 8.750 KiB)
Old unlinked: 16.375 μs (420 allocs: 33.375 KiB)
Old linked:   13.354 μs (325 allocs: 18.609 KiB)
SVI NT:       missing
SVI OD:       357.333 μs (3514 allocs: 98.891 KiB)

loop_univariate1k
New unlinked: 10.625 μs (6 allocs: 16.125 KiB)
New linked:   12.250 μs (6 allocs: 16.125 KiB)
Old unlinked: 64.542 μs (2009 allocs: 86.688 KiB)
Old linked:   58.625 μs (2009 allocs: 86.688 KiB)
SVI NT:       missing
SVI OD:       7.444 μs (6 allocs: 16.125 KiB)

multivariate1k
New unlinked: 11.125 μs (24 allocs: 80.500 KiB)
New linked:   11.250 μs (24 allocs: 80.500 KiB)
Old unlinked: 11.209 μs (29 allocs: 88.625 KiB)
Old linked:   11.208 μs (29 allocs: 88.625 KiB)
SVI NT:       10.708 μs (24 allocs: 80.500 KiB)
SVI OD:       10.833 μs (24 allocs: 80.500 KiB)

loop_univariate10k
New unlinked: 104.750 μs (6 allocs: 192.125 KiB)
New linked:   142.583 μs (6 allocs: 192.125 KiB)
Old unlinked: 752.542 μs (20009 allocs: 913.188 KiB)
Old linked:   614.750 μs (20009 allocs: 913.188 KiB)
SVI NT:       missing
SVI OD:       155.625 μs (6 allocs: 192.125 KiB)

multivariate10k
New unlinked: 107.500 μs (24 allocs: 896.500 KiB)
New linked:   106.459 μs (24 allocs: 896.500 KiB)
Old unlinked: 112.833 μs (29 allocs: 992.625 KiB)
Old linked:   110.500 μs (29 allocs: 992.625 KiB)
SVI NT:       106.000 μs (24 allocs: 896.500 KiB)
SVI OD:       110.292 μs (24 allocs: 896.500 KiB)

dynamic
New unlinked: 1.109 μs (12 allocs: 672 bytes)
New linked:   2.149 μs (43 allocs: 2.406 KiB)
Old unlinked: 1.854 μs (27 allocs: 1.891 KiB)
Old linked:   3.023 μs (53 allocs: 2.922 KiB)
SVI NT:       1.035 μs (12 allocs: 672 bytes)
SVI OD:       6.927 μs (75 allocs: 2.953 KiB)

parent
New unlinked: 2.777 ns
New linked:   10.967 ns
Old unlinked: 113.683 ns (6 allocs: 192 bytes)
Old linked:   106.579 ns (6 allocs: 192 bytes)
SVI NT:       missing
SVI OD:       4.948 ns

I think a fair TL;DR is that for both small models and models with IndexLenses this is many times faster than the old VarInfo, and not far off from SimpleVarInfo when SimpleVarInfo is at its fastest (NamedTuples for small models, OrderedDicts for IndexLenses). I would still like to close that gap a bit, I don't know why linking causes such a large slowdown in some cases, I suspect it's because the transform system is geared towards assuming we want to vectorise things, and I've hacked this together quickly to just get it to work.

For large models performance is essentially equal, as it should be, because this is about overheads. To fix that, I need to look into using views in some clever way, but that's for later.

I think this is a promising start towards being able to say that all of VarInfo, SimpleVarInfo, and VarNamedVector could be replaced with a direct use of VarNamedTuple (as opposed to e.g. VNT wrapping VarNamedVector), and it would be pretty close to being a best-of-all-worlds solution, in that it's almost as fast as SVI and has full support for all models.

Note that the new VNTVarInfo has no notion of typed and untyped VarInfos. They are all as typed as they can be, which should also help simplify code.

I'll keep working on this tomorrow.

@github-actions
Copy link
Contributor

github-actions bot commented Dec 18, 2025

Benchmark Report

  • this PR's head: 39df57b586e9b58d7b954bb78d758335ece312d3
  • base branch: a848290ec616aad5caed88c0a6add5a2aee62ad0

Computer Information

Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │   true │   367.64 │   350.82 │    1.05 │  10.34 │   11.07 │    0.93 │   3801.86 │   3882.46 │    0.98 │
│                   LDA │    12 │ reversediff │   true │  2686.68 │  2667.55 │    1.01 │   5.04 │    6.66 │    0.76 │  13544.06 │  17762.32 │    0.76 │
│   Loop univariate 10k │ 10000 │    mooncake │   true │ 58615.00 │ 53481.66 │    1.10 │   5.79 │    5.96 │    0.97 │ 339124.34 │ 318731.74 │    1.06 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │   true │  5888.02 │  5339.46 │    1.10 │   5.73 │    5.97 │    0.96 │  33726.24 │  31901.82 │    1.06 │
│      Multivariate 10k │ 10000 │    mooncake │   true │ 32220.60 │ 31029.52 │    1.04 │  10.27 │    9.89 │    1.04 │ 330958.30 │ 306921.54 │    1.08 │
│       Multivariate 1k │  1000 │    mooncake │   true │  3599.15 │  3601.79 │    1.00 │   9.34 │    8.70 │    1.07 │  33626.47 │  31325.62 │    1.07 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │  false │     2.66 │     2.41 │    1.10 │   3.82 │    3.93 │    0.97 │     10.15 │      9.47 │    1.07 │
│           Smorgasbord │   201 │ forwarddiff │  false │  1089.45 │  1006.01 │    1.08 │ 135.14 │   68.91 │    1.96 │ 147224.74 │  69325.36 │    2.12 │
│           Smorgasbord │   201 │      enzyme │   true │  1519.11 │  1373.42 │    1.11 │   6.66 │    6.20 │    1.07 │  10114.57 │   8517.55 │    1.19 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │   true │  1501.77 │  1378.27 │    1.09 │  61.76 │   67.97 │    0.91 │  92754.63 │  93686.74 │    0.99 │
│           Smorgasbord │   201 │    mooncake │   true │  1528.65 │  1380.21 │    1.11 │   5.68 │    5.94 │    0.96 │   8679.36 │   8193.78 │    1.06 │
│           Smorgasbord │   201 │ reversediff │   true │  1529.23 │  1380.75 │    1.11 │ 100.66 │  103.82 │    0.97 │ 153929.50 │ 143348.40 │    1.07 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│              Submodel │     1 │    mooncake │   true │     3.35 │     3.08 │    1.09 │  11.13 │   10.84 │    1.03 │     37.23 │     33.40 │    1.11 │
└───────────────────────┴───────┴─────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@penelopeysm
Copy link
Member

penelopeysm commented Dec 19, 2025

Darn, that is really good.

tuple (or actually a tiny struct, but anyway) of three values: Stored value for this variable, whether it's linked, and what transform should be applied to convert the stored value back to "model space"

Am I right in saying the latter two are only really needed for DefaultContext?

Edit: Actually, that's a silly question, if not for DefaultContext we don't even need the metadata field at all.

@penelopeysm
Copy link
Member

Also, I'm just eyeing this PR and thinking that it's a prime opportunity to clean up the varinfo interface, especially with the functions that return internal values when they probably shouldn't.

@mhauru
Copy link
Member Author

mhauru commented Dec 19, 2025

Also, I'm just eyeing this PR and thinking that it's a prime opportunity to clean up the varinfo interface, especially with the functions that return internal values when they probably shouldn't.

Yes. I'm first trying to make this work without making huge interface changes, just to make sure this can do everything that is needed to do, but I think interface changes should follow close behind, maybe in the same PR or the same release. They'll be much easier to make once there is only two VarInfo types that need to respect them, namely the new one and Threadsafe.

@mhauru mhauru changed the title VarNamedTuple as VarInfo VNT Part 5: VarNamedTuple as VarInfo Dec 19, 2025
@yebai
Copy link
Member

yebai commented Dec 19, 2025

Looks exciting! Two quick quesitons: would this be suitable to

  1. Implement simulation based inference algorithms, eg, particle MCMC, where model dimentionality or parameters support could change
  2. First model run to bootstrap / infer a VarInfo?

@mhauru
Copy link
Member Author

mhauru commented Dec 19, 2025

  1. Yep. The only thing I foresee being a problem is if some variable turns from e.g. being a Vector to being a Matrix, and you do IndexLens indexing into it. So first you have x[1] and then x[1,1]. That would be a problem. Other than that, should be fine.
  2. Yes. You can use the same type, VNTVarInfo, for both the first run when collecting variables, and for later runs when evaluating with known variables. No need for the typed/untyped distinction.

One thing I haven't benchmarked, and maybe should, is type unstable models. There is a possibility that type unstable models will be slower with the new approach, because VNTVarInfo is pretty aggressive in trying to make element types concrete, and if it keeps trying and failing again and again, that could cost a lot of time. Or it might be a negligible contribution to the performance-disaster that is a type unstable model. Need to benchmark.

Base automatically changed from mhauru/vnt-for-vaimacc to mhauru/vnt-for-fastldf January 7, 2026 16:52
@yebai
Copy link
Member

yebai commented Jan 8, 2026

One comment here is to carefully consider the requirements of particle Gibbs during reviewing so we have sufficient design prevision for

  1. bootstrap an of-type by running the model once.
  2. upstreaming AdvancedPS to Turing using VarNamedTupe and Accumulators

cc @sunxd3 @penelopeysm

@penelopeysm
Copy link
Member

I think PG should be fine. I think we aren't really removing things so much as shuffling things around and putting them in the right boxes -- so previously where Libtask had to carry a full varinfo, we would now just make it carry a VNT + accumulator tuple.

Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
end

function unflatten!!(vi::VarInfo, vec::AbstractVector)
function unflatten!!(vi::VarInfo{Linked}, vec::AbstractVector) where {Linked}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re. comment below this: Did you look into the closure boxing thing, or should I just leave it for another time?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm testing it out, but so far it hasn't yielded much. I get the feeling the time cost is somewhere else, but I don't yet understand where.

Comment on lines +496 to +500
new_linked = if LinkedLeft == LinkedRight
LinkedLeft
else
nothing
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pathological, but what if one of them is empty but has the opposite link status? Then we should just take the link status from the non-empty one. Alternatively, only determine the new link status by iterating through the values again after merging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair points. I've added a comment, because I don't think this is high priority to fix right now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's fair, it's just a performance hit rather than correctness.

@mhauru
Copy link
Member Author

mhauru commented Jan 15, 2026

BTW: VNT seems to use @inbounds liberally. I just found out about 30 seconds ago that apparently this is not good: JuliaStats/Distributions.jl#2005

This is concerning. I should try removing them. I would be surprised if Julia was able to optimise these things itself, because a lot of the @inbounds marks rely on mask and data being the same shape, which depends on them having been constructed that way. However, if it hurts type inference, it could be really bad.

mhauru and others added 2 commits January 15, 2026 14:55
Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
@mhauru mhauru requested a review from penelopeysm January 15, 2026 15:51
@mhauru mhauru mentioned this pull request Jan 15, 2026
14 tasks
@mhauru
Copy link
Member Author

mhauru commented Jan 15, 2026

I think this is good enough to merge. I've listed the residual issues to check and polish here: #1201. @penelopeysm, are you happy?

@sunxd3
Copy link
Member

sunxd3 commented Jan 16, 2026

🎉

@penelopeysm penelopeysm mentioned this pull request Jan 23, 2026
4 tasks
penelopeysm added a commit that referenced this pull request Jan 28, 2026
This PR implements something like what I described here:
#836 (comment)

# TODO

- [x] Code
- [x] Passes existing tests
- [x] Docs
- [x] New tests

# Description

This PR:

- modifies the signature of `DynamicPPL.init`, to return an
`AbstractTransformedValue`;

- passes the `AbstractTransformedValue` through to the accumulators;

- uses an accumulator, `TransformedValueAccumulator`, to potentially
modify and store a new `AbstractTransformedValue`.

The general concept is explained in these new docs pages:

https://turinglang.org/DynamicPPL.jl/previews/PR1212/flow/
https://turinglang.org/DynamicPPL.jl/previews/PR1212/values/

About a third of the changes are docs. There are a lot of source code
changes, but most of them are just adapting existing functions to use
the new interface. The interesting code changes are, in recommended
order of reading:

- `src/transformed_values.jl` the definition of
`AbstractTransformedValue`
- `src/accs/vnt.jl`, a generic accumulator that stores things in a VNT.
Right now we only use it for one thing; I thought I'd leave the
generalisation for another PR.
- `src/accs/transformed_values.jl`, the accumulator that specifically
stores transformed values in a VNT
- `src/varinfo.jl`, the new implementation of linking that uses said
accumulator

# Benchmarks

main is main (of course)
py/orig-vnt is where Markus got to (#1183)
breaking is the same as py/orig-vnt, plus shadow arrays plus some perf
improvements (#1204)
py/link is this PR

```
Smorgasbord
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>      4.694 ms     10.979 µs      8.014 µs      7.931 µs
                                keys  =>    411.067 ns      5.433 µs    579.875 ns    576.375 ns
                              subset  =>    496.209 µs    191.666 µs     66.167 µs     70.959 µs
                               merge  =>     15.584 µs    661.765 ns    260.593 ns    263.435 ns
            evaluate!! InitFromPrior  =>     12.708 µs      3.208 µs      2.590 µs      2.528 µs
           evaluate!! DefaultContext  =>     11.094 µs      1.037 µs    987.231 ns      1.250 µs
                         unflatten!!  =>    180.970 ns     12.125 µs    871.048 ns    922.794 ns
                              link!!  =>    151.937 µs    206.041 µs     52.459 µs      5.608 µs
    evaluate!! InitFromPrior, linked  =>     29.375 µs      6.510 µs      4.583 µs      4.583 µs
   evaluate!! DefaultContext, linked  =>      9.208 µs      2.550 µs      2.450 µs      2.392 µs
                 unflatten!!, linked  =>    188.691 ns     12.146 µs    820.850 ns    847.238 ns

Loop univariate 1k
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>    793.780 ms    172.708 µs    102.709 µs    102.666 µs
                                keys  =>    649.529 ns     55.000 µs      7.511 µs      7.104 µs
                              subset  =>    234.584 µs     17.421 ms      6.318 ms      6.079 ms
                               merge  =>    311.084 µs      4.198 µs      3.028 µs      2.777 µs
            evaluate!! InitFromPrior  =>     58.709 µs     22.500 µs     22.833 µs     21.416 µs
           evaluate!! DefaultContext  =>     58.083 µs      7.208 µs      6.688 µs      6.740 µs
                         unflatten!!  =>    821.429 ns    105.001 µs      7.542 µs      6.146 µs
                              link!!  =>    229.250 µs      1.947 ms    409.459 µs     25.875 µs
    evaluate!! InitFromPrior, linked  =>    246.459 µs     22.125 µs     21.875 µs     21.417 µs
   evaluate!! DefaultContext, linked  =>     70.458 µs     11.459 µs     11.416 µs     11.416 µs
                 unflatten!!, linked  =>    887.235 ns    102.375 µs      5.792 µs      6.146 µs

Multivariate 1k
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>     42.667 µs     23.291 µs     23.292 µs     22.896 µs
                                keys  =>     31.569 ns     41.479 ns     22.328 ns     22.065 ns
                              subset  =>      2.153 µs    407.143 ns    366.663 ns    366.137 ns
                               merge  =>      1.976 µs      2.174 ns      2.173 ns      2.169 ns
            evaluate!! InitFromPrior  =>     13.792 µs     12.375 µs     11.875 µs     11.583 µs
           evaluate!! DefaultContext  =>      8.334 µs      6.708 µs      6.073 µs      5.792 µs
                         unflatten!!  =>    827.381 ns      2.482 ns      2.480 ns      2.483 ns
                              link!!  =>     48.542 µs     10.229 µs      7.667 µs      6.042 µs
    evaluate!! InitFromPrior, linked  =>     14.250 µs     12.083 µs     12.791 µs     13.375 µs
   evaluate!! DefaultContext, linked  =>      7.979 µs      5.959 µs      6.625 µs      7.083 µs
                 unflatten!!, linked  =>    791.679 ns      2.480 ns      2.484 ns      2.480 ns

Dynamic
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>     35.292 µs      2.838 µs      3.051 µs      2.829 µs
                                keys  =>     48.659 ns     44.793 ns     43.266 ns     25.074 ns
                              subset  =>     11.709 µs      2.689 µs      2.725 µs      2.511 µs
                               merge  =>      1.566 µs      3.844 ns      4.012 ns      3.886 ns
            evaluate!! InitFromPrior  =>      3.375 µs      1.853 µs      1.964 µs      1.889 µs
           evaluate!! DefaultContext  =>      1.234 µs    683.325 ns    663.550 ns    732.300 ns
                         unflatten!!  =>    114.380 ns      5.448 ns      5.437 ns      5.425 ns
                              link!!  =>    149.000 µs     12.542 µs      5.142 µs      2.403 µs
    evaluate!! InitFromPrior, linked  =>      6.448 µs      4.066 µs      4.208 µs      3.589 µs
   evaluate!! DefaultContext, linked  =>      2.608 µs      1.815 µs      1.852 µs      1.529 µs
                 unflatten!!, linked  =>    116.129 ns      5.085 ns      4.978 ns      4.960 ns

Parent
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>     12.021 µs    372.831 ns    375.000 ns    304.557 ns
                                keys  =>     31.518 ns     46.558 ns     39.942 ns     22.237 ns
                              subset  =>    715.625 ns     35.526 ns     31.970 ns     15.133 ns
                               merge  =>    484.717 ns      2.174 ns      2.171 ns      2.184 ns
            evaluate!! InitFromPrior  =>     95.930 ns     29.826 ns     29.819 ns     14.450 ns
           evaluate!! DefaultContext  =>    100.410 ns      3.132 ns      3.134 ns      3.134 ns
                         unflatten!!  =>     41.185 ns      2.486 ns      2.484 ns      2.482 ns
                              link!!  =>     49.208 µs      3.720 µs      1.059 µs     18.761 ns
    evaluate!! InitFromPrior, linked  =>    297.680 ns     32.290 ns     31.272 ns     16.954 ns
   evaluate!! DefaultContext, linked  =>    119.073 ns     10.701 ns     10.806 ns     11.065 ns
                 unflatten!!, linked  =>     41.075 ns      2.481 ns      2.483 ns      2.554 ns

LDA
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>    122.959 µs     18.709 µs     19.333 µs     13.792 µs
                                keys  =>    117.287 ns    736.842 ns    128.553 ns     99.382 ns
                              subset  =>     57.083 µs      3.167 µs      2.887 µs      1.897 µs
                               merge  =>      2.504 µs    186.968 ns    159.192 ns    115.927 ns
            evaluate!! InitFromPrior  =>      9.403 µs      7.570 µs      7.313 µs      4.354 µs
           evaluate!! DefaultContext  =>      8.292 µs      6.646 µs      6.334 µs      3.307 µs
                         unflatten!!  =>    127.765 ns      4.681 µs    462.968 ns    324.728 ns
                              link!!  =>    150.562 µs     37.750 µs     23.084 µs      6.656 µs
    evaluate!! InitFromPrior, linked  =>     11.375 µs      7.833 µs      7.861 µs      5.100 µs
   evaluate!! DefaultContext, linked  =>      7.986 µs      7.011 µs      6.667 µs      4.190 µs
                 unflatten!!, linked  =>    123.391 ns      4.764 µs    461.651 ns    461.594 ns
```

# Generating a linked VarInfo directly

This PR also adds new methods for `VarInfo(rng, model, link, init)`, for
example, `VarInfo(model, LinkAll())` will immediately generate a linked
VarInfo.

I thought that this would definitely be faster than the roundabout
method, but it seems to depend on the model in question, and I'm not
entirely sure why.

```julia
using DynamicPPL, Distributions, Chairmarks

direct(m) = VarInfo(m, DynamicPPL.LinkAll())
indirect(m) = link!!(VarInfo(m), m)

@model f() = x ~ Beta(2, 2); m = f()
@b direct($m)    # 431.222 ns (15 allocs: 528 bytes)
@b indirect($m)  # 158.775 ns (7 allocs: 208 bytes)

@model function f2()
    x = Vector{Float64}(undef, 100)
    x .~ Beta(2, 2)
end
m2 = f2()
@b direct($m2)   # 11.146 μs (419 allocs: 17.531 KiB)
@b indirect($m2) # 14.125 μs (815 allocs: 33.922 KiB)
```

---------

Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants