@@ -47,6 +47,14 @@ function AbstractMCMC.sample(
4747 callback = nothing ,
4848 kwargs... ,
4949)
50+ if haskey (kwargs, :nadapts )
51+ throw (
52+ ArgumentError (
53+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
54+ ),
55+ )
56+ end
57+
5058 if callback === nothing
5159 callback = HMCProgressCallback (N, progress = progress, verbose = verbose)
5260 progress = false # don't use AMCMC's progress-funtionality
@@ -78,6 +86,13 @@ function AbstractMCMC.sample(
7886 callback = nothing ,
7987 kwargs... ,
8088)
89+ if haskey (kwargs, :nadapts )
90+ throw (
91+ ArgumentError (
92+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
93+ ),
94+ )
95+ end
8196
8297 if callback === nothing
8398 callback = HMCProgressCallback (N, progress = progress, verbose = verbose)
@@ -141,8 +156,17 @@ function AbstractMCMC.step(
141156 model:: AbstractMCMC.LogDensityModel ,
142157 spl:: AbstractHMCSampler ,
143158 state:: HMCState ;
159+ n_adapts:: Int = 0 ,
144160 kwargs... ,
145161)
162+ if haskey (kwargs, :nadapts )
163+ throw (
164+ ArgumentError (
165+ " keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps." ,
166+ ),
167+ )
168+ end
169+
146170 # Compute transition.
147171 i = state. i + 1
148172 t_old = state. transition
@@ -158,7 +182,6 @@ function AbstractMCMC.step(
158182
159183 # Adapt h and spl.
160184 tstat = stat (t)
161- n_adapts = kwargs[:n_adapts ]
162185 h, κ, isadapted = adapt! (h, κ, adaptor, i, n_adapts, t. z. θ, tstat. acceptance_rate)
163186 tstat = merge (tstat, (is_adapt = isadapted,))
164187
@@ -189,8 +212,8 @@ struct HMCProgressCallback{P}
189212 " If `progress` is not specified and this is `true` some information will be logged upon completion of adaptation."
190213 verbose:: Bool
191214 " Number of divergent transitions fo far."
192- num_divergent_transitions:: Ref {Int}
193- num_divergent_transitions_during_adaption:: Ref {Int}
215+ num_divergent_transitions:: Base.RefValue {Int}
216+ num_divergent_transitions_during_adaption:: Base.RefValue {Int}
194217end
195218
196219function HMCProgressCallback (n_samples; progress = true , verbose = false )
@@ -200,7 +223,16 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false)
200223 HMCProgressCallback (pm, progress, verbose, Ref (0 ), Ref (0 ))
201224end
202225
203- function (cb:: HMCProgressCallback )(rng, model, spl, t, state, i; nadapts = 0 , kwargs... )
226+ function (cb:: HMCProgressCallback )(
227+ rng,
228+ model,
229+ spl,
230+ t,
231+ state,
232+ i;
233+ n_adapts:: Int = 0 ,
234+ kwargs... ,
235+ )
204236 progress = cb. progress
205237 verbose = cb. verbose
206238 pm = cb. pm
@@ -243,8 +275,8 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw
243275 ),
244276 )
245277 # Report finish of adapation
246- elseif verbose && isadapted && i == nadapts
247- @info " Finished $nadapts adapation steps" adaptor κ. τ. integrator metric
278+ elseif verbose && isadapted && i == n_adapts
279+ @info " Finished $(n_adapts) adapation steps" adaptor κ. τ. integrator metric
248280 end
249281end
250282
0 commit comments