Skip to content

Commit cf2d323

Browse files
committed
Various fixes
1 parent 2aa9b5e commit cf2d323

File tree

3 files changed

+92
-14
lines changed

3 files changed

+92
-14
lines changed

src/contexts/transformation.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ function tilde_assume!!(
2121
# vi[vn, right] always provides the value in unlinked space.
2222
x = vi[vn, right]
2323

24-
if is_transformed(vi, vn)
25-
isinverse || @warn "Trying to link an already transformed variable ($vn)"
26-
else
27-
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
28-
end
24+
# TODO(mhauru) Warnings disabled for benchmarking purposes
25+
# if is_transformed(vi, vn)
26+
# isinverse || @warn "Trying to link an already transformed variable ($vn)"
27+
# else
28+
# isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
29+
# end
2930

3031
transform = isinverse ? identity : link_transform(right)
3132
y, logjac = with_logabsdet_jacobian(transform, x)

src/varinfo.jl

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ function typed_vector_varinfo(
358358
end
359359

360360
function make_leaf_metadata((r, dist), optic)
361-
md = Metadata(Float64)
361+
md = Metadata(Float64, VarName{:_})
362362
vn = VarName{:_}(optic)
363363
push!(md, vn, r, dist)
364364
return md
@@ -439,13 +439,13 @@ unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x)
439439
440440
Construct an empty type unstable instance of `Metadata`.
441441
"""
442-
function Metadata(eltype=Real)
442+
function Metadata(eltype=Real, vntype=VarName)
443443
vals = Vector{eltype}()
444444
is_transformed = BitVector()
445445

446446
return Metadata(
447-
Dict{VarName,Int}(),
448-
Vector{VarName}(),
447+
Dict{vntype,Int}(),
448+
Vector{vntype}(),
449449
Vector{UnitRange{Int}}(),
450450
vals,
451451
Vector{Distribution}(),
@@ -814,7 +814,7 @@ The values may or may not be transformed to Euclidean space.
814814
setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn)
815815
function setval!(vi::TupleVarInfo, val, vn::VarName)
816816
main_vn, optic = split_trailing_index(vn)
817-
return setval!(getindex(vi.metadata, main_vn), VarName{:_}(optic))
817+
return setval!(getindex(vi.metadata, main_vn), val, VarName{:_}(optic))
818818
end
819819
function setval!(md::Metadata, val::AbstractVector, vn::VarName)
820820
return md.vals[getrange(md, vn)] = val
@@ -1914,3 +1914,80 @@ end
19141914
function from_linked_internal_transform(::VarNamedVector, ::VarName, dist)
19151915
return from_linked_vec_transform(dist)
19161916
end
1917+
1918+
function link(vi::TupleVarInfo, model::Model)
1919+
metadata = map(value -> link(value, model), vi.metadata)
1920+
return VarInfo(metadata, vi.accs)
1921+
end
1922+
1923+
function link(metadata::Metadata, model::Model)
1924+
vns = metadata.vns
1925+
cumulative_logjac = zero(LogProbType)
1926+
1927+
# Construct the new transformed values, and keep track of their lengths.
1928+
vals_new = map(vns) do vn
1929+
# Return early if we're already in unconstrained space.
1930+
# HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check.
1931+
if is_transformed(metadata, vn)
1932+
return metadata.vals[getrange(metadata, vn)]
1933+
end
1934+
1935+
# Transform to constrained space.
1936+
x = getindex_internal(metadata, vn)
1937+
dist = getdist(metadata, vn)
1938+
f_from_internal = from_internal_transform(metadata, vn, dist)
1939+
f_to_linked_internal = inverse(from_linked_internal_transform(metadata, vn, dist))
1940+
f = f_to_linked_internal f_from_internal
1941+
y, logjac = with_logabsdet_jacobian(f, x)
1942+
# Vectorize value.
1943+
yvec = tovec(y)
1944+
# Accumulate the log-abs-det jacobian correction.
1945+
cumulative_logjac += logjac
1946+
# Return the vectorized transformed value.
1947+
return yvec
1948+
end
1949+
1950+
# Determine new ranges.
1951+
ranges_new = similar(metadata.ranges)
1952+
offset = 0
1953+
for (i, v) in enumerate(vals_new)
1954+
r_start, r_end = offset + 1, length(v) + offset
1955+
offset = r_end
1956+
ranges_new[i] = r_start:r_end
1957+
end
1958+
1959+
# Now we just create a new metadata with the new `vals` and `ranges`.
1960+
return Metadata(
1961+
metadata.idcs,
1962+
metadata.vns,
1963+
ranges_new,
1964+
reduce(vcat, vals_new),
1965+
metadata.dists,
1966+
BitVector(fill(true, length(metadata.vns))),
1967+
)
1968+
end
1969+
1970+
function Base.haskey(vi::TupleVarInfo, vn::VarName)
1971+
# TODO(mhauru) Fix this to account for the index.
1972+
main_vn, optic = split_trailing_index(vn)
1973+
haskey(vi.metadata, main_vn) || return false
1974+
value = getindex(vi.metadata, main_vn)
1975+
if value isa Metadata
1976+
return haskey(value, VarName{:_}(optic))
1977+
else
1978+
error("TODO(mhauru) Implement me")
1979+
end
1980+
end
1981+
1982+
function BangBang.setindex!!(metadata::Metadata, val, optic)
1983+
return setindex!!(metadata, val, VarName{:_}(optic))
1984+
end
1985+
1986+
function BangBang.setindex!!(metadata::Metadata, (r, dist), vn::VarName)
1987+
if haskey(metadata, vn)
1988+
setval!(metadata, r, vn)
1989+
else
1990+
push!(metadata, vn, r, dist)
1991+
end
1992+
return metadata
1993+
end

src/varname.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
function remove_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
4646
return if Optic === typeof(identity)
4747
vn
48-
elseif Optic isa IndexLens
48+
elseif Optic <: Accessors.IndexLens
4949
VarName{sym}()
5050
else
5151
prefix(remove_trailing_index(unprefix(vn, sym)), sym)
@@ -55,10 +55,10 @@ end
5555
function split_trailing_index(vn::VarName{sym,Optic}) where {sym,Optic}
5656
return if Optic === typeof(identity)
5757
(vn, identity)
58-
elseif Optic isa IndexLens
59-
(VarName{sym}(), Optic.index)
58+
elseif Optic <: Accessors.IndexLens
59+
(VarName{sym}(), getoptic(vn))
6060
else
61-
(prefix, index) = split_trailing_index(unprefix(vn, sym))
61+
(prefix, index) = split_trailing_index(unprefix(vn, VarName{sym}()))
6262
(prefix(prefix, sym), index)
6363
end
6464
end

0 commit comments

Comments
 (0)