|
25 | 25 | @test length(chains) == N |
26 | 26 |
|
27 | 27 | # `m` is Gaussian, i.e. no transformation is used, so it |
28 | | - # should have a mean equal to its prior, i.e. 2. |
29 | | - @test mean(vi[@varname(m)] for vi in chains) ≈ 2 atol = 0.1 |
| 28 | + # will be drawn from U[-2, 2] and its mean should be 0. |
| 29 | + @test mean(vi[@varname(m)] for vi in chains) ≈ 0.0 atol = 0.1 |
30 | 30 |
|
31 | 31 | # Expected value of ``exp(X)`` where ``X ~ U[-2, 2]`` is ≈ 1.8. |
32 | 32 | @test mean(vi[@varname(s)] for vi in chains) ≈ 1.8 atol = 0.1 |
|
81 | 81 | model = coinflip() |
82 | 82 | sampler = Sampler(alg) |
83 | 83 | lptrue = logpdf(Binomial(25, 0.2), 10) |
84 | | - let inits = (; p=0.2) |
85 | | - chain = sample( |
86 | | - model, sampler, 1; initial_params=ParamsInit(inits), progress=false |
87 | | - ) |
| 84 | + let inits = ParamsInit((; p=0.2)) |
| 85 | + chain = sample(model, sampler, 1; initial_params=inits, progress=false) |
88 | 86 | @test chain[1].metadata.p.vals == [0.2] |
89 | 87 | @test getlogjoint(chain[1]) == lptrue |
90 | 88 |
|
|
111 | 109 | end |
112 | 110 | model = twovars() |
113 | 111 | lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) |
114 | | - for inits in ([4, -1], (; s=4, m=-1)) |
115 | | - chain = sample( |
116 | | - model, sampler, 1; initial_params=ParamsInit(inits), progress=false |
117 | | - ) |
| 112 | + let inits = ParamsInit((; s=4, m=-1)) |
| 113 | + chain = sample(model, sampler, 1; initial_params=inits, progress=false) |
118 | 114 | @test chain[1].metadata.s.vals == [4] |
119 | 115 | @test chain[1].metadata.m.vals == [-1] |
120 | 116 | @test getlogjoint(chain[1]) == lptrue |
|
126 | 122 | MCMCThreads(), |
127 | 123 | 1, |
128 | 124 | 10; |
129 | | - initial_params=fill(ParamsInit(inits), 10), |
| 125 | + initial_params=fill(inits, 10), |
130 | 126 | progress=false, |
131 | 127 | ) |
132 | 128 | for c in chains |
|
137 | 133 | end |
138 | 134 |
|
139 | 135 | # set only m = -1 |
140 | | - for inits in ((; s=missing, m=-1), (; m=-1)) |
141 | | - chain = sample( |
142 | | - model, sampler, 1; initial_params=ParamsInit(inits), progress=false |
143 | | - ) |
| 136 | + for inits in (ParamsInit((; s=missing, m=-1)), ParamsInit((; m=-1))) |
| 137 | + chain = sample(model, sampler, 1; initial_params=inits, progress=false) |
144 | 138 | @test !ismissing(chain[1].metadata.s.vals[1]) |
145 | 139 | @test chain[1].metadata.m.vals == [-1] |
146 | 140 |
|
|
0 commit comments