1- # assume
2- function tilde_assume!! (context:: AbstractContext , right:: Distribution , vn, vi)
1+ """
2+ DynamicPPL.tilde_assume!!(
3+ context::AbstractContext,
4+ right::Distribution,
5+ vn::VarName,
6+ vi::AbstractVarInfo
7+ )
8+
9+ Handle assumed variables, i.e. anything which is not observed (see
10+ [`tilde_observe!!`](@ref)). Accumulate the associated log probability, and return the
11+ sampled value and updated `vi`.
12+
13+ `vn` is the VarName on the left-hand side of the tilde statement.
14+ """
15+ function tilde_assume!! (
16+ context:: AbstractContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
17+ )
318 return tilde_assume!! (childcontext (context), right, vn, vi)
419end
5- function tilde_assume!! (:: DefaultContext , right:: Distribution , vn, vi)
20+ function tilde_assume!! (
21+ :: DefaultContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
22+ )
623 y = getindex_internal (vi, vn)
724 f = from_maybe_linked_internal_transform (vi, vn, right)
825 x, inv_logjac = with_logabsdet_jacobian (f, y)
926 vi = accumulate_assume!! (vi, x, - inv_logjac, vn, right)
1027 return x, vi
1128end
12- function tilde_assume!! (context:: PrefixContext , right:: Distribution , vn, vi)
29+ function tilde_assume!! (
30+ context:: PrefixContext , right:: Distribution , vn:: VarName , vi:: AbstractVarInfo
31+ )
1332 # Note that we can't use something like this here:
1433 # new_vn = prefix(context, vn)
1534 # return tilde_assume!!(childcontext(context), right, new_vn, vi)
@@ -22,24 +41,62 @@ function tilde_assume!!(context::PrefixContext, right::Distribution, vn, vi)
2241 new_vn, new_context = prefix_and_strip_contexts (context, vn)
2342 return tilde_assume!! (new_context, right, new_vn, vi)
2443end
25-
2644"""
27- tilde_assume!!(context, right, vn, vi)
45+ DynamicPPL.tilde_assume!!(
46+ context::AbstractContext,
47+ right::DynamicPPL.Submodel,
48+ vn::VarName,
49+ vi::AbstractVarInfo
50+ )
2851
29- Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
30- accumulate the log probability, and return the sampled value and updated `vi`.
52+ Evaluate the submodel with the given context.
3153"""
32- function tilde_assume!! (context, right:: DynamicPPL.Submodel , vn, vi)
54+ function tilde_assume!! (
55+ context:: AbstractContext , right:: DynamicPPL.Submodel , vn:: VarName , vi:: AbstractVarInfo
56+ )
3357 return _evaluate!! (right, vi, context, vn)
3458end
3559
36- # observe
37- function tilde_observe!! (context:: AbstractContext , right, left, vn, vi)
60+ """
61+ tilde_observe!!(
62+ context::AbstractContext,
63+ right::Distribution,
64+ left,
65+ vn::Union{VarName, Nothing},
66+ vi::AbstractVarInfo
67+ )
68+
69+ This function handles observed variables, which may be:
70+
71+ - literals on the left-hand side, e.g., `3.0 ~ Normal()`
72+ - a model input, e.g. `x ~ Normal()` in a model `@model f(x) ... end`
73+ - a conditioned or fixed variable, e.g. `x ~ Normal()` in a model `model | (; x = 3.0)`.
74+
75+ The relevant log-probability associated with the observation is computed and accumulated in
76+ the VarInfo object `vi` (except for fixed variables, which do not contribute to the
77+ log-probability).
78+
79+ `left` is the actual value that the left-hand side evaluates to. `vn` is the VarName on the
80+ left-hand side, or `nothing` if the left-hand side is a literal value.
81+
82+ Observations of submodels are not yet supported in DynamicPPL.
83+ """
84+ function tilde_observe!! (
85+ context:: AbstractContext ,
86+ right:: Distribution ,
87+ left,
88+ vn:: Union{VarName,Nothing} ,
89+ vi:: AbstractVarInfo ,
90+ )
3891 return tilde_observe!! (childcontext (context), right, left, vn, vi)
3992end
40-
41- # `PrefixContext`
42- function tilde_observe!! (context:: PrefixContext , right, left, vn, vi)
93+ function tilde_observe!! (
94+ context:: PrefixContext ,
95+ right:: Distribution ,
96+ left,
97+ vn:: Union{VarName,Nothing} ,
98+ vi:: AbstractVarInfo ,
99+ )
43100 # In the observe case, unlike assume, `vn` may be `nothing` if the LHS is a literal
44101 # value. For the need for prefix_and_strip_contexts rather than just prefix, see the
45102 # comment in `tilde_assume!!`.
@@ -50,21 +107,22 @@ function tilde_observe!!(context::PrefixContext, right, left, vn, vi)
50107 end
51108 return tilde_observe!! (new_context, right, left, new_vn, vi)
52109end
53-
54- """
55- tilde_observe!!(context, right, left, vn, vi)
56-
57- Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs),
58- accumulate the log probability, and return the observed value and updated `vi`.
59-
60- Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name
61- and indices; if needed, these can be accessed through this function, though.
62- """
63- function tilde_observe!! (:: DefaultContext , right:: Distribution , left, vn, vi)
110+ function tilde_observe!! (
111+ :: DefaultContext ,
112+ right:: Distribution ,
113+ left,
114+ vn:: Union{VarName,Nothing} ,
115+ vi:: AbstractVarInfo ,
116+ )
64117 vi = accumulate_observe!! (vi, right, left, vn)
65118 return left, vi
66119end
67-
68- function tilde_observe!! (:: DefaultContext , :: DynamicPPL.Submodel , left, vn, vi)
120+ function tilde_observe!! (
121+ :: AbstractContext ,
122+ :: DynamicPPL.Submodel ,
123+ left,
124+ vn:: Union{VarName,Nothing} ,
125+ :: AbstractVarInfo ,
126+ )
69127 throw (ArgumentError (" `x ~ to_submodel(...)` is not supported when `x` is observed" ))
70128end
0 commit comments