@@ -358,7 +358,7 @@ function typed_vector_varinfo(
358358end
359359
360360function make_leaf_metadata ((r, dist), optic)
361- md = Metadata (Float64)
361+ md = Metadata (Float64, VarName{ :_ } )
362362 vn = VarName {:_} (optic)
363363 push! (md, vn, r, dist)
364364 return md
@@ -439,13 +439,13 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
439439
440440Construct an empty type unstable instance of `Metadata`.
441441"""
442- function Metadata (eltype= Real)
442+ function Metadata (eltype= Real, vntype = VarName )
443443 vals = Vector {eltype} ()
444444 is_transformed = BitVector ()
445445
446446 return Metadata (
447- Dict {VarName ,Int} (),
448- Vector {VarName } (),
447+ Dict {vntype ,Int} (),
448+ Vector {vntype } (),
449449 Vector {UnitRange{Int}} (),
450450 vals,
451451 Vector {Distribution} (),
@@ -814,7 +814,7 @@ The values may or may not be transformed to Euclidean space.
814814setval! (vi:: VarInfo , val, vn:: VarName ) = setval! (getmetadata (vi, vn), val, vn)
815815function setval! (vi:: TupleVarInfo , val, vn:: VarName )
816816 main_vn, optic = split_trailing_index (vn)
817- return setval! (getindex (vi. metadata, main_vn), VarName {:_} (optic))
817+ return setval! (getindex (vi. metadata, main_vn), val, VarName {:_} (optic))
818818end
819819function setval! (md:: Metadata , val:: AbstractVector , vn:: VarName )
820820 return md. vals[getrange (md, vn)] = val
@@ -1914,3 +1914,80 @@ end
19141914function from_linked_internal_transform (:: VarNamedVector , :: VarName , dist)
19151915 return from_linked_vec_transform (dist)
19161916end
1917+
1918+ function link (vi:: TupleVarInfo , model:: Model )
1919+ metadata = map (value -> link (value, model), vi. metadata)
1920+ return VarInfo (metadata, vi. accs)
1921+ end
1922+
1923+ function link (metadata:: Metadata , model:: Model )
1924+ vns = metadata. vns
1925+ cumulative_logjac = zero (LogProbType)
1926+
1927+ # Construct the new transformed values, and keep track of their lengths.
1928+ vals_new = map (vns) do vn
1929+ # Return early if we're already in unconstrained space.
1930+ # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
1931+ if is_transformed (metadata, vn)
1932+ return metadata. vals[getrange (metadata, vn)]
1933+ end
1934+
1935+ # Transform to constrained space.
1936+ x = getindex_internal (metadata, vn)
1937+ dist = getdist (metadata, vn)
1938+ f_from_internal = from_internal_transform (metadata, vn, dist)
1939+ f_to_linked_internal = inverse (from_linked_internal_transform (metadata, vn, dist))
1940+ f = f_to_linked_internal ∘ f_from_internal
1941+ y, logjac = with_logabsdet_jacobian (f, x)
1942+ # Vectorize value.
1943+ yvec = tovec (y)
1944+ # Accumulate the log-abs-det jacobian correction.
1945+ cumulative_logjac += logjac
1946+ # Return the vectorized transformed value.
1947+ return yvec
1948+ end
1949+
1950+ # Determine new ranges.
1951+ ranges_new = similar (metadata. ranges)
1952+ offset = 0
1953+ for (i, v) in enumerate (vals_new)
1954+ r_start, r_end = offset + 1 , length (v) + offset
1955+ offset = r_end
1956+ ranges_new[i] = r_start: r_end
1957+ end
1958+
1959+ # Now we just create a new metadata with the new `vals` and `ranges`.
1960+ return Metadata (
1961+ metadata. idcs,
1962+ metadata. vns,
1963+ ranges_new,
1964+ reduce (vcat, vals_new),
1965+ metadata. dists,
1966+ BitVector (fill (true , length (metadata. vns))),
1967+ )
1968+ end
1969+
1970+ function Base. haskey (vi:: TupleVarInfo , vn:: VarName )
1971+ # TODO (mhauru) Fix this to account for the index.
1972+ main_vn, optic = split_trailing_index (vn)
1973+ haskey (vi. metadata, main_vn) || return false
1974+ value = getindex (vi. metadata, main_vn)
1975+ if value isa Metadata
1976+ return haskey (value, VarName {:_} (optic))
1977+ else
1978+ error (" TODO(mhauru) Implement me" )
1979+ end
1980+ end
1981+
1982+ function BangBang. setindex!! (metadata:: Metadata , val, optic)
1983+ return setindex!! (metadata, val, VarName {:_} (optic))
1984+ end
1985+
1986+ function BangBang. setindex!! (metadata:: Metadata , (r, dist), vn:: VarName )
1987+ if haskey (metadata, vn)
1988+ setval! (metadata, r, vn)
1989+ else
1990+ push! (metadata, vn, r, dist)
1991+ end
1992+ return metadata
1993+ end
0 commit comments