From 2f247e353842a0cf7290ebf556704f17f8e53d35 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 20 Feb 2026 18:40:01 +0000 Subject: [PATCH] Remove `vnt_size` --- docs/src/accs/values.md | 2 +- docs/src/api.md | 1 - docs/src/vnt/arraylikeblocks.md | 26 +------ src/accumulators/pointwise_logdensities.jl | 2 +- src/accumulators/vector_values.jl | 15 ++-- src/contexts.jl | 1 + src/contexts/init.jl | 25 +++---- src/logdensityfunction.jl | 2 +- src/transformed_values.jl | 45 ++++-------- src/varinfo.jl | 11 ++- src/varnamedtuple.jl | 1 - src/varnamedtuple/getset.jl | 35 ++------- src/varnamedtuple/map.jl | 15 ---- src/varnamedtuple/partial_array.jl | 82 +++++++++++----------- 14 files changed, 85 insertions(+), 178 deletions(-) diff --git a/docs/src/accs/values.md b/docs/src/accs/values.md index 4afb563b9..c2b1ec6db 100644 --- a/docs/src/accs/values.md +++ b/docs/src/accs/values.md @@ -46,7 +46,7 @@ vi = VarInfo(dirichlet_model) vi ``` -In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref) documentation for information on this). +In `VarInfo`, it is mandatory to store `LinkedVectorValue`s or `VectorValue`s as `ArrayLikeBlock`s (see the [Array-like blocks](@ref array-like-blocks) documentation for information on this). The reason is because, if the value is linked, it may have a different size than the number of indices in the `VarName`. This means that when retrieving the keys, we obtain each block as a single key: diff --git a/docs/src/api.md b/docs/src/api.md index 40fc62585..f9c452738 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -349,7 +349,6 @@ For more details on `VarNamedTuple`, see the Internals section of our documentat ```@docs DynamicPPL.VarNamedTuples.VarNamedTuple DynamicPPL.VarNamedTuples.@vnt -DynamicPPL.VarNamedTuples.vnt_size DynamicPPL.VarNamedTuples.apply!! DynamicPPL.VarNamedTuples.densify!! DynamicPPL.VarNamedTuples.map_pairs!! diff --git a/docs/src/vnt/arraylikeblocks.md b/docs/src/vnt/arraylikeblocks.md index 6e2675358..db252240b 100644 --- a/docs/src/vnt/arraylikeblocks.md +++ b/docs/src/vnt/arraylikeblocks.md @@ -1,4 +1,4 @@ -# Array-like blocks +# [Array-like blocks](@id array-like-blocks) In a number of VNT use cases, it is necessary to associate multiple indices in a `VarNamedTuple` with an object that is not necessarily the same number of elements. @@ -87,30 +87,6 @@ Furthermore, if you set a value into any of the indices covered by the block, th vnt = DynamicPPL.templated_setindex!!(vnt, Normal(), @varname(x[2]), x) ``` -## Size checks - -Currently, when setting any object `val` as an `ArrayLikeBlock`, there is a size check: we make sure that the range of indices being set to has the same size as `DynamicPPL.VarNamedTuples.vnt_size(val)`. -By default, `vnt_size(x)` returns `Base.size(x)`. - -```@example 1 -DynamicPPL.VarNamedTuples.vnt_size(Dirichlet(ones(3))) -``` - -This is what allows us to set a `Dirichlet` distribution to three indices. -However, trying to set the same distribution to two indices will fail: - -```@repl 1 -vnt = DynamicPPL.templated_setindex!!( - VarNamedTuple(), Dirichlet(ones(3)), @varname(x[1:2]), zeros(5) -) -``` - -!!! note - - In principle, these checks can be removed since if `Dirichlet(ones(3))` is set as the prior of `x[1:2]`, then model evaluation will error anyway. - Furthermore, if at any point we need to know the size of the block, we can always retrieve it via `size(view(parent_array, alb.ix...; alb.kw...))`. - However, the checks are still here for now. - ## Which parts of DynamicPPL use array-like blocks? Simply put, anywhere where we don't store raw values. diff --git a/src/accumulators/pointwise_logdensities.jl b/src/accumulators/pointwise_logdensities.jl index 335abb88e..63e5aa4b9 100644 --- a/src/accumulators/pointwise_logdensities.jl +++ b/src/accumulators/pointwise_logdensities.jl @@ -22,7 +22,7 @@ function (plp::PointwiseLogProb{Prior,Likelihood})( return DoNotAccumulate() end end -const POINTWISE_ACCNAME = :PointwiseLogProbAccumulator +const POINTWISE_ACCNAME = :PointwiseLogProb # Not exported function get_pointwise_logprobs(varinfo::AbstractVarInfo) diff --git a/src/accumulators/vector_values.jl b/src/accumulators/vector_values.jl index b3f094860..07271c0fb 100644 --- a/src/accumulators/vector_values.jl +++ b/src/accumulators/vector_values.jl @@ -1,11 +1,10 @@ const VECTORVAL_ACCNAME = :VectorValue _get_vector_tval(val, tval::Union{VectorValue,LinkedVectorValue}, logjac, vn, dist) = tval function _get_vector_tval(val, ::UntransformedValue, logjac, vn, dist) - original_val_size = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () f = to_vec_transform(dist) new_val, logjac = with_logabsdet_jacobian(f, val) @assert iszero(logjac) # otherwise we're in trouble... - return VectorValue(new_val, inverse(f), original_val_size) + return VectorValue(new_val, inverse(f)) end # This is equivalent to `varinfo.values` where `varinfo isa VarInfo` @@ -34,12 +33,12 @@ julia> using DynamicPPL julia> # In a real setting the other fields would be filled in with meaningful values. vnt = @vnt begin - x := VectorValue([1.0, 2.0], nothing, nothing) - y := LinkedVectorValue([3.0], nothing, nothing) + x := VectorValue([1.0, 2.0], nothing) + y := LinkedVectorValue([3.0], nothing) end VarNamedTuple -├─ x => VectorValue{Vector{Float64}, Nothing, Nothing}([1.0, 2.0], nothing, nothing) -└─ y => LinkedVectorValue{Vector{Float64}, Nothing, Nothing}([3.0], nothing, nothing) +├─ x => VectorValue{Vector{Float64}, Nothing}([1.0, 2.0], nothing) +└─ y => LinkedVectorValue{Vector{Float64}, Nothing}([3.0], nothing) julia> internal_values_as_vector(vnt) 3-element Vector{Float64}: @@ -70,8 +69,8 @@ julia> # note InitFromParams provides parameters in untransformed space julia> # but because we specified LinkAll(), the vectorised values are transformed vector_vals = get_vector_values(accs) VarNamedTuple -├─ x => LinkedVectorValue{Vector{Float64}, ComposedFunction{typeof(identity), typeof(identity)}, Tuple{Int64}}([1.0, 2.0], identity ∘ identity, (2,)) -└─ y => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}, Tuple{}}([0.0], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ())), ()) +├─ x => LinkedVectorValue{Vector{Float64}, ComposedFunction{typeof(identity), typeof(identity)}}([1.0, 2.0], identity ∘ identity) +└─ y => LinkedVectorValue{Vector{Float64}, ComposedFunction{DynamicPPL.UnwrapSingletonTransform{Tuple{}}, ComposedFunction{Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}, DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}}}}([0.0], DynamicPPL.UnwrapSingletonTransform{Tuple{}}(()) ∘ (Bijectors.Inverse{Bijectors.Logit{Float64, Float64}}(Bijectors.Logit{Float64, Float64}(0.0, 1.0)) ∘ DynamicPPL.ReshapeTransform{Tuple{Int64}, Tuple{}}((1,), ()))) julia> # we can extract the internal values as a single vector internal_values_as_vector(vector_vals) diff --git a/src/contexts.jl b/src/contexts.jl index 7a14b6970..a4447651e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -168,6 +168,7 @@ function tilde_observe!!( ::Distribution, ::Any, ::Union{VarName,Nothing}, + ::Any, ::AbstractVarInfo, ) return error("tilde_observe!! not implemented for context of type $(typeof(context))") diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 894d8e6e5..05faa084e 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -112,7 +112,7 @@ function init(rng::Random.AbstractRNG, ::VarName, dist::Distribution, u::InitFro # sample. real_sz = prod(sz) y = u.lower .+ ((u.upper - u.lower) .* rand(rng, real_sz)) - return LinkedVectorValue(y, from_linked_vec_transform(dist), get_size_for_vnt(dist)) + return LinkedVectorValue(y, from_linked_vec_transform(dist)) end """ @@ -188,10 +188,10 @@ function init( # In this case, we can't trust the transform stored in x because the _value_ # in x may have been changed via unflatten!! without the transform being # updated. Therefore, we always recompute the transform here. - VectorValue(x.val, from_vec_transform(dist), x.size) + VectorValue(x.val, from_vec_transform(dist)) elseif x isa LinkedVectorValue # Same as above. - LinkedVectorValue(x.val, from_linked_vec_transform(dist), x.size) + LinkedVectorValue(x.val, from_linked_vec_transform(dist)) elseif x isa UntransformedValue x else @@ -230,10 +230,10 @@ function init( # In this case, we can't trust the transform stored in x because the _value_ # in x may have been changed via unflatten!! without the transform being # updated. Therefore, we always recompute the transform here. - VectorValue(x.val, from_vec_transform(dist), x.size) + VectorValue(x.val, from_vec_transform(dist)) elseif x isa LinkedVectorValue # Same as above. - LinkedVectorValue(x.val, from_linked_vec_transform(dist), x.size) + LinkedVectorValue(x.val, from_linked_vec_transform(dist)) elseif x isa UntransformedValue x else @@ -270,17 +270,13 @@ an unlinked value. $(TYPEDFIELDS) """ -struct RangeAndLinked{T<:Tuple} +struct RangeAndLinked # indices that the variable corresponds to in the vectorised parameter range::UnitRange{Int} # whether the variable is linked or unlinked is_linked::Bool - # original size of the variable before vectorisation - original_size::T end -VarNamedTuples.vnt_size(ral::RangeAndLinked) = ral.original_size - """ InitFromVector( vect::AbstractVector{<:Real}, @@ -328,17 +324,16 @@ end function init(::Random.AbstractRNG, vn::VarName, dist::Distribution, ifv::InitFromVector) range_and_linked = _get_range_and_linked(ifv, vn) vect = view(ifv.vect, range_and_linked.range) - sz = range_and_linked.original_size # This block here is why we store transform_strategy inside the InitFromVector, as it # allows for type stability. return if ifv.transform_strategy isa LinkAll - LinkedVectorValue(vect, from_linked_vec_transform(dist), sz) + LinkedVectorValue(vect, from_linked_vec_transform(dist)) elseif ifv.transform_strategy isa UnlinkAll - VectorValue(vect, from_vec_transform(dist), sz) + VectorValue(vect, from_vec_transform(dist)) elseif range_and_linked.is_linked - LinkedVectorValue(vect, from_linked_vec_transform(dist), sz) + LinkedVectorValue(vect, from_linked_vec_transform(dist)) else - VectorValue(vect, from_vec_transform(dist), sz) + VectorValue(vect, from_vec_transform(dist)) end end function get_param_eltype(strategy::InitFromVector) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index ec66b4d90..b13794574 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -497,7 +497,7 @@ function get_ranges_and_linked(vnt::VarNamedTuple) val = tv.val range = offset:(offset + length(val) - 1) offset += length(val) - ral = RangeAndLinked(range, tv isa LinkedVectorValue, tv.size) + ral = RangeAndLinked(range, tv isa LinkedVectorValue) template = vnt.data[AbstractPPL.getsym(vn)] ranges_vnt = templated_setindex!!(ranges_vnt, ral, vn, template) return ranges_vnt, offset diff --git a/src/transformed_values.jl b/src/transformed_values.jl index 7ad7f5ee1..607a7cd02 100644 --- a/src/transformed_values.jl +++ b/src/transformed_values.jl @@ -64,9 +64,6 @@ Subtypes of this should implement the following functions: - `DynamicPPL.set_internal_value(tv::AbstractTransformedValue, new_val)`: Create a new `AbstractTransformedValue` with the same transformation as `tv`, but with internal value `new_val`. - -- `DynamicPPL.VarNamedTuples.vnt_size(tv::AbstractTransformedValue)`: Get the size of the - original value before transformation. """ abstract type AbstractTransformedValue end @@ -102,14 +99,14 @@ internal value `new_val`. function set_internal_value end """ - VectorValue{V<:AbstractVector,T,S} + VectorValue{V<:AbstractVector,T} A transformed value that stores its internal value as a vectorised form. This is what VarInfo sees as an "unlinked value". These values can be generated when using `InitFromParams` with a VarInfo's internal values. """ -struct VectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue +struct VectorValue{V<:AbstractVector,T} <: AbstractTransformedValue "The internal (vectorised) value." val::V """The unvectorisation transform required to convert `val` back to the original space. @@ -120,23 +117,20 @@ struct VectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue evaluation occurs, the correct transform is determined from the distribution associated with the variable.""" transform::T - """The size of the original value before transformation. This is needed when a - `TransformedValue` is stored as a block in an array.""" - size::S - function VectorValue(val::V, tfm::T, size::S) where {V<:AbstractVector,T,S} - return new{V,T,S}(val, tfm, size) + function VectorValue(val::V, tfm::T) where {V<:AbstractVector,T} + return new{V,T}(val, tfm) end end """ - LinkedVectorValue{V<:AbstractVector,T,S} + LinkedVectorValue{V<:AbstractVector,T} -A transformed value that stores its internal value as a linked andvectorised form. This is +A transformed value that stores its internal value as a linked and vectorised form. This is what VarInfo sees as a "linked value". These values can be generated when using `InitFromParams` with a VarInfo's internal values. """ -struct LinkedVectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue +struct LinkedVectorValue{V<:AbstractVector,T} <: AbstractTransformedValue "The internal (linked + vectorised) value." val::V """The unlinking transform required to convert `val` back to the original space. @@ -147,33 +141,25 @@ struct LinkedVectorValue{V<:AbstractVector,T,S} <: AbstractTransformedValue evaluation occurs, the correct transform is determined from the distribution associated with the variable.""" transform::T - """The size of the original value before transformation. This is needed when a - `TransformedValue` is stored as a block in an array.""" - size::S - function LinkedVectorValue(val::V, tfm::T, size::S) where {V<:AbstractVector,T,S} - return new{V,T,S}(val, tfm, size) + function LinkedVectorValue(val::V, tfm::T) where {V<:AbstractVector,T} + return new{V,T}(val, tfm) end end for T in (:VectorValue, :LinkedVectorValue) @eval begin function Base.:(==)(tv1::$T, tv2::$T) - return (tv1.val == tv2.val) & - (tv1.transform == tv2.transform) & - (tv1.size == tv2.size) + return (tv1.val == tv2.val) & (tv1.transform == tv2.transform) end function Base.isequal(tv1::$T, tv2::$T) - return isequal(tv1.val, tv2.val) && - isequal(tv1.transform, tv2.transform) && - isequal(tv1.size, tv2.size) + return isequal(tv1.val, tv2.val) && isequal(tv1.transform, tv2.transform) end - VarNamedTuples.vnt_size(tv::$T) = tv.size get_transform(tv::$T) = tv.transform get_internal_value(tv::$T) = tv.val function set_internal_value(tv::$T, new_val) - return $T(new_val, tv.transform, tv.size) + return $T(new_val, tv.transform) end end end @@ -191,15 +177,12 @@ struct UntransformedValue{V} <: AbstractTransformedValue val::V UntransformedValue(val::V) where {V} = new{V}(val) end -VarNamedTuples.vnt_size(tv::UntransformedValue) = vnt_size(tv.val) Base.:(==)(tv1::UntransformedValue, tv2::UntransformedValue) = tv1.val == tv2.val Base.isequal(tv1::UntransformedValue, tv2::UntransformedValue) = isequal(tv1.val, tv2.val) get_transform(::UntransformedValue) = typed_identity get_internal_value(tv::UntransformedValue) = tv.val set_internal_value(::UntransformedValue, new_val) = UntransformedValue(new_val) -get_size_for_vnt(val) = hasmethod(size, Tuple{typeof(val)}) ? size(val) : () - """ abstract type AbstractTransform end @@ -396,7 +379,7 @@ function apply_transform_strategy( flink = DynamicPPL.to_linked_vec_transform(dist) linked_value, logjac = with_logabsdet_jacobian(flink, raw_value) finvlink = DynamicPPL.from_linked_vec_transform(dist) - linked_tv = LinkedVectorValue(linked_value, finvlink, get_size_for_vnt(raw_value)) + linked_tv = LinkedVectorValue(linked_value, finvlink) (raw_value, linked_tv, logjac) elseif target isa Unlink # No need to transform further @@ -419,7 +402,7 @@ function apply_transform_strategy( flink = DynamicPPL.to_linked_vec_transform(dist) linked_value, logjac = with_logabsdet_jacobian(flink, raw_value) finvlink = DynamicPPL.from_linked_vec_transform(dist) - linked_tv = LinkedVectorValue(linked_value, finvlink, get_size_for_vnt(raw_value)) + linked_tv = LinkedVectorValue(linked_value, finvlink) (raw_value, linked_tv, logjac) elseif target isa Unlink # No need to transform further diff --git a/src/varinfo.jl b/src/varinfo.jl index 0119ef47c..b5e170163 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -314,8 +314,7 @@ function setindex_with_dist!!( vi::VarInfo, tval::UntransformedValue, dist::Distribution, vn::VarName, template ) raw_value = DynamicPPL.get_internal_value(tval) - sz = hasmethod(size, (typeof(raw_value),)) ? size(raw_value) : () - tval = VectorValue(to_vec_transform(dist)(raw_value), from_vec_transform(dist), sz) + tval = VectorValue(to_vec_transform(dist)(raw_value), from_vec_transform(dist)) return setindex_with_dist!!(vi, tval, dist, vn, template) end @@ -330,9 +329,9 @@ transformation of the variable! function set_transformed!!(vi::VarInfo, linked::Bool, vn::VarName) old_tv = getindex(vi.values, vn) new_tv = if linked - LinkedVectorValue(old_tv.val, old_tv.transform, old_tv.size) + LinkedVectorValue(old_tv.val, old_tv.transform) else - VectorValue(old_tv.val, old_tv.transform, old_tv.size) + VectorValue(old_tv.val, old_tv.transform) end new_values = setindex!!(vi.values, new_tv, vn) new_transform_strategy = update_transform_strategy( @@ -358,7 +357,7 @@ set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false function set_transformed!!(vi::VarInfo, linked::Bool) ctor = linked ? LinkedVectorValue : VectorValue new_values = map_values!!(vi.values) do tv - ctor(tv.val, tv.transform, tv.size) + ctor(tv.val, tv.transform) end new_transform_strategy = linked ? LinkAll() : UnlinkAll() return VarInfo(new_transform_strategy, new_values, vi.accs) @@ -528,7 +527,7 @@ for T in (:VectorValue, :LinkedVectorValue) len = length(old_val) new_val = @view vci.vec[(vci.index):(vci.index + len - 1)] vci.index += len - return $T(new_val, tv.transform, tv.size) + return $T(new_val, tv.transform) end end end diff --git a/src/varnamedtuple.jl b/src/varnamedtuple.jl index 455e2b732..12e0c6870 100644 --- a/src/varnamedtuple.jl +++ b/src/varnamedtuple.jl @@ -8,7 +8,6 @@ using BangBang using DynamicPPL: DynamicPPL export VarNamedTuple, - vnt_size, map_pairs!!, map_values!!, apply!!, diff --git a/src/varnamedtuple/getset.jl b/src/varnamedtuple/getset.jl index 449bd5ac5..939bfd7c5 100644 --- a/src/varnamedtuple/getset.jl +++ b/src/varnamedtuple/getset.jl @@ -162,12 +162,7 @@ function _setindex_optic!!( # doesn't yet have enough indices for that slice. Expand it if so. pa = grow_to_indices!!(pa, coptic.ix...; coptic.kw...) - is_multiindex = if template isa AbstractArray || template isa PartialArray - _is_multiindex(template, coptic.ix...; coptic.kw...) - else - isempty(coptic.kw) || throw_kw_error() - _is_multiindex_static(coptic.ix) - end + is_multiindex = _is_multiindex(template, coptic.ix...; coptic.kw...) if permissions isa MustNotOverwrite if any(view(pa.mask, coptic.ix...; coptic.kw...)) @@ -285,15 +280,6 @@ function _setindex_optic!!( return VarNamedTuple(merge(vnt.data, NamedTuple{(S,)}((sub_value,)))) end -@generated function _is_multiindex_static(::T) where {T<:Tuple} - for x in T.parameters - if x <: AbstractVector{<:Int} || x <: Colon - return :(return true) - end - end - return :(return false) -end - """ make_leaf(value, optic, template) @@ -332,16 +318,7 @@ function make_leaf(value, optic::AbstractPPL.Index, template) # them. This also helpfully catches errors if there is a dynamic index and a suitable # template is not provided (e.g., if someone tries to set `x[end]` without a template). coptic = AbstractPPL.concretize_top_level(optic, template) - is_multiindex = if template isa AbstractArray || template isa PartialArray - _is_multiindex(template, coptic.ix...; coptic.kw...) - else - # This handles the case where no template is provided, or a nonsense template is - # provided. - isempty(coptic.kw) || throw_kw_error() - # This will error if there are things like colons. - _is_multiindex_static(coptic.ix) - end - return if is_multiindex + return if _is_multiindex(template, coptic.ix...; coptic.kw...) make_leaf_multiindex(value, coptic, template) else make_leaf_singleindex(value, coptic, template) @@ -423,12 +400,8 @@ function make_leaf_multiindex(value, coptic::AbstractPPL.Index, template) pa_eltype = if sub_value isa AbstractArray || sub_value isa PartialArray eltype(sub_value) else - ArrayLikeBlock{ - typeof(sub_value), - typeof(coptic.ix), - typeof(coptic.kw), - typeof(vnt_size(value)), - } + idx_size_type = Dims{_ndims(template, coptic.ix...; coptic.kw...)} + ArrayLikeBlock{typeof(sub_value),typeof(coptic.ix),typeof(coptic.kw),idx_size_type} end # The rest is the same as the single-index case. diff --git a/src/varnamedtuple/map.jl b/src/varnamedtuple/map.jl index d9e0868e3..833df8805 100644 --- a/src/varnamedtuple/map.jl +++ b/src/varnamedtuple/map.jl @@ -263,27 +263,12 @@ function _map_values_recursive_pa_noalb!!(func, pa::PartialArray) end end -function _check_size(new_block, old_block) - sz_new = vnt_size(new_block) - sz_old = vnt_size(old_block) - if sz_new != sz_old - throw( - DimensionMismatch( - "map_pairs!! can't change the size of a block. Tried to change " * - "from $(sz_old) to $(sz_new).", - ), - ) - end -end function _map_pairs_recursive!!(pairfunc, alb::ArrayLikeBlock, vn) - # new_block = _map_pairs_recursive!!(pairfunc, alb.block, vn) new_block = pairfunc(vn => alb.block) - _check_size(new_block, alb.block) return ArrayLikeBlock(new_block, alb.ix, alb.kw, alb.index_size) end function _map_values_recursive!!(func, alb::ArrayLikeBlock) new_block = _map_values_recursive!!(func, alb.block) - _check_size(new_block, alb.block) return ArrayLikeBlock(new_block, alb.ix, alb.kw, alb.index_size) end diff --git a/src/varnamedtuple/partial_array.jl b/src/varnamedtuple/partial_array.jl index ad60524a2..686e15270 100644 --- a/src/varnamedtuple/partial_array.jl +++ b/src/varnamedtuple/partial_array.jl @@ -27,17 +27,6 @@ merge its nested values as well. """ _merge(_, x2, _) = x2 -""" - vnt_size(x) - -Get the size of an object `x` for use in `VarNamedTuple` and `PartialArray`. - -By default, this falls back onto `Base.size`, but can be overloaded for custom types. -This notion of type is used to determine whether a value can be set into a `PartialArray` -as a block, see the docstring of `PartialArray` and `ArrayLikeBlock` for details. -""" -vnt_size(x) = size(x) - """ ArrayLikeBlock{T,I,N,S} @@ -162,13 +151,14 @@ end An array-like like structure that may only have some of its elements defined. One can set values in a `PartialArray` either element-by-element, or with ranges like -`arr[1:3,2] = [5,10,15]`. When setting values over a range of indices, the value being set -must either be an `AbstractArray` or otherwise something for which `vnt_size(value)` or -`Base.size(value)` (which `vnt_size` falls back onto) is defined, and the size matches the -range. If the value is an `AbstractArray`, the elements are copied individually, but if it -is not, the value is stored as a block, that takes up the whole range, e.g. `[1:3,2]`, but -is only a single object. Getting such a block-value must be done with the exact same range -of indices, otherwise an error is thrown. +`arr[1:3,2] = [5,10,15]`. + +When setting values over a range of indices, the value being set can be an `AbstractArray` +whose size matches the range (in which case the values are set elementwise). If the value is +some other object, it can still be stored as an `ArrayLikeBlock`. Retrieving such a +block-value must be done with the exact same range of indices, otherwise an error is thrown. +Please see [the DynamicPPL documentation](@ref array-like-blocks) for more information on +this. If the element type of a `PartialArray` is not concrete, any call to `setindex!!` will check if, after the new value has been set, the element type can be made more concrete. If so, @@ -214,14 +204,6 @@ Base.eltype(::PartialArray{ElType}) where {ElType} = ElType Base.size(pa::PartialArray) = size(pa.data) Base.isassigned(pa::PartialArray, ix...; kw...) = isassigned(pa.data, ix...; kw...) -# Even though a PartialArray may have its own size, we still allow it to be used as an -# ArrayLikeBlock. This enables setting values for keys like @varname(x[1:3][1]), which will -# be stored as a PartialArray wrapped in an ArrayLikeBlock, stored in another PartialArray. -# Note that this bypasses _any_ size checks, so that e.g. @varname(x[1:3][1,15]) is also a -# valid key. -# TODO(penelopeysm) check if this is still needed. -vnt_size(pa::PartialArray) = size(pa) - function Base.copy(pa::PartialArray) # Make a shallow copy of pa, except for any VarNamedTuple elements, which we recursively # copy. @@ -485,11 +467,39 @@ function _remove_partial_blocks!!( return pa_data, pa_mask end -function _is_multiindex(f::AbstractArray, ix...; kw...) - return ndims(view(f, ix...; kw...)) > 0 +@inline function _ndims(f::AbstractArray, ix...; kw...) + return ndims(view(f, ix...; kw...)) +end +@inline function _ndims(f::PartialArray, ix...; kw...) + return ndims(view(f.data, ix...; kw...)) end -function _is_multiindex(f::PartialArray, ix...; kw...) - return ndims(view(f.data, ix...; kw...)) > 0 +@inline function _ndims(::Any, ix...; kw...) + isempty(kw) || throw_kw_error() + return _ndims_static(ix) +end +@generated function _ndims_static(::T) where {T<:Tuple} + i = 0 + for x in T.parameters + if x <: AbstractVector{<:Int} || x <: Colon + i += 1 + end + end + return :(return $i) +end +@inline function _is_multiindex(f::Union{AbstractArray,PartialArray}, ix...; kw...) + return _ndims(f, ix...; kw...) > 0 +end +@inline function _is_multiindex(::Any, ix...; kw...) + isempty(kw) || throw_kw_error() + return _is_multiindex_static(ix) +end +@generated function _is_multiindex_static(::T) where {T<:Tuple} + for x in T.parameters + if x <: AbstractVector{<:Int} || x <: Colon + return :(return true) + end + end + return :(return false) end """ @@ -503,7 +513,6 @@ The value only depends on the types of the arguments, and should be constant pro function _needs_arraylikeblock(pa_data::AbstractArray, value, inds::Vararg{Any}; kw...) return !isa(value, AbstractArray) && !isa(value, PartialArray) && - hasmethod(vnt_size, Tuple{typeof(value)}) && _is_multiindex(pa_data, inds...; kw...) end @@ -569,18 +578,7 @@ function BangBang.setindex!!(pa::PartialArray, value, inds::Vararg{Any}; kw...) new_data, new_mask = _remove_partial_blocks!!(new_data, new_mask, inds...; kw...) if _needs_arraylikeblock(new_data, value, inds...; kw...) - # Check that we're trying to set a block that has the right size. idx_sz = size(@view new_data[inds..., kw...]) - - # vnt_sz = vnt_size(value) - # if vnt_sz != idx_sz - # throw( - # DimensionMismatch( - # "Assigned value has size $(vnt_sz), which does not match " * - # "the size implied by the indices $(idx_sz).", - # ), - # ) - # end alb = ArrayLikeBlock(value, inds, NamedTuple(kw), idx_sz) new_data = setindex!!(new_data, fill(alb, idx_sz...), inds...; kw...) fill!(view(new_mask, inds...; kw...), true)