Skip to content

Accumulators for linked VarInfo#1212

Merged
penelopeysm merged 5 commits intobreakingfrom
py/link
Jan 28, 2026
Merged

Accumulators for linked VarInfo#1212
penelopeysm merged 5 commits intobreakingfrom
py/link

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Jan 22, 2026

This PR implements something like what I described here: #836 (comment)

TODO

  • Code
  • Passes existing tests
  • Docs
  • New tests

Description

This PR:

  • modifies the signature of DynamicPPL.init, to return an AbstractTransformedValue;

  • passes the AbstractTransformedValue through to the accumulators;

  • uses an accumulator, TransformedValueAccumulator, to potentially modify and store a new AbstractTransformedValue.

The general concept is explained in these new docs pages:

https://turinglang.org/DynamicPPL.jl/previews/PR1212/flow/
https://turinglang.org/DynamicPPL.jl/previews/PR1212/values/

About a third of the changes are docs. There are a lot of source code changes, but most of them are just adapting existing functions to use the new interface. The interesting code changes are, in recommended order of reading:

  • src/transformed_values.jl the definition of AbstractTransformedValue
  • src/accs/vnt.jl, a generic accumulator that stores things in a VNT. Right now we only use it for one thing; I thought I'd leave the generalisation for another PR.
  • src/accs/transformed_values.jl, the accumulator that specifically stores transformed values in a VNT
  • src/varinfo.jl, the new implementation of linking that uses said accumulator

Benchmarks

main is main (of course)
py/orig-vnt is where Markus got to (#1183)
breaking is the same as py/orig-vnt, plus shadow arrays plus some perf improvements (#1204)
py/link is this PR

Smorgasbord
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>      4.694 ms     10.979 µs      8.014 µs      7.931 µs
                                keys  =>    411.067 ns      5.433 µs    579.875 ns    576.375 ns
                              subset  =>    496.209 µs    191.666 µs     66.167 µs     70.959 µs
                               merge  =>     15.584 µs    661.765 ns    260.593 ns    263.435 ns
            evaluate!! InitFromPrior  =>     12.708 µs      3.208 µs      2.590 µs      2.528 µs
           evaluate!! DefaultContext  =>     11.094 µs      1.037 µs    987.231 ns      1.250 µs
                         unflatten!!  =>    180.970 ns     12.125 µs    871.048 ns    922.794 ns
                              link!!  =>    151.937 µs    206.041 µs     52.459 µs      5.608 µs
    evaluate!! InitFromPrior, linked  =>     29.375 µs      6.510 µs      4.583 µs      4.583 µs
   evaluate!! DefaultContext, linked  =>      9.208 µs      2.550 µs      2.450 µs      2.392 µs
                 unflatten!!, linked  =>    188.691 ns     12.146 µs    820.850 ns    847.238 ns

Loop univariate 1k
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>    793.780 ms    172.708 µs    102.709 µs    102.666 µs
                                keys  =>    649.529 ns     55.000 µs      7.511 µs      7.104 µs
                              subset  =>    234.584 µs     17.421 ms      6.318 ms      6.079 ms
                               merge  =>    311.084 µs      4.198 µs      3.028 µs      2.777 µs
            evaluate!! InitFromPrior  =>     58.709 µs     22.500 µs     22.833 µs     21.416 µs
           evaluate!! DefaultContext  =>     58.083 µs      7.208 µs      6.688 µs      6.740 µs
                         unflatten!!  =>    821.429 ns    105.001 µs      7.542 µs      6.146 µs
                              link!!  =>    229.250 µs      1.947 ms    409.459 µs     25.875 µs
    evaluate!! InitFromPrior, linked  =>    246.459 µs     22.125 µs     21.875 µs     21.417 µs
   evaluate!! DefaultContext, linked  =>     70.458 µs     11.459 µs     11.416 µs     11.416 µs
                 unflatten!!, linked  =>    887.235 ns    102.375 µs      5.792 µs      6.146 µs

Multivariate 1k
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>     42.667 µs     23.291 µs     23.292 µs     22.896 µs
                                keys  =>     31.569 ns     41.479 ns     22.328 ns     22.065 ns
                              subset  =>      2.153 µs    407.143 ns    366.663 ns    366.137 ns
                               merge  =>      1.976 µs      2.174 ns      2.173 ns      2.169 ns
            evaluate!! InitFromPrior  =>     13.792 µs     12.375 µs     11.875 µs     11.583 µs
           evaluate!! DefaultContext  =>      8.334 µs      6.708 µs      6.073 µs      5.792 µs
                         unflatten!!  =>    827.381 ns      2.482 ns      2.480 ns      2.483 ns
                              link!!  =>     48.542 µs     10.229 µs      7.667 µs      6.042 µs
    evaluate!! InitFromPrior, linked  =>     14.250 µs     12.083 µs     12.791 µs     13.375 µs
   evaluate!! DefaultContext, linked  =>      7.979 µs      5.959 µs      6.625 µs      7.083 µs
                 unflatten!!, linked  =>    791.679 ns      2.480 ns      2.484 ns      2.480 ns

Dynamic
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>     35.292 µs      2.838 µs      3.051 µs      2.829 µs
                                keys  =>     48.659 ns     44.793 ns     43.266 ns     25.074 ns
                              subset  =>     11.709 µs      2.689 µs      2.725 µs      2.511 µs
                               merge  =>      1.566 µs      3.844 ns      4.012 ns      3.886 ns
            evaluate!! InitFromPrior  =>      3.375 µs      1.853 µs      1.964 µs      1.889 µs
           evaluate!! DefaultContext  =>      1.234 µs    683.325 ns    663.550 ns    732.300 ns
                         unflatten!!  =>    114.380 ns      5.448 ns      5.437 ns      5.425 ns
                              link!!  =>    149.000 µs     12.542 µs      5.142 µs      2.403 µs
    evaluate!! InitFromPrior, linked  =>      6.448 µs      4.066 µs      4.208 µs      3.589 µs
   evaluate!! DefaultContext, linked  =>      2.608 µs      1.815 µs      1.852 µs      1.529 µs
                 unflatten!!, linked  =>    116.129 ns      5.085 ns      4.978 ns      4.960 ns

Parent
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>     12.021 µs    372.831 ns    375.000 ns    304.557 ns
                                keys  =>     31.518 ns     46.558 ns     39.942 ns     22.237 ns
                              subset  =>    715.625 ns     35.526 ns     31.970 ns     15.133 ns
                               merge  =>    484.717 ns      2.174 ns      2.171 ns      2.184 ns
            evaluate!! InitFromPrior  =>     95.930 ns     29.826 ns     29.819 ns     14.450 ns
           evaluate!! DefaultContext  =>    100.410 ns      3.132 ns      3.134 ns      3.134 ns
                         unflatten!!  =>     41.185 ns      2.486 ns      2.484 ns      2.482 ns
                              link!!  =>     49.208 µs      3.720 µs      1.059 µs     18.761 ns
    evaluate!! InitFromPrior, linked  =>    297.680 ns     32.290 ns     31.272 ns     16.954 ns
   evaluate!! DefaultContext, linked  =>    119.073 ns     10.701 ns     10.806 ns     11.065 ns
                 unflatten!!, linked  =>     41.075 ns      2.481 ns      2.483 ns      2.554 ns

LDA
                                                  main   py/orig-vnt      breaking       py/link
                         Constructor  =>    122.959 µs     18.709 µs     19.333 µs     13.792 µs
                                keys  =>    117.287 ns    736.842 ns    128.553 ns     99.382 ns
                              subset  =>     57.083 µs      3.167 µs      2.887 µs      1.897 µs
                               merge  =>      2.504 µs    186.968 ns    159.192 ns    115.927 ns
            evaluate!! InitFromPrior  =>      9.403 µs      7.570 µs      7.313 µs      4.354 µs
           evaluate!! DefaultContext  =>      8.292 µs      6.646 µs      6.334 µs      3.307 µs
                         unflatten!!  =>    127.765 ns      4.681 µs    462.968 ns    324.728 ns
                              link!!  =>    150.562 µs     37.750 µs     23.084 µs      6.656 µs
    evaluate!! InitFromPrior, linked  =>     11.375 µs      7.833 µs      7.861 µs      5.100 µs
   evaluate!! DefaultContext, linked  =>      7.986 µs      7.011 µs      6.667 µs      4.190 µs
                 unflatten!!, linked  =>    123.391 ns      4.764 µs    461.651 ns    461.594 ns

Generating a linked VarInfo directly

This PR also adds new methods for VarInfo(rng, model, link, init), for example, VarInfo(model, LinkAll()) will immediately generate a linked VarInfo.

I thought that this would definitely be faster than the roundabout method, but it seems to depend on the model in question, and I'm not entirely sure why.

using DynamicPPL, Distributions, Chairmarks

direct(m) = VarInfo(m, DynamicPPL.LinkAll())
indirect(m) = link!!(VarInfo(m), m)

@model f() = x ~ Beta(2, 2); m = f()
@b direct($m)    # 431.222 ns (15 allocs: 528 bytes)
@b indirect($m)  # 158.775 ns (7 allocs: 208 bytes)

@model function f2()
    x = Vector{Float64}(undef, 100)
    x .~ Beta(2, 2)
end
m2 = f2()
@b direct($m2)   # 11.146 μs (419 allocs: 17.531 KiB)
@b indirect($m2) # 14.125 μs (815 allocs: 33.922 KiB)

@github-actions
Copy link
Contributor

github-actions bot commented Jan 22, 2026

Benchmark Report

  • this PR's head: 6cce479b6e762b6258a9f079e39df3fcaf0c8406
  • base branch: b03d531939fb35c9089b94b8732c7eceb3fc4038

Computer Information

Julia Version 1.11.8
Commit cf1da5e20e3 (2025-11-06 17:49 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │   true │   356.72 │   398.39 │    0.90 │  12.77 │   11.93 │    1.07 │   4553.88 │   4752.95 │    0.96 │
│                   LDA │    12 │ reversediff │   true │  2465.55 │  2731.61 │    0.90 │   5.09 │    5.08 │    1.00 │  12537.92 │  13877.75 │    0.90 │
│   Loop univariate 10k │ 10000 │    mooncake │   true │ 61230.62 │ 65585.40 │    0.93 │   5.23 │    5.42 │    0.96 │ 320254.11 │ 355486.65 │    0.90 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │   true │  6116.12 │  7384.87 │    0.83 │   5.18 │    4.77 │    1.08 │  31651.54 │  35233.67 │    0.90 │
│      Multivariate 10k │ 10000 │    mooncake │   true │ 37280.52 │ 33517.58 │    1.11 │  11.47 │    9.87 │    1.16 │ 427500.91 │ 330942.86 │    1.29 │
│       Multivariate 1k │  1000 │    mooncake │   true │  3264.70 │  3558.79 │    0.92 │   9.49 │    9.44 │    1.00 │  30973.70 │  33598.25 │    0.92 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │  false │     2.40 │     2.61 │    0.92 │   3.92 │    3.91 │    1.00 │      9.43 │     10.21 │    0.92 │
│           Smorgasbord │   201 │ forwarddiff │  false │  1009.23 │  1485.70 │    0.68 │  67.54 │   51.84 │    1.30 │  68161.76 │  77024.15 │    0.88 │
│           Smorgasbord │   201 │      enzyme │   true │  1411.21 │  1530.53 │    0.92 │   6.50 │    5.86 │    1.11 │   9171.37 │   8975.21 │    1.02 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │   true │  1400.88 │  1572.19 │    0.89 │  63.48 │   60.35 │    1.05 │  88927.79 │  94879.06 │    0.94 │
│           Smorgasbord │   201 │    mooncake │   true │  1405.25 │  1532.64 │    0.92 │   5.69 │    5.59 │    1.02 │   7992.12 │   8566.55 │    0.93 │
│           Smorgasbord │   201 │ reversediff │   true │  1391.79 │  1565.37 │    0.89 │ 101.22 │   98.71 │    1.03 │ 140878.70 │ 154511.82 │    0.91 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│              Submodel │     1 │    mooncake │   true │     3.03 │     3.28 │    0.92 │ 117.94 │  117.71 │    1.00 │    357.34 │    385.94 │    0.93 │
└───────────────────────┴───────┴─────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1212 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1212/

@codecov
Copy link

codecov bot commented Jan 22, 2026

Codecov Report

❌ Patch coverage is 85.09615% with 31 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.22%. Comparing base (6a5675b) to head (6cce479).
⚠️ Report is 15 commits behind head on breaking.

Files with missing lines Patch % Lines
src/transformed_values.jl 60.00% 8 Missing ⚠️
src/accs/vnt.jl 63.15% 7 Missing ⚠️
src/contexts/init.jl 75.00% 7 Missing ⚠️
src/varinfo.jl 90.00% 7 Missing ⚠️
src/accs/transformed_values.jl 95.12% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1212      +/-   ##
============================================
- Coverage     78.36%   77.22%   -1.15%     
============================================
  Files            41       45       +4     
  Lines          3296     3574     +278     
============================================
+ Hits           2583     2760     +177     
- Misses          713      814     +101     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm force-pushed the py/link branch 2 times, most recently from faea67a to 144f4d5 Compare January 27, 2026 13:58
Comment on lines +43 to +45
if acc1.f != acc2.f
throw(ArgumentError("Cannot combine VNTAccumulators with different functions"))
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really know if this is too strict. I'm worried about things like anonymous functions, which don't compare equal to one another. But at the same time if we don't check, we might silently get unexpected results! So I've just gone with a general principle of, it doesn't break CI, let's be strict, and if we find we need to not be strict we can change it later.

Comment on lines +38 to +39
# Note that we can't rely on the stored transform being correct (e.g. if new values
# were placed in `vi` via `unflatten!!`, so we regenerate the transforms.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same as the behaviour of from_maybe_linked_internal_transform.

Comment on lines +368 to +373
x, init_logjac = with_logabsdet_jacobian(get_transform(tval), get_internal_value(tval))
# TODO(penelopeysm): This could be inefficient if `tval` is already linked and
# `setindex_with_dist!!` tells it to create a new linked value again. In particular,
# this is inefficient if we use `InitFromParams` that provides linked values. The answer
# to this is to stop using setindex_with_dist!! and just use the TransformedValue
# accumulator.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scenario actually can't happen right now (there's nothing that triggers this code path). I don't want to fix it in this PR though because it'd get too big.

# easier to type infer, but if `transform` no longer exists, it might start to cause
# unnecessary type inconcreteness in the elements of PartialArray.
"""
TransformedValue{Linked,ValType,TransformType,SizeType}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, TransformedValue{false} is the same as VectorValue now, and TransformedValue{true} is the same as LinkedVectorValue. I generalised this so that we could accommodate the third subtype UntransformedValue as well.

@penelopeysm penelopeysm marked this pull request as ready for review January 27, 2026 15:19
@penelopeysm penelopeysm requested a review from sunxd3 January 27, 2026 15:19
@penelopeysm
Copy link
Member Author

penelopeysm commented Jan 27, 2026

I think this one should be easier to review :) I wrote a bit in the toplevel comment about what this PR does & where the code changes are. Again no rush, I have plenty of other things to work on 😅

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Penny!

LGTM in general, couple of nits

get_transform(tv::$T) = tv.transform
get_internal_value(tv::$T) = tv.val

function update_value(tv::$T, new_val)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need update_value for UntransformedValue?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably a good idea, it's not used anywhere, but just for the sake of having a consistent API.

Also it should probably be update_internal_value or maybe even set_internal_value

@penelopeysm penelopeysm requested a review from sunxd3 January 28, 2026 12:59
@penelopeysm
Copy link
Member Author

Thanks @sunxd3!

@penelopeysm penelopeysm merged commit 27a0b3b into breaking Jan 28, 2026
20 of 21 checks passed
@penelopeysm penelopeysm deleted the py/link branch January 28, 2026 14:18
penelopeysm referenced this pull request Feb 24, 2026
Release 0.40

---------

Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants