Skip to content

Commit 77af4eb

Browse files
committed
Merge branch 'main' into breaking
2 parents 262d732 + 2020741 commit 77af4eb

File tree

8 files changed

+70
-33
lines changed

8 files changed

+70
-33
lines changed

.github/workflows/IntegrationTest.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ jobs:
4343
4444
- name: Load this and run the downstream tests
4545
shell: julia --color=yes --project=downstream {0}
46-
# Don't test Turing.jl on 1.12.0, it's broken because of Libtask
47-
if: ${{ !(steps.julia-version.outputs.julia == '1.12.0' && matrix.package.repo == 'Turing.jl') }}
4846
run: |
4947
using Pkg
5048
try

HISTORY.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22

33
## 0.39.0
44

5+
Removed the method `returned(::Model, values, keys)`; please use `returned(::Model, ::AbstractDict{<:VarName})` instead.
6+
7+
## 0.38.3
8+
9+
Add an implementation of `returned(::Model, ::AbstractDict{<:VarName})`.
10+
Please note we generally recommend using Dict, as NamedTuples cannot correctly represent variables with indices / fields on the left-hand side of tildes, like `x[1]` or `x.a`.
11+
12+
The generic method `returned(::Model, values, keys)` is deprecated and will be removed in the next minor version.
13+
14+
## 0.38.2
15+
16+
Added a compatibility entry for JET@0.11.
17+
18+
> > > > > > > main
19+
520
## 0.38.1
621

722
Added `from_linked_vec_transform` and `from_vec_transform` methods for `ProductNamedTupleDistribution`.

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ DocStringExtensions = "0.9"
6262
EnzymeCore = "0.6 - 0.8"
6363
ForwardDiff = "0.10.12, 1"
6464
InteractiveUtils = "1"
65-
JET = "0.9, 0.10"
65+
JET = "0.9, 0.10, 0.11"
6666
KernelAbstractions = "0.9.33"
6767
LinearAlgebra = "1.6"
6868
LogDensityProblems = "2"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ DocumenterMermaid = "0.1, 0.2"
2222
DynamicPPL = "0.38"
2323
FillArrays = "0.13, 1"
2424
ForwardDiff = "0.10, 1"
25-
JET = "0.9, 0.10"
25+
JET = "0.9, 0.10, 0.11"
2626
LogDensityProblems = "2"
2727
MarginalLogDensities = "0.4"
2828
MCMCChains = "5, 6, 7"

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,11 @@ It is possible to manually increase (or decrease) the accumulated log likelihood
176176
@addlogprob!
177177
```
178178

179-
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples) or a single sample represented as a `NamedTuple`.
179+
Return values of the model function can be obtained with [`returned(model, sample)`](@ref), where `sample` is either a `MCMCChains.Chains` object (which represents a collection of samples), or a single sample represented as a `NamedTuple` or a dictionary of VarNames.
180180

181181
```@docs
182182
returned(::DynamicPPL.Model, ::MCMCChains.Chains)
183-
returned(::DynamicPPL.Model, ::NamedTuple)
183+
returned(::DynamicPPL.Model, ::Union{NamedTuple,AbstractDict{<:VarName}})
184184
```
185185

186186
For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using

src/model.jl

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,44 +1103,36 @@ function predict end
11031103

11041104
"""
11051105
returned(model::Model, parameters::NamedTuple)
1106-
returned(model::Model, values, keys)
1107-
returned(model::Model, values, keys)
1106+
returned(model::Model, parameters::AbstractDict{<:VarName})
11081107
11091108
Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.
11101109
1111-
If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameters)`.
1112-
11131110
# Example
11141111
```jldoctest
11151112
julia> using DynamicPPL, Distributions
11161113
1117-
julia> @model function demo(xs)
1118-
s ~ InverseGamma(2, 3)
1119-
m_shifted ~ Normal(10, √s)
1120-
m = m_shifted - 10
1121-
for i in eachindex(xs)
1122-
xs[i] ~ Normal(m, √s)
1123-
end
1124-
return (m, )
1114+
julia> @model function demo()
1115+
m ~ Normal()
1116+
return (mp1 = m + 1,)
11251117
end
11261118
demo (generic function with 2 methods)
11271119
1128-
julia> model = demo(randn(10));
1129-
1130-
julia> parameters = (; s = 1.0, m_shifted=10.0);
1120+
julia> model = demo();
11311121
1132-
julia> returned(model, parameters)
1133-
(0.0,)
1122+
julia> returned(model, (; m = 1.0))
1123+
(mp1 = 2.0,)
11341124
1135-
julia> returned(model, values(parameters), keys(parameters))
1136-
(0.0,)
1125+
julia> returned(model, Dict{VarName,Float64}(@varname(m) => 2.0))
1126+
(mp1 = 3.0,)
11371127
```
11381128
"""
1139-
function returned(model::Model, parameters::NamedTuple)
1140-
fixed_model = fix(model, parameters)
1141-
return fixed_model()
1142-
end
1143-
1144-
function returned(model::Model, values, keys)
1145-
return returned(model, NamedTuple{keys}(values))
1129+
function returned(model::Model, parameters::Union{NamedTuple,AbstractDict{<:VarName}})
1130+
vi = DynamicPPL.setaccs!!(VarInfo(), ())
1131+
# Note: we can't use `fix(model, parameters)` because
1132+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1097
1133+
# Use `nothing` as the fallback to ensure that any missing parameters cause an error
1134+
ctx = InitContext(Random.default_rng(), InitFromParams(parameters, nothing))
1135+
new_model = setleafcontext(model, ctx)
1136+
# We can't use new_model() because that overwrites it with an InitContext of its own.
1137+
return first(evaluate!!(new_model, vi))
11461138
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Distributions = "0.25"
4141
DistributionsAD = "0.6.3"
4242
Documenter = "1"
4343
ForwardDiff = "0.10.12, 1"
44-
JET = "0.9, 0.10"
44+
JET = "0.9, 0.10, 0.11"
4545
LogDensityProblems = "2"
4646
MCMCChains = "7.2.1"
4747
MacroTools = "0.5.6"

test/model.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,38 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
321321
end
322322
end
323323

324+
@testset "returned() on NamedTuple / Dict" begin
325+
@model function demo_returned()
326+
a ~ Normal()
327+
b ~ Normal()
328+
return (asq=a^2, bsq=b^2)
329+
end
330+
model = demo_returned()
331+
332+
@testset "NamedTuple" begin
333+
params = (a=1.0, b=2.0)
334+
results = returned(model, params)
335+
@test results.asq == params.a^2
336+
@test results.bsq == params.b^2
337+
# `returned` should error when not all parameters are provided
338+
@test_throws ErrorException returned(model, (; a=1.0))
339+
@test_throws ErrorException returned(model, (a=1.0, b=missing))
340+
end
341+
@testset "Dict" begin
342+
params = Dict{VarName,Float64}(@varname(a) => 1.0, @varname(b) => 2.0)
343+
results = returned(model, params)
344+
@test results.asq == params[@varname(a)]^2
345+
@test results.bsq == params[@varname(b)]^2
346+
# `returned` should error when not all parameters are provided
347+
@test_throws ErrorException returned(
348+
model, Dict{VarName,Float64}(@varname(a) => 1.0)
349+
)
350+
@test_throws ErrorException returned(
351+
model, Dict{VarName,Any}(@varname(a) => 1.0, @varname(b) => missing)
352+
)
353+
end
354+
end
355+
324356
@testset "returned() on `LKJCholesky`" begin
325357
n = 10
326358
d = 2

0 commit comments

Comments
 (0)