Skip to content

Commit f8a91a1

Browse files
authored
Move external sampler interface to AbstractMCMC (#2704)
Following on from: - TuringLang/AbstractMCMC.jl#182 Adding the interface function `AbstractMCMC.getstats` - TuringLang/AdvancedHMC.jl#471 and TuringLang/AdvancedMH.jl#119 Implementing them for AdvancedHMC and AdvancedMH this PR changes Turing's external sampler interface to exclusively use AbstractMCMC functions. **With this PR, anyone who defines an external sampler will only need to depend on AbstractMCMC (which they presumably already do, because it is a sampler) and LogDensityProblems (which is already a dep of AbstractMCMC).** No need for a Turing extension. To be precise, it makes the following changes: - Previously where one had to define `Turing.Inference.getparams`, now one has to define `AbstractMCMC.getparams`. - Previously there was no way to include sampler stats in the resulting chain, now one can define `AbstractMCMC.getstats`. - The default for `Turing.Inference.isgibbscomponent` is changed to `true`, so that external sampler packages don't need to override it (unless absolutely necessary). As an example implementation, this PR contains a test mock (note how it doesn't require a Turing dep): https://github.com/TuringLang/Turing.jl/blob/06752c41a97cd0dc37bb8700ab7a2d06f50f4f76/test/mcmc/external_sampler.jl#L20-L74
1 parent be007f3 commit f8a91a1

File tree

7 files changed

+106
-97
lines changed

7 files changed

+106
-97
lines changed

HISTORY.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
# 0.42.0
22

3+
## External sampler interface
4+
5+
The interface for defining an external sampler has been reworked.
6+
In general, implementations of external samplers should now no longer need to depend on Turing.
7+
This is because the interface functions required have been shifted upstream to AbstractMCMC.jl.
8+
9+
In particular, you now only need to define the following functions:
10+
11+
- AbstractMCMC.step(rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, ::MySampler; kwargs...) (and also a method with `state`, and the corresponding `step_warmup` methods if needed)
12+
- AbstractMCMC.getparams(::MySamplerState) -> Vector{<:Real}
13+
- AbstractMCMC.getstats(::MySamplerState) -> NamedTuple
14+
- AbstractMCMC.requires_unconstrained_space(::MySampler) -> Bool (default `true`)
15+
16+
This means that you only need to depend on AbstractMCMC.jl.
17+
As long as the above functions are defined correctly, Turing will be able to use your external sampler.
18+
19+
The `Turing.Inference.isgibbscomponent(::MySampler)` interface function still exists, but in this version the default has been changed to `true`, so you should not need to overload this.
20+
321
# 0.41.0
422

523
## DynamicPPL 0.38

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ TuringOptimExt = ["Optim", "AbstractPPL"]
4949

5050
[compat]
5151
ADTypes = "1.9"
52-
AbstractMCMC = "5.5"
52+
AbstractMCMC = "5.9"
5353
AbstractPPL = "0.11, 0.12, 0.13"
5454
Accessors = "0.1"
55-
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
56-
AdvancedMH = "0.8"
55+
AdvancedHMC = "0.8.3"
56+
AdvancedMH = "0.8.9"
5757
AdvancedPS = "0.7"
5858
AdvancedVI = "0.4"
5959
BangBang = "0.4.2"

src/mcmc/Inference.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,10 +429,10 @@ include("hmc.jl")
429429
include("mh.jl")
430430
include("is.jl")
431431
include("particle_mcmc.jl")
432-
include("gibbs.jl")
433432
include("sghmc.jl")
434433
include("emcee.jl")
435434
include("prior.jl")
435+
include("gibbs.jl")
436436

437437
################
438438
# Typing tools #

src/mcmc/external_sampler.jl

Lines changed: 68 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained}
2+
ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType}
33
44
Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng,
55
::DynamicPPL.Model, spl)`.
@@ -14,45 +14,59 @@ $(TYPEDFIELDS)
1414
If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl
1515
models, there are two options:
1616
17-
1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the
18-
most powerful option and is what Turing.jl's in-house samplers do. Implementing this
19-
means that you can directly call `sample(model, MySampler(), N)`.
17+
1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. That is to
18+
say, implement `AbstractMCMC.step(rng::Random.AbstractRNG, model::DynamicPPL.Model,
19+
sampler::MySampler; kwargs...)` and related methods. This is the most powerful option and
20+
is what Turing.jl's in-house samplers do. Implementing this means that you can directly
21+
call `sample(model, MySampler(), N)`.
2022
21-
2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This
22-
struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
23+
2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel` (the
24+
same signature as above except that `model::AbstractMCMC.LogDensityModel`). This struct
25+
wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
2326
implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use
2427
this with Turing.jl, you will need to wrap your sampler: `sample(model,
2528
externalsampler(MySampler()), N)`.
2629
2730
This section describes the latter.
2831
29-
`MySampler` must implement the following methods:
32+
`MySampler` **must** implement the following methods:
3033
3134
- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is
32-
documented in AbstractMCMC.jl)
33-
- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the
34-
parameters from the transition returned by your sampler (i.e., the first return value of
35-
`step`). There is a default implementation for this method, which is to return
36-
`external_transition.θ`.
37-
38-
!!! note
39-
In a future breaking release of Turing, this is likely to change to
40-
`AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method.
41-
`Turing.Inference.getparams` is technically an internal method, so the aim here is to
42-
unify the interface for samplers at a higher level.
35+
documented in AbstractMCMC.jl). This function must return a tuple of two elements, a
36+
'transition' and a 'state'.
37+
38+
- `AbstractMCMC.getparams(external_state)`: How to extract the parameters from the **state**
39+
returned by your sampler (i.e., the **second** return value of `step`). For your sampler
40+
to work with Turing.jl, this function should return a Vector of parameter values. Note that
41+
this function does not need to perform any linking or unlinking; Turing.jl will take care of
42+
this for you. You should return the parameters *exactly* as your sampler sees them.
43+
44+
- `AbstractMCMC.getstats(external_state)`: Extract sampler statistics corresponding to this
45+
iteration from the **state** returned by your sampler (i.e., the **second** return value
46+
of `step`). For your sampler to work with Turing.jl, this function should return a
47+
`NamedTuple`. If there are no statistics to return, return `NamedTuple()`.
48+
49+
Note that `getstats` should not include log-probabilities as these will be recalculated by
50+
Turing automatically for you.
51+
52+
Notice that both of these functions take the **state** as input, not the **transition**. In
53+
other words, the transition is completely useless for the external sampler interface. This is
54+
in line with long-term plans for removing transitions from AbstractMCMC.jl and only using
55+
states.
4356
4457
There are a few more optional functions which you can implement to improve the integration
4558
with Turing.jl:
4659
47-
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as
48-
a component in Turing's Gibbs sampler, you should make this evaluate to `true`.
49-
50-
- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires
60+
- `AbstractMCMC.requires_unconstrained_space(::MySampler)`: If your sampler requires
5161
unconstrained space, you should return `true`. This tells Turing to perform linking on the
5262
VarInfo before evaluation, and ensures that the parameter values passed to your sampler
5363
will always be in unconstrained (Euclidean) space.
64+
65+
- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want to disallow your sampler
66+
from a component in Turing's Gibbs sampler, you should make this evaluate to `false`. Note
67+
that the default is `true`, so you should only need to implement this in special cases.
5468
"""
55-
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
69+
struct ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType} <:
5670
AbstractSampler
5771
"the sampler to wrap"
5872
sampler::S
@@ -67,47 +81,42 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
6781
# Arguments
6882
- `sampler::AbstractSampler`: The sampler to wrap.
6983
- `adtype::ADTypes.AbstractADType`: The automatic differentiation (AD) backend to use.
70-
- `unconstrained::Val=Val{true}()`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
84+
- `unconstrained::Val`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
7185
"""
7286
function ExternalSampler(
73-
sampler::AbstractSampler,
74-
adtype::ADTypes.AbstractADType,
75-
(::Val{unconstrained})=Val(true),
87+
sampler::AbstractSampler, adtype::ADTypes.AbstractADType, ::Val{unconstrained}
7688
) where {unconstrained}
7789
if !(unconstrained isa Bool)
7890
throw(
7991
ArgumentError("Expected Val{true} or Val{false}, got Val{$unconstrained}")
8092
)
8193
end
82-
return new{typeof(sampler),typeof(adtype),unconstrained}(sampler, adtype)
94+
return new{unconstrained,typeof(sampler),typeof(adtype)}(sampler, adtype)
8395
end
8496
end
8597

8698
"""
87-
requires_unconstrained_space(sampler::ExternalSampler)
88-
89-
Return `true` if the sampler requires unconstrained space, and `false` otherwise.
90-
"""
91-
function requires_unconstrained_space(
92-
::ExternalSampler{<:Any,<:Any,Unconstrained}
93-
) where {Unconstrained}
94-
return Unconstrained
95-
end
96-
97-
"""
98-
externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true)
99+
externalsampler(
100+
sampler::AbstractSampler;
101+
adtype=AutoForwardDiff(),
102+
unconstrained=AbstractMCMC.requires_unconstrained_space(sampler),
103+
)
99104
100105
Wrap a sampler so it can be used as an inference algorithm.
101106
102107
# Arguments
103108
- `sampler::AbstractSampler`: The sampler to wrap.
104109
105110
# Keyword Arguments
106-
- `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use.
107-
- `unconstrained::Bool=true`: Whether the sampler requires unconstrained space.
111+
- `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation
112+
(AD) backend to use.
113+
- `unconstrained::Bool=AbstractMCMC.requires_unconstrained_space(sampler)`: Whether the
114+
sampler requires unconstrained space.
108115
"""
109116
function externalsampler(
110-
sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained::Bool=true
117+
sampler::AbstractSampler;
118+
adtype=Turing.DEFAULT_ADTYPE,
119+
unconstrained::Bool=AbstractMCMC.requires_unconstrained_space(sampler),
111120
)
112121
return ExternalSampler(sampler, adtype, Val(unconstrained))
113122
end
@@ -128,30 +137,21 @@ end
128137
get_varinfo(state::TuringState) = state.varinfo
129138
get_varinfo(state::AbstractVarInfo) = state
130139

131-
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
132-
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
133-
return getparams(model, state.transition)
134-
end
135-
getstats(transition::AdvancedHMC.Transition) = transition.stat
136-
137-
getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
138-
139-
# TODO: Do we also support `resume`, etc?
140140
function AbstractMCMC.step(
141141
rng::Random.AbstractRNG,
142142
model::DynamicPPL.Model,
143-
sampler_wrapper::ExternalSampler;
143+
sampler_wrapper::ExternalSampler{unconstrained};
144144
initial_state=nothing,
145145
initial_params, # passed through from sample
146146
kwargs...,
147-
)
147+
) where {unconstrained}
148148
sampler = sampler_wrapper.sampler
149149

150150
# Initialise varinfo with initial params and link the varinfo if needed.
151151
varinfo = DynamicPPL.VarInfo(model)
152152
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params)
153153

154-
if requires_unconstrained_space(sampler_wrapper)
154+
if unconstrained
155155
varinfo = DynamicPPL.link(varinfo, model)
156156
end
157157

@@ -166,16 +166,17 @@ function AbstractMCMC.step(
166166
)
167167

168168
# Then just call `AbstractMCMC.step` with the right arguments.
169-
if initial_state === nothing
170-
transition_inner, state_inner = AbstractMCMC.step(
169+
_, state_inner = if initial_state === nothing
170+
AbstractMCMC.step(
171171
rng,
172172
AbstractMCMC.LogDensityModel(f),
173173
sampler;
174174
initial_params=initial_params_vector,
175175
kwargs...,
176176
)
177+
177178
else
178-
transition_inner, state_inner = AbstractMCMC.step(
179+
AbstractMCMC.step(
179180
rng,
180181
AbstractMCMC.LogDensityModel(f),
181182
sampler,
@@ -185,13 +186,12 @@ function AbstractMCMC.step(
185186
)
186187
end
187188

188-
# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
189-
# The latter uses the state rather than the transition.
190-
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
191-
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
189+
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
192190
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
191+
new_stats = AbstractMCMC.getstats(state_inner)
193192
return (
194-
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
193+
Turing.Inference.Transition(f.model, new_vi, new_stats),
194+
TuringState(state_inner, new_vi, f),
195195
)
196196
end
197197

@@ -206,16 +206,15 @@ function AbstractMCMC.step(
206206
f = state.ldf
207207

208208
# Then just call `AdvancedMCMC.step` with the right arguments.
209-
transition_inner, state_inner = AbstractMCMC.step(
209+
_, state_inner = AbstractMCMC.step(
210210
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
211211
)
212212

213-
# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
214-
# The latter uses the state rather than the transition.
215-
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
216-
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
213+
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
217214
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
215+
new_stats = AbstractMCMC.getstats(state_inner)
218216
return (
219-
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
217+
Turing.Inference.Transition(f.model, new_vi, new_stats),
218+
TuringState(state_inner, new_vi, f),
220219
)
221220
end

src/mcmc/gibbs.jl

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,19 @@
33
44
Return a boolean indicating whether `spl` is a valid component for a Gibbs sampler.
55
6-
Defaults to `false` if no method has been defined for a particular algorithm type.
6+
Defaults to `true` if no method has been defined for a particular sampler.
77
"""
8-
isgibbscomponent(::AbstractSampler) = false
9-
10-
isgibbscomponent(::ESS) = true
11-
isgibbscomponent(::HMC) = true
12-
isgibbscomponent(::HMCDA) = true
13-
isgibbscomponent(::NUTS) = true
14-
isgibbscomponent(::MH) = true
15-
isgibbscomponent(::PG) = true
8+
isgibbscomponent(::AbstractSampler) = true
169

1710
isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler)
18-
1911
isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
20-
isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
21-
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
22-
isgibbscomponent(spl) = false
12+
13+
isgibbscomponent(::IS) = false
14+
isgibbscomponent(::Prior) = false
15+
isgibbscomponent(::Emcee) = false
16+
isgibbscomponent(::SGLD) = false
17+
isgibbscomponent(::SGHMC) = false
18+
isgibbscomponent(::SMC) = false
2319

2420
function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
2521
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
4040

4141
[compat]
4242
ADTypes = "1"
43-
AbstractMCMC = "5"
43+
AbstractMCMC = "5.9"
4444
AbstractPPL = "0.11, 0.12, 0.13"
45-
AdvancedMH = "0.6, 0.7, 0.8"
45+
AdvancedMH = "0.8.9"
4646
AdvancedPS = "0.7"
4747
AdvancedVI = "0.4"
4848
Aqua = "0.8"

test/mcmc/external_sampler.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,11 @@ using Turing.Inference: AdvancedHMC
2020
# Turing declares an interface for external samplers (see docstring for
2121
# ExternalSampler). We should check that implementing this interface
2222
# and only this interface allows us to use the sampler in Turing.
23-
struct MyTransition{V<:AbstractVector}
24-
params::V
25-
end
26-
# Samplers need to implement `Turing.Inference.getparams`.
27-
Turing.Inference.getparams(::DynamicPPL.Model, t::MyTransition) = t.params
28-
# State doesn't matter (but we need to carry the params through to the next
29-
# iteration).
3023
struct MyState{V<:AbstractVector}
3124
params::V
3225
end
26+
AbstractMCMC.getparams(s::MyState) = s.params
27+
AbstractMCMC.getstats(s::MyState) = (param_length=length(s.params),)
3328

3429
# externalsamplers must accept LogDensityModel inside their step function.
3530
# By default Turing gives the externalsampler a LDF constructed with
@@ -58,7 +53,7 @@ using Turing.Inference: AdvancedHMC
5853
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, initial_params)
5954
@test lp isa Real
6055
@test grad isa AbstractVector{<:Real}
61-
return MyTransition(initial_params), MyState(initial_params)
56+
return nothing, MyState(initial_params)
6257
end
6358
function AbstractMCMC.step(
6459
rng::Random.AbstractRNG,
@@ -75,7 +70,7 @@ using Turing.Inference: AdvancedHMC
7570
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
7671
@test lp isa Real
7772
@test grad isa AbstractVector{<:Real}
78-
return MyTransition(params), MyState(params)
73+
return nothing, MyState(params)
7974
end
8075

8176
@model function test_external_sampler()
@@ -96,6 +91,7 @@ using Turing.Inference: AdvancedHMC
9691
@test all(chn[:lp] .== expected_logpdf)
9792
@test all(chn[:logprior] .== expected_logpdf)
9893
@test all(chn[:loglikelihood] .== 0.0)
94+
@test all(chn[:param_length] .== 2)
9995
end
10096

10197
function initialize_nuts(model::DynamicPPL.Model)

0 commit comments

Comments
 (0)