Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# 0.42.0

## External sampler interface

The interface for defining an external sampler has been reworked.
In general, implementations of external samplers should now no longer need to depend on Turing.
This is because the interface functions required have been shifted upstream to AbstractMCMC.jl.

In particular, you now only need to define the following functions:

- AbstractMCMC.step(rng::Random.AbstractRNG, model::AbstractMCMC.LogDensityModel, ::MySampler; kwargs...) (and also a method with `state`, and the corresponding `step_warmup` methods if needed)
- AbstractMCMC.getparams(::MySamplerState) -> Vector{<:Real}
- AbstractMCMC.getstats(::MySamplerState) -> NamedTuple
- AbstractMCMC.requires_unconstrained_space(::MySampler) -> Bool (default `true`)

This means that you only need to depend on AbstractMCMC.jl.
As long as the above functions are defined correctly, Turing will be able to use your external sampler.

# 0.41.0

## DynamicPPL 0.38
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ TuringOptimExt = ["Optim", "AbstractPPL"]

[compat]
ADTypes = "1.9"
AbstractMCMC = "5.5"
AbstractMCMC = "5.9"
AbstractPPL = "0.11, 0.12, 0.13"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
Expand Down
135 changes: 67 additions & 68 deletions src/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained}
ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this type parameter earlier so that we can dispatch on it more easily.


Represents a sampler that does not have a custom implementation of `AbstractMCMC.step(rng,
::DynamicPPL.Model, spl)`.
Expand All @@ -14,45 +14,59 @@ $(TYPEDFIELDS)
If you implement a new `MySampler <: AbstractSampler` and want it to work with Turing.jl
models, there are two options:

1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. This is the
most powerful option and is what Turing.jl's in-house samplers do. Implementing this
means that you can directly call `sample(model, MySampler(), N)`.
1. Directly implement the `AbstractMCMC.step` methods for `DynamicPPL.Model`. That is to
say, implement `AbstractMCMC.step(rng::Random.AbstractRNG, model::DynamicPPL.Model,
sampler::MySampler; kwargs...)` and related methods. This is the most powerful option and
is what Turing.jl's in-house samplers do. Implementing this means that you can directly
call `sample(model, MySampler(), N)`.

2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel`. This
struct wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
2. Implement a generic `AbstractMCMC.step` method for `AbstractMCMC.LogDensityModel` (the
same signature as above except that `model::AbstractMCMC.LogDensityModel`). This struct
wraps an object that obeys the LogDensityProblems.jl interface, so your `step`
implementation does not need to know anything about Turing.jl or DynamicPPL.jl. To use
this with Turing.jl, you will need to wrap your sampler: `sample(model,
externalsampler(MySampler()), N)`.

This section describes the latter.

`MySampler` must implement the following methods:
`MySampler` **must** implement the following methods:

- `AbstractMCMC.step` (the main function for taking a step in MCMC sampling; this is
documented in AbstractMCMC.jl)
- `Turing.Inference.getparams(::DynamicPPL.Model, external_transition)`: How to extract the
parameters from the transition returned by your sampler (i.e., the first return value of
`step`). There is a default implementation for this method, which is to return
`external_transition.θ`.

!!! note
In a future breaking release of Turing, this is likely to change to
`AbstractMCMC.getparams(::DynamicPPL.Model, external_state)`, with no default method.
`Turing.Inference.getparams` is technically an internal method, so the aim here is to
unify the interface for samplers at a higher level.
documented in AbstractMCMC.jl). This function must return a tuple of two elements, a
'transition' and a 'state'.

- `AbstractMCMC.getparams(external_state)`: How to extract the parameters from the **state**
returned by your sampler (i.e., the **second** return value of `step`). For your sampler
to work with Turing.jl, this function should return a Vector of parameter values. Note that
this function does not need to perform any linking or unlinking; Turing.jl will take care of
this for you. You should return the parameters *exactly* as your sampler sees them.

- `AbstractMCMC.getstats(external_state)`: Extract sampler statistics corresponding to this
iteration from the **state** returned by your sampler (i.e., the **second** return value
of `step`). For your sampler to work with Turing.jl, this function should return a
`NamedTuple`. If there are no statistics to return, return `NamedTuple()`.

Note that `getstats` should not include log-probabilities as these will be recalculated by
Turing automatically for you.

Notice that both of these functions take the **state** as input, not the **transition**. In
other words, the transition is completely useless for the external sampler interface. This is
in line with long-term plans for removing transitions from AbstractMCMC.jl and only using
states.

There are a few more optional functions which you can implement to improve the integration
with Turing.jl:

- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want your sampler to function as
a component in Turing's Gibbs sampler, you should make this evaluate to `true`.

- `Turing.Inference.requires_unconstrained_space(::MySampler)`: If your sampler requires
- `AbstractMCMC.requires_unconstrained_space(::MySampler)`: If your sampler requires
unconstrained space, you should return `true`. This tells Turing to perform linking on the
VarInfo before evaluation, and ensures that the parameter values passed to your sampler
will always be in unconstrained (Euclidean) space.

- `Turing.Inference.isgibbscomponent(::MySampler)`: If you want to disallow your sampler
from a component in Turing's Gibbs sampler, you should make this evaluate to `false`. Note
that the default is `true`, so you should only need to implement this in special cases.
"""
struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrained} <:
struct ExternalSampler{Unconstrained,S<:AbstractSampler,AD<:ADTypes.AbstractADType} <:
AbstractSampler
"the sampler to wrap"
sampler::S
Expand All @@ -67,33 +81,20 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
# Arguments
- `sampler::AbstractSampler`: The sampler to wrap.
- `adtype::ADTypes.AbstractADType`: The automatic differentiation (AD) backend to use.
- `unconstrained::Val=Val{true}()`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
- `unconstrained::Val`: Value type containing a boolean indicating whether the sampler requires unconstrained space.
"""
function ExternalSampler(
sampler::AbstractSampler,
adtype::ADTypes.AbstractADType,
(::Val{unconstrained})=Val(true),
sampler::AbstractSampler, adtype::ADTypes.AbstractADType, ::Val{unconstrained}
) where {unconstrained}
if !(unconstrained isa Bool)
throw(
ArgumentError("Expected Val{true} or Val{false}, got Val{$unconstrained}")
)
end
return new{typeof(sampler),typeof(adtype),unconstrained}(sampler, adtype)
return new{unconstrained,typeof(sampler),typeof(adtype)}(sampler, adtype)
end
end

"""
requires_unconstrained_space(sampler::ExternalSampler)

Return `true` if the sampler requires unconstrained space, and `false` otherwise.
"""
function requires_unconstrained_space(
::ExternalSampler{<:Any,<:Any,Unconstrained}
) where {Unconstrained}
return Unconstrained
end

"""
externalsampler(sampler::AbstractSampler; adtype=AutoForwardDiff(), unconstrained=true)

Expand All @@ -106,10 +107,10 @@ Wrap a sampler so it can be used as an inference algorithm.
- `adtype::ADTypes.AbstractADType=ADTypes.AutoForwardDiff()`: The automatic differentiation (AD) backend to use.
- `unconstrained::Bool=true`: Whether the sampler requires unconstrained space.
"""
function externalsampler(
sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE, unconstrained::Bool=true
)
return ExternalSampler(sampler, adtype, Val(unconstrained))
function externalsampler(sampler::AbstractSampler; adtype=Turing.DEFAULT_ADTYPE)
return ExternalSampler(
sampler, adtype, Val(AbstractMCMC.requires_unconstrained_space(sampler))
)
end

# TODO(penelopeysm): Can't we clean this up somehow?
Expand All @@ -128,30 +129,22 @@ end
get_varinfo(state::TuringState) = state.varinfo
get_varinfo(state::AbstractVarInfo) = state

getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)
end
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::ExternalSampler;
sampler_wrapper::ExternalSampler{unconstrained};
initial_state=nothing,
initial_params, # passed through from sample
kwargs...,
)
) where {unconstrained}
sampler = sampler_wrapper.sampler

# Initialise varinfo with initial params and link the varinfo if needed.
varinfo = DynamicPPL.VarInfo(model)
_, varinfo = DynamicPPL.init!!(rng, model, varinfo, initial_params)

if requires_unconstrained_space(sampler_wrapper)
if unconstrained
varinfo = DynamicPPL.link(varinfo, model)
end

Expand All @@ -166,16 +159,17 @@ function AbstractMCMC.step(
)

# Then just call `AbstractMCMC.step` with the right arguments.
if initial_state === nothing
transition_inner, state_inner = AbstractMCMC.step(
_, state_inner = if initial_state === nothing
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(f),
sampler;
initial_params=initial_params_vector,
kwargs...,
)

else
transition_inner, state_inner = AbstractMCMC.step(
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(f),
sampler,
Expand All @@ -185,13 +179,12 @@ function AbstractMCMC.step(
)
end

# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
# The latter uses the state rather than the transition.
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
new_stats = AbstractMCMC.getstats(state_inner)
return (
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
Turing.Inference.Transition(f.model, new_vi, new_stats),
TuringState(state_inner, new_vi, f),
)
end

Expand All @@ -206,16 +199,22 @@ function AbstractMCMC.step(
f = state.ldf

# Then just call `AdvancedMCMC.step` with the right arguments.
transition_inner, state_inner = AbstractMCMC.step(
_, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, state.state; kwargs...
)

# NOTE: This is Turing.Inference.getparams, not AbstractMCMC.getparams (!!!!!)
# The latter uses the state rather than the transition.
# TODO(penelopeysm): Make this use AbstractMCMC.getparams instead
new_parameters = Turing.Inference.getparams(f.model, transition_inner)
new_parameters = AbstractMCMC.getparams(f.model, state_inner)
new_vi = DynamicPPL.unflatten(f.varinfo, new_parameters)
new_stats = AbstractMCMC.getstats(state_inner)
return (
Transition(f.model, new_vi, transition_inner), TuringState(state_inner, new_vi, f)
Turing.Inference.Transition(f.model, new_vi, new_stats),
TuringState(state_inner, new_vi, f),
)
end

# Implementation of interface for AdvancedMH and AdvancedHMC. TODO: These should be
# upstreamed to the respective packages, I'm just not doing it here to avoid having to run
# CI against three separate PR branches.
AbstractMCMC.getstats(state::AdvancedHMC.HMCState) = state.transition.stat
# Note that for AdvancedMH, transition and state are equivalent (and both named Transition)
AbstractMCMC.getstats(state::AdvancedMH.Transition) = (accepted=state.accepted,)
8 changes: 2 additions & 6 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

Return a boolean indicating whether `spl` is a valid component for a Gibbs sampler.

Defaults to `false` if no method has been defined for a particular algorithm type.
Defaults to `true` if no method has been defined for a particular sampler.
"""
isgibbscomponent(::AbstractSampler) = false
isgibbscomponent(::AbstractSampler) = true

isgibbscomponent(::ESS) = true
isgibbscomponent(::HMC) = true
Expand All @@ -15,11 +15,7 @@ isgibbscomponent(::MH) = true
isgibbscomponent(::PG) = true

isgibbscomponent(spl::RepeatSampler) = isgibbscomponent(spl.sampler)

isgibbscomponent(spl::ExternalSampler) = isgibbscomponent(spl.sampler)
isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
isgibbscomponent(spl) = false

function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf
Expand Down
14 changes: 5 additions & 9 deletions test/mcmc/external_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,11 @@ using Turing.Inference: AdvancedHMC
# Turing declares an interface for external samplers (see docstring for
# ExternalSampler). We should check that implementing this interface
# and only this interface allows us to use the sampler in Turing.
struct MyTransition{V<:AbstractVector}
params::V
end
# Samplers need to implement `Turing.Inference.getparams`.
Turing.Inference.getparams(::DynamicPPL.Model, t::MyTransition) = t.params
# State doesn't matter (but we need to carry the params through to the next
# iteration).
struct MyState{V<:AbstractVector}
params::V
end
AbstractMCMC.getparams(s::MyState) = s.params
AbstractMCMC.getstats(s::MyState) = (param_length=length(s.params),)

# externalsamplers must accept LogDensityModel inside their step function.
# By default Turing gives the externalsampler a LDF constructed with
Expand Down Expand Up @@ -58,7 +53,7 @@ using Turing.Inference: AdvancedHMC
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, initial_params)
@test lp isa Real
@test grad isa AbstractVector{<:Real}
return MyTransition(initial_params), MyState(initial_params)
return nothing, MyState(initial_params)
end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
Expand All @@ -75,7 +70,7 @@ using Turing.Inference: AdvancedHMC
lp, grad = LogDensityProblems.logdensity_and_gradient(ldf, params)
@test lp isa Real
@test grad isa AbstractVector{<:Real}
return MyTransition(params), MyState(params)
return nothing, MyState(params)
end

@model function test_external_sampler()
Expand All @@ -96,6 +91,7 @@ using Turing.Inference: AdvancedHMC
@test all(chn[:lp] .== expected_logpdf)
@test all(chn[:logprior] .== expected_logpdf)
@test all(chn[:loglikelihood] .== 0.0)
@test all(chn[:param_length] .== 2)
end

function initialize_nuts(model::DynamicPPL.Model)
Expand Down
Loading