Skip to content

Commit 443eaca

Browse files
authored
Allow passing kwargs on to Libtask (when it's an AbstractTuringLibtaskModel...) (#118)
* Add a different struct that can pass kwargs on to Libtask * Format
1 parent e100352 commit 443eaca

File tree

3 files changed

+31
-4
lines changed

3 files changed

+31
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.7.1"
4+
version = "0.7.2"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,31 @@ function AdvancedPS.LibtaskModel(
3737
) # Changed the API, need to take care of the RNG properly
3838
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
3939
end
40+
# TODO: Upstream this to Turing
41+
function AdvancedPS.LibtaskModel(
42+
f::AdvancedPS.AbstractTuringLibtaskModel, rng::Random.AbstractRNG
43+
)
44+
return AdvancedPS.LibtaskModel(
45+
f, Libtask.TapedTask(TapedGlobals(rng), f.fargs...; f.kwargs...)
46+
)
47+
end
48+
49+
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
50+
51+
function to_tapedtask(
52+
newf::AdvancedPS.AbstractGenericModel, trace::LibtaskTrace, rng::Random.AbstractRNG
53+
)
54+
return Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
55+
end
56+
function to_tapedtask(
57+
newf::AdvancedPS.AbstractTuringLibtaskModel,
58+
trace::LibtaskTrace,
59+
rng::Random.AbstractRNG,
60+
)
61+
return Libtask.TapedTask(
62+
TapedGlobals(rng, get_other_global(trace)), newf.fargs...; newf.kwargs...
63+
)
64+
end
4065

4166
"""
4267
copy(model::AdvancedPS.LibtaskModel)
@@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel)
4772
return AdvancedPS.LibtaskModel(deepcopy(model.f), copy(model.ctask))
4873
end
4974

50-
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
51-
5275
function Base.copy(trace::LibtaskTrace)
5376
newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng))
5477
set_other_global!(newtrace, newtrace)
@@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114137
newf = AdvancedPS.reset_model(trace.model.f)
115138
Random123.set_counter!(rng, 1)
116139

117-
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
140+
ctask = to_tapedtask(newf, trace, rng)
118141
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
119142

120143
# add backward reference

src/AdvancedPS.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end
1616
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
1717
abstract type AbstractGenericModel <: AbstractParticleModel end
1818

19+
# TODO(penelopeysm): This should be upstreamed to Turing together with anything that is
20+
# Turing-specific in LibtaskExt.
21+
abstract type AbstractTuringLibtaskModel <: AbstractGenericModel end
22+
1923
include("resampling.jl")
2024
include("rng.jl")
2125
include("model.jl")

0 commit comments

Comments
 (0)