@@ -37,6 +37,31 @@ function AdvancedPS.LibtaskModel(
3737) # Changed the API, need to take care of the RNG properly
3838 return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (TapedGlobals (rng), f, args... ))
3939end
40+ # TODO : Upstream this to Turing
41+ function AdvancedPS. LibtaskModel (
42+ f:: AdvancedPS.AbstractTuringLibtaskModel , rng:: Random.AbstractRNG
43+ )
44+ return AdvancedPS. LibtaskModel (
45+ f, Libtask. TapedTask (TapedGlobals (rng), f. fargs... ; f. kwargs... )
46+ )
47+ end
48+
49+ const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
50+
51+ function to_tapedtask (
52+ newf:: AdvancedPS.AbstractGenericModel , trace:: LibtaskTrace , rng:: Random.AbstractRNG
53+ )
54+ return Libtask. TapedTask (TapedGlobals (rng, get_other_global (trace)), newf)
55+ end
56+ function to_tapedtask (
57+ newf:: AdvancedPS.AbstractTuringLibtaskModel ,
58+ trace:: LibtaskTrace ,
59+ rng:: Random.AbstractRNG ,
60+ )
61+ return Libtask. TapedTask (
62+ TapedGlobals (rng, get_other_global (trace)), newf. fargs... ; newf. kwargs...
63+ )
64+ end
4065
4166"""
4267 copy(model::AdvancedPS.LibtaskModel)
@@ -47,8 +72,6 @@ function Base.copy(model::AdvancedPS.LibtaskModel)
4772 return AdvancedPS. LibtaskModel (deepcopy (model. f), copy (model. ctask))
4873end
4974
50- const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
51-
5275function Base. copy (trace:: LibtaskTrace )
5376 newtrace = AdvancedPS. Trace (copy (trace. model), deepcopy (trace. rng))
5477 set_other_global! (newtrace, newtrace)
@@ -114,7 +137,7 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
114137 newf = AdvancedPS. reset_model (trace. model. f)
115138 Random123. set_counter! (rng, 1 )
116139
117- ctask = Libtask . TapedTask ( TapedGlobals (rng, get_other_global ( trace)), newf )
140+ ctask = to_tapedtask (newf, trace, rng )
118141 new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
119142
120143 # add backward reference
0 commit comments