From 073ec92f1a5866b206771452a03e42dc2e414864 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Oct 2025 15:23:35 +0000 Subject: [PATCH 01/11] Use tighter element types in VNV --- src/varnamedvector.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 2c66e1245..ae10834e8 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1020,6 +1020,7 @@ function insert_internal!!( end vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform)) insert_internal!(vnv, val, vn, transform) + vnv = tighten_types!!(vnv) return vnv end @@ -1029,6 +1030,7 @@ function update_internal!!( transform_resolved = transform === nothing ? gettransform(vnv, vn) : transform vnv = loosen_types!!(vnv, typeof(vn), eltype(val), typeof(transform_resolved)) update_internal!(vnv, val, vn, transform) + vnv = tighten_types!!(vnv) return vnv end From 52edf9ab147b4fefa5dbe7a177beae0d242a31ba Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Oct 2025 16:20:15 +0000 Subject: [PATCH 02/11] Add type tightness tests for VNV --- test/Project.toml | 2 + test/runtests.jl | 1 + test/varnamedvector.jl | 86 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index de2160f4f..c96087d66 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" @@ -34,6 +35,7 @@ AbstractMCMC = "5" AbstractPPL = "0.13" Accessors = "0.1" Aqua = "0.8" +BangBang = "0.4" Bijectors = "0.15.1" Combinatorics = "1" DifferentiationInterface = "0.6.41, 0.7" diff --git a/test/runtests.jl b/test/runtests.jl index b6a3f7bf6..7a9c12525 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using ADTypes using DynamicPPL using AbstractMCMC using AbstractPPL +using BangBang: delete!!, setindex!! using Bijectors using DifferentiationInterface using Distributions diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index b764d517b..e7340aa01 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -579,6 +579,92 @@ end @test is_transformed(vnv, @varname(t[1])) @test subset(vnv, vns) == vnv end + + @testset "loosen and tighten types" begin + + """ + test_tightenability(vnv::VarNamedVector) + + Test that tighten_types!! is a no-op on `vnv`. + """ + function test_tightenability(vnv::DynamicPPL.VarNamedVector) + @test vnv == DynamicPPL.tighten_types!!(deepcopy(vnv)) + # TODO(mhauru) We would like to check something more stringent here, namely that + # the operation is compiled to a direct no-op, with no instructions at all. I + # don't know how to do that though, so for now we just check that it doesn't + # allocate. + @allocations(DynamicPPL.tighten_types!!(vnv)) == 0 + return nothing + end + + vn = @varname(a[1]) + # Test that tighten_types!! is a no-op on an empty VarNamedVector. + vnv = DynamicPPL.VarNamedVector() + @test DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + # Also check that it literally returns the same object, and both tighten and loosen + # are type stable. + @test vnv === DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + # Likewise for a VarNamedVector with something pushed into it. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + @test DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + @test vnv === DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + # Likewise for a VarNamedVector with abstract element-types, when that is needed for + # the current contents because mixed types have been pushed into it. However, this + # time, since the types are only as tight as they can be, but not actually concrete, + # tighten_types!! can't be type stable. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + vnv = setindex!!(vnv, 2, @varname(b)) + @test ~DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + # Likewise when first mixed types are pushed, but then deleted. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + vnv = setindex!!(vnv, 2, @varname(b)) + @test ~DynamicPPL.is_tightly_typed(vnv) + vnv = delete!!(vnv, vn) + @test DynamicPPL.is_tightly_typed(vnv) + test_tightenability(vnv) + @test vnv === DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.tighten_types!!(vnv) + @inferred DynamicPPL.loosen_types!!(vnv, VarName, Any, Any) + + # Test that loosen_types!! does really loosen them and that tighten_types!! reverts + # that. + vnv = DynamicPPL.VarNamedVector() + vnv = setindex!!(vnv, 1.0, vn) + @test DynamicPPL.is_tightly_typed(vnv) + k = eltype(vnv.varnames) + e = eltype(vnv.vals) + t = eltype(vnv.transforms) + # Loosen key type. + vnv = @inferred DynamicPPL.loosen_types!!(vnv, VarName, e, t) + @test ~DynamicPPL.is_tightly_typed(vnv) + vnv = DynamicPPL.tighten_types!!(vnv) + @test DynamicPPL.is_tightly_typed(vnv) + # Loosen element type + vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, Real, t) + @test ~DynamicPPL.is_tightly_typed(vnv) + vnv = DynamicPPL.tighten_types!!(vnv) + @test DynamicPPL.is_tightly_typed(vnv) + # Loosen transformation type + vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, Function) + @test ~DynamicPPL.is_tightly_typed(vnv) + vnv = DynamicPPL.tighten_types!!(vnv) + @test DynamicPPL.is_tightly_typed(vnv) + # Loosening to the same types as currently should do nothing. + vnv = @inferred DynamicPPL.loosen_types!!(vnv, k, e, t) + @test DynamicPPL.is_tightly_typed(vnv) + @allocations(DynamicPPL.loosen_types!!(vnv, k, e, t)) == 0 + end end @testset "VarInfo + VarNamedVector" begin From 839d4ce3eb6132db8bb2341b70de9c31911079d8 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Oct 2025 17:00:20 +0000 Subject: [PATCH 03/11] Fix some uses of OrderedDict in VNV tests --- test/varnamedvector.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index e7340aa01..20dc11145 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -79,7 +79,7 @@ function relax_container_types(vnv::DynamicPPL.VarNamedVector, vn::VarName, val) end function relax_container_types(vnv::DynamicPPL.VarNamedVector, vns, vals) if need_varnames_relaxation(vnv, vns, vals) - varname_to_index_new = convert(OrderedDict{VarName,Int}, vnv.varname_to_index) + varname_to_index_new = convert(Dict{VarName,Int}, vnv.varname_to_index) varnames_new = convert(Vector{VarName}, vnv.varnames) else varname_to_index_new = vnv.varname_to_index @@ -517,7 +517,7 @@ end @testset "deterministic" begin n = 5 vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) + vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) @test !DynamicPPL.has_inactive(vnv) # Growing should not create inactive ranges. for i in 1:n @@ -543,7 +543,7 @@ end @testset "random" begin n = 5 vn = @varname(x) - vnv = DynamicPPL.VarNamedVector(OrderedDict(vn => [true])) + vnv = DynamicPPL.VarNamedVector(Dict(vn => [true])) @test !DynamicPPL.has_inactive(vnv) # Insert a bunch of random-length vectors. From b1108003a0f002b2b8f264fb9faa42fdb0c3e081 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Oct 2025 17:01:09 +0000 Subject: [PATCH 04/11] Improvements to VNV loosen/tighten types --- src/varnamedvector.jl | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index ae10834e8..3f0be180b 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -341,10 +341,13 @@ function ==(vnv_left::VarNamedVector, vnv_right::VarNamedVector) vnv_left.num_inactive == vnv_right.num_inactive end -function is_concretely_typed(vnv::VarNamedVector) - return isconcretetype(eltype(vnv.varnames)) && - isconcretetype(eltype(vnv.vals)) && - isconcretetype(eltype(vnv.transforms)) +function is_tightly_typed(vnv::VarNamedVector) + k = eltype(vnv.varnames) + v = eltype(vnv.vals) + t = eltype(vnv.transforms) + return (isconcretetype(k) || k === Union{}) && + (isconcretetype(v) || v === Union{}) && + (isconcretetype(t) || t === Union{}) end getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] @@ -880,7 +883,16 @@ function loosen_types!!( return if vn_type == K && val_type == V && transform_type == T vnv elseif isempty(vnv) - VarNamedVector(vn_type[], val_type[], transform_type[]) + VarNamedVector( + Dict{vn_type,Int}(), + Vector{vn_type}(), + UnitRange{Int}[], + Vector{val_type}(), + Vector{transform_type}(), + BitVector(), + Dict{Int,Int}(); + check_consistency=false, + ) else # TODO(mhauru) We allow a `vnv` to have any AbstractVector type as its vals, but # then here always revert to Vector. @@ -944,7 +956,7 @@ julia> vnv_tight.transforms ``` """ function tighten_types!!(vnv::VarNamedVector) - return if is_concretely_typed(vnv) + return if is_tightly_typed(vnv) # There can not be anything to tighten, so short-circuit. vnv elseif isempty(vnv) From 69727138a9c1ebdc517b3f50aa0f805de46815ea Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Thu, 30 Oct 2025 17:23:08 +0000 Subject: [PATCH 05/11] Run formatter --- src/varnamedvector.jl | 4 ++-- test/varnamedvector.jl | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 3f0be180b..e70e4fa4d 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -346,8 +346,8 @@ function is_tightly_typed(vnv::VarNamedVector) v = eltype(vnv.vals) t = eltype(vnv.transforms) return (isconcretetype(k) || k === Union{}) && - (isconcretetype(v) || v === Union{}) && - (isconcretetype(t) || t === Union{}) + (isconcretetype(v) || v === Union{}) && + (isconcretetype(t) || t === Union{}) end getidx(vnv::VarNamedVector, vn::VarName) = vnv.varname_to_index[vn] diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index 20dc11145..3b3259768 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -581,7 +581,6 @@ end end @testset "loosen and tighten types" begin - """ test_tightenability(vnv::VarNamedVector) From 0b527c3067071c900a02846571d9061b08498f49 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 11:23:53 +0000 Subject: [PATCH 06/11] Don't recontiguify VNVs unnecessarily --- src/varnamedvector.jl | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index e70e4fa4d..17b851d1d 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1118,6 +1118,9 @@ care about them. This is in a sense the reverse operation of `vnv[:]`. +The return value may share memory with the input `vnv`, and thus one can not be mutated +safely without affecting the other. + Unflatten recontiguifies the internal storage, getting rid of any inactive entries. # Examples @@ -1139,15 +1142,20 @@ function unflatten(vnv::VarNamedVector, vals::AbstractVector) ), ) end - new_ranges = deepcopy(vnv.ranges) - recontiguify_ranges!(new_ranges) + new_ranges = vnv.ranges + num_inactive = vnv.num_inactive + if has_inactive(vnv) + new_ranges = recontiguify_ranges!(new_ranges) + num_inactive = Dict{Int,Int}() + end return VarNamedVector( vnv.varname_to_index, vnv.varnames, new_ranges, vals, vnv.transforms, - vnv.is_unconstrained; + vnv.is_unconstrained, + num_inactive; check_consistency=false, ) end @@ -1442,6 +1450,9 @@ julia> vnv[@varname(x)] # All the values are still there. ``` """ function contiguify!(vnv::VarNamedVector) + if !has_inactive(vnv) + return vnv + end # Extract the re-contiguified values. # NOTE: We need to do this before we update the ranges. old_vals = copy(vnv.vals) From 48e93ef8a163dfcc343595f5beda3dd0c425885e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 16:33:47 +0000 Subject: [PATCH 07/11] contiguify VNV after (inv)linking --- src/varinfo.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index a90b81488..486d24191 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1297,6 +1297,10 @@ function _link_metadata!!( metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) set_transformed!(metadata, true, vn) end + # Linking can often change the sizes of variables, causing inactive elements. We don't + # want to keep them around, since typically linking is done once and then the VarInfo + # is evaluated multiple times. Hence we contiguify here. + metadata = contiguify!(metadata) return metadata, cumulative_logjac end @@ -1465,6 +1469,10 @@ function _invlink_metadata!!( metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) set_transformed!(metadata, false, vn) end + # Linking can often change the sizes of variables, causing inactive elements. We don't + # want to keep them around, since typically linking is done once and then the VarInfo + # is evaluated multiple times. Hence we contiguify here. + metadata = contiguify!(metadata) return metadata, cumulative_inv_logjac end From e2c8a7254241f2a886fbdb227f8d096ed05534f6 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 27 Oct 2025 13:58:23 +0000 Subject: [PATCH 08/11] Make VarNamedVector the default VarInfo backend. Introduce (un)typed_legacy_varinfo --- src/abstract_varinfo.jl | 2 +- src/varinfo.jl | 34 ++++++++++++++++++++++++++-------- test/contexts.jl | 9 ++++++--- test/debug_utils.jl | 15 ++++++++++++--- test/varinfo.jl | 14 +++++++++----- 5 files changed, 54 insertions(+), 20 deletions(-) diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index ec5e1ea10..9f9a3720a 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -595,7 +595,7 @@ OrderedDict{VarName{sym, typeof(identity)} where sym, Float64} with 2 entries: m => 2.0 julia> values_as(vi, Vector) -2-element Vector{Real}: +2-element Vector{Float64}: 1.0 2.0 ``` diff --git a/src/varinfo.jl b/src/varinfo.jl index 486d24191..b2e116f3b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -107,7 +107,7 @@ struct VarInfo{Tmeta,Accs<:AccumulatorTuple} <: AbstractVarInfo metadata::Tmeta accs::Accs end -function VarInfo(meta=Metadata()) +function VarInfo(meta=VarNamedVector()) return VarInfo(meta, default_accumulators()) end @@ -194,8 +194,20 @@ end # VarInfo constructors # ######################## +function untyped_varinfo( + rng::Random.AbstractRNG, + model::Model, + init_strategy::AbstractInitStrategy=InitFromPrior(), +) + return untyped_vector_varinfo(rng, model, init_strategy) +end + +function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) + return untyped_vector_varinfo(Random.default_rng(), model, init_strategy) +end + """ - untyped_varinfo([rng, ]model[, init_strategy]) + untyped_legacy_varinfo([rng, ]model[, init_strategy]) Construct a VarInfo object for the given `model`, which has just a single `Metadata` as its metadata field. @@ -205,19 +217,21 @@ Construct a VarInfo object for the given `model`, which has just a single - `model::Model`: The model for which to create the varinfo object - `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ -function untyped_varinfo( +function untyped_legacy_varinfo( rng::Random.AbstractRNG, model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) return last(init!!(rng, model, VarInfo(Metadata()), init_strategy)) end -function untyped_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) - return untyped_varinfo(Random.default_rng(), model, init_strategy) +function untyped_legacy_varinfo( + model::Model, init_strategy::AbstractInitStrategy=InitFromPrior() +) + return untyped_legacy_varinfo(Random.default_rng(), model, init_strategy) end """ - typed_varinfo(vi::UntypedVarInfo) + typed_legacy_varinfo(vi::UntypedVarInfo) This function finds all the unique `sym`s from the instances of `VarName{sym}` found in `vi.metadata.vns`. It then extracts the metadata associated with each symbol from the @@ -225,7 +239,7 @@ global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `meta a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each symbol. """ -function typed_varinfo(vi::UntypedVarInfo) +function typed_legacy_varinfo(vi::UntypedVarInfo) meta = vi.metadata new_metas = Metadata[] # Symbols of all instances of `VarName{sym}` in `vi.vns` @@ -289,12 +303,16 @@ function typed_varinfo( model::Model, init_strategy::AbstractInitStrategy=InitFromPrior(), ) - return typed_varinfo(untyped_varinfo(rng, model, init_strategy)) + return typed_vector_varinfo(rng, model, init_strategy) end function typed_varinfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior()) return typed_varinfo(Random.default_rng(), model, init_strategy) end +function typed_varinfo(vi::UntypedVectorVarInfo) + return typed_vector_varinfo(vi) +end + """ untyped_vector_varinfo([rng, ]model[, init_strategy]) diff --git a/test/contexts.jl b/test/contexts.jl index 972d833a5..70e2fec86 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -417,12 +417,15 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() @testset "InitContext" begin empty_varinfos = [ - ("untyped+metadata", VarInfo()), - ("typed+metadata", DynamicPPL.typed_varinfo(VarInfo())), + ("untyped+metadata", VarInfo(DynamicPPL.Metadata())), + ( + "typed+metadata", + DynamicPPL.typed_legacy_varinfo(VarInfo(DynamicPPL.Metadata())), + ), ("untyped+VNV", VarInfo(DynamicPPL.VarNamedVector())), ( "typed+VNV", - DynamicPPL.typed_vector_varinfo(DynamicPPL.typed_varinfo(VarInfo())), + DynamicPPL.typed_vector_varinfo(VarInfo(DynamicPPL.VarNamedVector())), ), ("SVI+NamedTuple", SimpleVarInfo()), ("Svi+Dict", SimpleVarInfo(Dict{VarName,Any}())), diff --git a/test/debug_utils.jl b/test/debug_utils.jl index f950f6b45..45c7415d6 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -76,8 +76,11 @@ return nothing end buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) + @test_throws "should not subsume each other" DynamicPPL.untyped_varinfo( + buggy_model + ) + varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess @@ -94,8 +97,11 @@ return nothing end buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) + @test_throws "should not subsume each other" DynamicPPL.untyped_varinfo( + buggy_model + ) + varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess @@ -112,8 +118,11 @@ return nothing end buggy_model = buggy_subsumes_demo_model() - varinfo = VarInfo(buggy_model) + @test_throws "should not subsume each other" DynamicPPL.untyped_varinfo( + buggy_model + ) + varinfo = DynamicPPL.untyped_legacy_varinfo(buggy_model) @test_logs (:warn,) (:warn,) check_model(buggy_model, varinfo) issuccess = check_model(buggy_model, varinfo) @test !issuccess diff --git a/test/varinfo.jl b/test/varinfo.jl index 6b31fbe91..49e5bbac7 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -37,8 +37,10 @@ end end model = gdemo(1.0, 2.0) - _, vi = DynamicPPL.init!!(model, VarInfo(), InitFromUniform()) - tvi = DynamicPPL.typed_varinfo(vi) + # TODO(mhauru) Make this test more generic. It currently explicitly relies on + # Metadata. + _, vi = DynamicPPL.init!!(model, VarInfo(DynamicPPL.Metadata()), InitFromUniform()) + tvi = DynamicPPL.typed_legacy_varinfo(vi) meta = vi.metadata for f in fieldnames(typeof(tvi.metadata)) @@ -290,7 +292,7 @@ end dist = Normal(0, 1) r = rand(dist) - push!!(vi, vn_x, r, dist) + vi = push!!(vi, vn_x, r, dist) # is_transformed is set by default @test !is_transformed(vi, vn_x) @@ -353,7 +355,9 @@ end # worth specifically checking that it can do this without having to # change the VarInfo object. # TODO(penelopeysm): Move this to InitFromUniform tests rather than here. - vi = VarInfo() + # TODO(mhauru) Make this test more generic. It currently explicitly relies on + # Metadata. + vi = VarInfo(DynamicPPL.Metadata()) meta = vi.metadata _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) @test all(x -> !is_transformed(vi, x), meta.vns) @@ -367,7 +371,7 @@ end @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values - vi = DynamicPPL.typed_varinfo(vi) + vi = DynamicPPL.typed_legacy_varinfo(vi) meta = vi.metadata v_s = copy(meta.s.vals) v_m = copy(meta.m.vals) From d1ac32c70fa897c1a448f0583732ad45e615e66b Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 18:33:56 +0000 Subject: [PATCH 09/11] Fix setindex!! for VarInfo --- src/varinfo.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index b2e116f3b..523cbe9ec 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1585,9 +1585,15 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) + function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) - setindex!(vi, val, vn) - return vi + md = setindex!!(getmetadata(vi, vn), val, vn) + return VarInfo(md, vi.accs) +end + +function BangBang.setindex!!(vi::NTVarInfo, val, vn::VarName) + submd = setindex!!(getmetadata(vi, vn), val, vn) + return Accessors.@set vi.metadata[getsym(vn)] = submd end @inline function findvns(vi, f_vns) From 10035d74dd3a6205da72052927714f257848c076 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 18:38:44 +0000 Subject: [PATCH 10/11] Renamed UntypedVarInfo to UntypedLegacyVarInfo, and make UntypedVarInfo be UntypedVectorVarInfo --- src/varinfo.jl | 33 +++++++++++++++++---------------- test/test_util.jl | 2 +- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 523cbe9ec..3a08b8896 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -90,7 +90,7 @@ the left-hand side of tilde statements. For example, `x[1]` and `x[2]` both have the same symbol `x`. Several type aliases are provided for these forms of VarInfos: -- `VarInfo{<:Metadata}` is `UntypedVarInfo` +- `VarInfo{<:Metadata}` is `UntypedLegacyVarInfo` - `VarInfo{<:VarNamedVector}` is `UntypedVectorVarInfo` - `VarInfo{<:NamedTuple}` is `NTVarInfo` @@ -143,7 +143,7 @@ function VarInfo(model::Model, init_strategy::AbstractInitStrategy=InitFromPrior end const UntypedVectorVarInfo = VarInfo{<:VarNamedVector} -const UntypedVarInfo = VarInfo{<:Metadata} +const UntypedLegacyVarInfo = VarInfo{<:Metadata} # TODO: NTVarInfo carries no information about the type of the actual metadata # i.e. the elements of the NamedTuple. It could be Metadata or it could be # VarNamedVector. @@ -154,6 +154,7 @@ const NTVarInfo = VarInfo{<:NamedTuple} const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ VarInfo{Tmeta},ThreadSafeVarInfo{<:VarInfo{Tmeta}} } +const UntypedVarInfo = UntypedVectorVarInfo function Base.:(==)(vi1::VarInfo, vi2::VarInfo) return (vi1.metadata == vi2.metadata && vi1.accs == vi2.accs) @@ -231,7 +232,7 @@ function untyped_legacy_varinfo( end """ - typed_legacy_varinfo(vi::UntypedVarInfo) + typed_legacy_varinfo(vi::UntypedLegacyVarInfo) This function finds all the unique `sym`s from the instances of `VarName{sym}` found in `vi.metadata.vns`. It then extracts the metadata associated with each symbol from the @@ -239,7 +240,7 @@ global `vi.metadata` field. Finally, a new `VarInfo` is created with a new `meta a `NamedTuple` mapping from symbols to type-stable `Metadata` instances, one for each symbol. """ -function typed_legacy_varinfo(vi::UntypedVarInfo) +function typed_legacy_varinfo(vi::UntypedLegacyVarInfo) meta = vi.metadata new_metas = Metadata[] # Symbols of all instances of `VarName{sym}` in `vi.vns` @@ -324,7 +325,7 @@ Return a VarInfo object for the given `model`, which has just a single - `model::Model`: The model for which to create the varinfo object - `init_strategy::AbstractInitStrategy`: How the values are to be initialised. Defaults to `InitFromPrior()`. """ -function untyped_vector_varinfo(vi::UntypedVarInfo) +function untyped_vector_varinfo(vi::UntypedLegacyVarInfo) md = metadata_to_varnamedvector(vi.metadata) return VarInfo(md, copy(vi.accs)) end @@ -644,11 +645,11 @@ end const VarView = Union{Int,UnitRange,Vector{Int}} """ - setval!(vi::UntypedVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) + setval!(vi::UntypedLegacyVarInfo, val, vview::Union{Int, UnitRange, Vector{Int}}) Set the value of `vi.vals[vview]` to `val`. """ -setval!(vi::UntypedVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val +setval!(vi::UntypedLegacyVarInfo, val, vview::VarView) = vi.metadata.vals[vview] = val """ getmetadata(vi::VarInfo, vn::VarName) @@ -843,10 +844,10 @@ set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, Returns a tuple of the unique symbols of random variables in `vi`. """ -syms(vi::UntypedVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols +syms(vi::UntypedLegacyVarInfo) = Tuple(unique!(map(getsym, vi.metadata.vns))) # get all symbols syms(vi::NTVarInfo) = keys(vi.metadata) -_getidcs(vi::UntypedVarInfo) = 1:length(vi.metadata.idcs) +_getidcs(vi::UntypedLegacyVarInfo) = 1:length(vi.metadata.idcs) _getidcs(vi::NTVarInfo) = _getidcs(vi.metadata) @generated function _getidcs(metadata::NamedTuple{names}) where {names} @@ -967,7 +968,7 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, vns, model) end -function _link!!(vi::UntypedVarInfo, vns) +function _link!!(vi::UntypedLegacyVarInfo, vns) # TODO: Change to a lazy iterator over `vns` if ~is_transformed(vi, vns[1]) for vn in vns @@ -1081,7 +1082,7 @@ function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) return maybe_invlink_before_eval!!(t, vi, model) end -function _invlink!!(vi::UntypedVarInfo, vns) +function _invlink!!(vi::UntypedLegacyVarInfo, vns) if is_transformed(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) @@ -1495,7 +1496,7 @@ function _invlink_metadata!!( end # TODO(mhauru) The treatment of the case when some variables are transformed and others are -# not should be revised. It used to be the case that for UntypedVarInfo `is_transformed` +# not should be revised. It used to be the case that for UntypedLegacyVarInfo `is_transformed` # returned whether the first variable was linked. For NTVarInfo we did an OR over the first # variables under each symbol. We now more consistently use OR, but I'm not convinced this # is really the right thing to do. @@ -1618,7 +1619,7 @@ function Base.haskey(vi::NTVarInfo, vn::VarName) return any(md_haskey) end -function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) +function Base.show(io::IO, ::MIME"text/plain", vi::UntypedLegacyVarInfo) lines = Tuple{String,Any}[ ("VarNames", vi.metadata.vns), ("Range", vi.metadata.ranges), @@ -1673,7 +1674,7 @@ function _show_varnames(io::IO, vi) end end -function Base.show(io::IO, vi::UntypedVarInfo) +function Base.show(io::IO, vi::UntypedLegacyVarInfo) print(io, "VarInfo (") _show_varnames(io, vi) print(io, "; accumulators: ") @@ -1845,11 +1846,11 @@ end values_as(vi::VarInfo) = vi.metadata values_as(vi::VarInfo, ::Type{Vector}) = copy(getindex_internal(vi, Colon())) -function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) +function values_as(vi::UntypedLegacyVarInfo, ::Type{NamedTuple}) iter = values_from_metadata(vi.metadata) return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) end -function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict} +function values_as(vi::UntypedLegacyVarInfo, ::Type{D}) where {D<:AbstractDict} return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata)) end diff --git a/test/test_util.jl b/test/test_util.jl index 164751c7b..3a7ea0028 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -23,7 +23,7 @@ function short_varinfo_name(vi::DynamicPPL.NTVarInfo) "TypedVarInfo" end end -short_varinfo_name(::DynamicPPL.UntypedVarInfo) = "UntypedVarInfo" +short_varinfo_name(::DynamicPPL.UntypedLegacyVarInfo) = "UntypedLegacyVarInfo" short_varinfo_name(::DynamicPPL.UntypedVectorVarInfo) = "UntypedVectorVarInfo" function short_varinfo_name(::SimpleVarInfo{<:NamedTuple,<:Ref}) return "SimpleVarInfo{<:NamedTuple,<:Ref}" From 852c971d4a69f5305a3081d53ecead84b5f9d9ca Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 31 Oct 2025 18:39:31 +0000 Subject: [PATCH 11/11] Fix a JET test --- test/ext/DynamicPPLJETExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl index 8ed29e0c7..245367047 100644 --- a/test/ext/DynamicPPLJETExt.jl +++ b/test/ext/DynamicPPLJETExt.jl @@ -40,7 +40,7 @@ end end @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa - DynamicPPL.UntypedVarInfo + DynamicPPL.NTVarInfo # In this model, the type error occurs in the user code rather than in DynamicPPL. @model function demo5()