VNT Part 5: VarNamedTuple as VarInfo#1183
Conversation
Benchmark Report
Computer InformationBenchmark Results |
|
Darn, that is really good.
Am I right in saying the latter two are only really needed for DefaultContext? Edit: Actually, that's a silly question, if not for DefaultContext we don't even need the metadata field at all. |
|
Also, I'm just eyeing this PR and thinking that it's a prime opportunity to clean up the varinfo interface, especially with the functions that return internal values when they probably shouldn't. |
Yes. I'm first trying to make this work without making huge interface changes, just to make sure this can do everything that is needed to do, but I think interface changes should follow close behind, maybe in the same PR or the same release. They'll be much easier to make once there is only two VarInfo types that need to respect them, namely the new one and Threadsafe. |
|
Looks exciting! Two quick quesitons: would this be suitable to
|
One thing I haven't benchmarked, and maybe should, is type unstable models. There is a possibility that type unstable models will be slower with the new approach, because |
|
One comment here is to carefully consider the requirements of particle Gibbs during reviewing so we have sufficient design prevision for
|
|
I think PG should be fine. I think we aren't really removing things so much as shuffling things around and putting them in the right boxes -- so previously where Libtask had to carry a full varinfo, we would now just make it carry a VNT + accumulator tuple. |
…uru/vnt-for-varinfo
Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
| end | ||
|
|
||
| function unflatten!!(vi::VarInfo, vec::AbstractVector) | ||
| function unflatten!!(vi::VarInfo{Linked}, vec::AbstractVector) where {Linked} |
There was a problem hiding this comment.
Re. comment below this: Did you look into the closure boxing thing, or should I just leave it for another time?
There was a problem hiding this comment.
I'm testing it out, but so far it hasn't yielded much. I get the feeling the time cost is somewhere else, but I don't yet understand where.
| new_linked = if LinkedLeft == LinkedRight | ||
| LinkedLeft | ||
| else | ||
| nothing | ||
| end |
There was a problem hiding this comment.
Pathological, but what if one of them is empty but has the opposite link status? Then we should just take the link status from the non-empty one. Alternatively, only determine the new link status by iterating through the values again after merging.
There was a problem hiding this comment.
Fair points. I've added a comment, because I don't think this is high priority to fix right now.
There was a problem hiding this comment.
Yeah, that's fair, it's just a performance hit rather than correctness.
This is concerning. I should try removing them. I would be surprised if Julia was able to optimise these things itself, because a lot of the |
Co-authored-by: Penelope Yong <penelopeysm@gmail.com>
|
I think this is good enough to merge. I've listed the residual issues to check and polish here: #1201. @penelopeysm, are you happy? |
|
🎉 |
This PR implements something like what I described here: #836 (comment) # TODO - [x] Code - [x] Passes existing tests - [x] Docs - [x] New tests # Description This PR: - modifies the signature of `DynamicPPL.init`, to return an `AbstractTransformedValue`; - passes the `AbstractTransformedValue` through to the accumulators; - uses an accumulator, `TransformedValueAccumulator`, to potentially modify and store a new `AbstractTransformedValue`. The general concept is explained in these new docs pages: https://turinglang.org/DynamicPPL.jl/previews/PR1212/flow/ https://turinglang.org/DynamicPPL.jl/previews/PR1212/values/ About a third of the changes are docs. There are a lot of source code changes, but most of them are just adapting existing functions to use the new interface. The interesting code changes are, in recommended order of reading: - `src/transformed_values.jl` the definition of `AbstractTransformedValue` - `src/accs/vnt.jl`, a generic accumulator that stores things in a VNT. Right now we only use it for one thing; I thought I'd leave the generalisation for another PR. - `src/accs/transformed_values.jl`, the accumulator that specifically stores transformed values in a VNT - `src/varinfo.jl`, the new implementation of linking that uses said accumulator # Benchmarks main is main (of course) py/orig-vnt is where Markus got to (#1183) breaking is the same as py/orig-vnt, plus shadow arrays plus some perf improvements (#1204) py/link is this PR ``` Smorgasbord main py/orig-vnt breaking py/link Constructor => 4.694 ms 10.979 µs 8.014 µs 7.931 µs keys => 411.067 ns 5.433 µs 579.875 ns 576.375 ns subset => 496.209 µs 191.666 µs 66.167 µs 70.959 µs merge => 15.584 µs 661.765 ns 260.593 ns 263.435 ns evaluate!! InitFromPrior => 12.708 µs 3.208 µs 2.590 µs 2.528 µs evaluate!! DefaultContext => 11.094 µs 1.037 µs 987.231 ns 1.250 µs unflatten!! => 180.970 ns 12.125 µs 871.048 ns 922.794 ns link!! => 151.937 µs 206.041 µs 52.459 µs 5.608 µs evaluate!! InitFromPrior, linked => 29.375 µs 6.510 µs 4.583 µs 4.583 µs evaluate!! DefaultContext, linked => 9.208 µs 2.550 µs 2.450 µs 2.392 µs unflatten!!, linked => 188.691 ns 12.146 µs 820.850 ns 847.238 ns Loop univariate 1k main py/orig-vnt breaking py/link Constructor => 793.780 ms 172.708 µs 102.709 µs 102.666 µs keys => 649.529 ns 55.000 µs 7.511 µs 7.104 µs subset => 234.584 µs 17.421 ms 6.318 ms 6.079 ms merge => 311.084 µs 4.198 µs 3.028 µs 2.777 µs evaluate!! InitFromPrior => 58.709 µs 22.500 µs 22.833 µs 21.416 µs evaluate!! DefaultContext => 58.083 µs 7.208 µs 6.688 µs 6.740 µs unflatten!! => 821.429 ns 105.001 µs 7.542 µs 6.146 µs link!! => 229.250 µs 1.947 ms 409.459 µs 25.875 µs evaluate!! InitFromPrior, linked => 246.459 µs 22.125 µs 21.875 µs 21.417 µs evaluate!! DefaultContext, linked => 70.458 µs 11.459 µs 11.416 µs 11.416 µs unflatten!!, linked => 887.235 ns 102.375 µs 5.792 µs 6.146 µs Multivariate 1k main py/orig-vnt breaking py/link Constructor => 42.667 µs 23.291 µs 23.292 µs 22.896 µs keys => 31.569 ns 41.479 ns 22.328 ns 22.065 ns subset => 2.153 µs 407.143 ns 366.663 ns 366.137 ns merge => 1.976 µs 2.174 ns 2.173 ns 2.169 ns evaluate!! InitFromPrior => 13.792 µs 12.375 µs 11.875 µs 11.583 µs evaluate!! DefaultContext => 8.334 µs 6.708 µs 6.073 µs 5.792 µs unflatten!! => 827.381 ns 2.482 ns 2.480 ns 2.483 ns link!! => 48.542 µs 10.229 µs 7.667 µs 6.042 µs evaluate!! InitFromPrior, linked => 14.250 µs 12.083 µs 12.791 µs 13.375 µs evaluate!! DefaultContext, linked => 7.979 µs 5.959 µs 6.625 µs 7.083 µs unflatten!!, linked => 791.679 ns 2.480 ns 2.484 ns 2.480 ns Dynamic main py/orig-vnt breaking py/link Constructor => 35.292 µs 2.838 µs 3.051 µs 2.829 µs keys => 48.659 ns 44.793 ns 43.266 ns 25.074 ns subset => 11.709 µs 2.689 µs 2.725 µs 2.511 µs merge => 1.566 µs 3.844 ns 4.012 ns 3.886 ns evaluate!! InitFromPrior => 3.375 µs 1.853 µs 1.964 µs 1.889 µs evaluate!! DefaultContext => 1.234 µs 683.325 ns 663.550 ns 732.300 ns unflatten!! => 114.380 ns 5.448 ns 5.437 ns 5.425 ns link!! => 149.000 µs 12.542 µs 5.142 µs 2.403 µs evaluate!! InitFromPrior, linked => 6.448 µs 4.066 µs 4.208 µs 3.589 µs evaluate!! DefaultContext, linked => 2.608 µs 1.815 µs 1.852 µs 1.529 µs unflatten!!, linked => 116.129 ns 5.085 ns 4.978 ns 4.960 ns Parent main py/orig-vnt breaking py/link Constructor => 12.021 µs 372.831 ns 375.000 ns 304.557 ns keys => 31.518 ns 46.558 ns 39.942 ns 22.237 ns subset => 715.625 ns 35.526 ns 31.970 ns 15.133 ns merge => 484.717 ns 2.174 ns 2.171 ns 2.184 ns evaluate!! InitFromPrior => 95.930 ns 29.826 ns 29.819 ns 14.450 ns evaluate!! DefaultContext => 100.410 ns 3.132 ns 3.134 ns 3.134 ns unflatten!! => 41.185 ns 2.486 ns 2.484 ns 2.482 ns link!! => 49.208 µs 3.720 µs 1.059 µs 18.761 ns evaluate!! InitFromPrior, linked => 297.680 ns 32.290 ns 31.272 ns 16.954 ns evaluate!! DefaultContext, linked => 119.073 ns 10.701 ns 10.806 ns 11.065 ns unflatten!!, linked => 41.075 ns 2.481 ns 2.483 ns 2.554 ns LDA main py/orig-vnt breaking py/link Constructor => 122.959 µs 18.709 µs 19.333 µs 13.792 µs keys => 117.287 ns 736.842 ns 128.553 ns 99.382 ns subset => 57.083 µs 3.167 µs 2.887 µs 1.897 µs merge => 2.504 µs 186.968 ns 159.192 ns 115.927 ns evaluate!! InitFromPrior => 9.403 µs 7.570 µs 7.313 µs 4.354 µs evaluate!! DefaultContext => 8.292 µs 6.646 µs 6.334 µs 3.307 µs unflatten!! => 127.765 ns 4.681 µs 462.968 ns 324.728 ns link!! => 150.562 µs 37.750 µs 23.084 µs 6.656 µs evaluate!! InitFromPrior, linked => 11.375 µs 7.833 µs 7.861 µs 5.100 µs evaluate!! DefaultContext, linked => 7.986 µs 7.011 µs 6.667 µs 4.190 µs unflatten!!, linked => 123.391 ns 4.764 µs 461.651 ns 461.594 ns ``` # Generating a linked VarInfo directly This PR also adds new methods for `VarInfo(rng, model, link, init)`, for example, `VarInfo(model, LinkAll())` will immediately generate a linked VarInfo. I thought that this would definitely be faster than the roundabout method, but it seems to depend on the model in question, and I'm not entirely sure why. ```julia using DynamicPPL, Distributions, Chairmarks direct(m) = VarInfo(m, DynamicPPL.LinkAll()) indirect(m) = link!!(VarInfo(m), m) @model f() = x ~ Beta(2, 2); m = f() @b direct($m) # 431.222 ns (15 allocs: 528 bytes) @b indirect($m) # 158.775 ns (7 allocs: 208 bytes) @model function f2() x = Vector{Float64}(undef, 100) x .~ Beta(2, 2) end m2 = f2() @b direct($m2) # 11.146 μs (419 allocs: 17.531 KiB) @b indirect($m2) # 14.125 μs (815 allocs: 33.922 KiB) ``` --------- Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
I put together a quick sketch of what it would look like to use
VarNamedTupleas aVarInfodirectly. By that I mean having aVarInfotype that is nothing but accumulators plus aVarNamedTuplethat maps eachVarNameto a tuple (or actually a tiny struct, but anyway) of three values: Stored value for this variable, whether it's linked, and what transform should be applied to convert the stored value back to "model space". I'm calling this new VarInfo typeVNTVarInfo(name to be changed later).This isn't finished yet, but the majority of tests pass. There are a lot of failures around edge cases like Cholesky and weird VarNames and such, but for most simple models you can do
and it'll give you the correct result.
unflattenandvi[:]also work.I'll keep working on this, but at this point I wanted to pause to do some benchmarks, see how viable this is. Benchmark code, very similar to #1182, running
evaluate!!on our benchmarking models:Details
Results contrasting the new
VarInfowith both the oldVarInfoand withSimpleVarInfo{NamedTuple}andSimpleVarInfo{OrderedDict}. Some SVI NT results are missing because it couldn't handle theIndexLenses:I think a fair TL;DR is that for both small models and models with
IndexLenses this is many times faster than the oldVarInfo, and not far off fromSimpleVarInfowhenSimpleVarInfois at its fastest (NamedTuples for small models, OrderedDicts forIndexLenses). I would still like to close that gap a bit, I don't know why linking causes such a large slowdown in some cases, I suspect it's because the transform system is geared towards assuming we want to vectorise things, and I've hacked this together quickly to just get it to work.For large models performance is essentially equal, as it should be, because this is about overheads. To fix that, I need to look into using views in some clever way, but that's for later.
I think this is a promising start towards being able to say that all of
VarInfo,SimpleVarInfo, andVarNamedVectorcould be replaced with a direct use ofVarNamedTuple(as opposed to e.g. VNT wrappingVarNamedVector), and it would be pretty close to being a best-of-all-worlds solution, in that it's almost as fast as SVI and has full support for all models.Note that the new VNTVarInfo has no notion of typed and untyped VarInfos. They are all as typed as they can be, which should also help simplify code.
I'll keep working on this tomorrow.