From 44b50c6e0cf6d9c743f462d67a41407f40aae2ad Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 10 Oct 2025 18:44:53 +0100 Subject: [PATCH 1/5] Implement some AbstractChains interface --- src/AbstractMCMC.jl | 9 +---- src/chains.jl | 61 ++++++++++++++++++++++++++++ src/experimental/chains.jl | 82 ++++++++++++++++++++++++++++++++++++++ src/interface.jl | 21 ---------- 4 files changed, 144 insertions(+), 29 deletions(-) create mode 100644 src/chains.jl create mode 100644 src/experimental/chains.jl diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index e103d5a5..a71fd435 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -22,14 +22,6 @@ export sample # Parallel sampling types export MCMCThreads, MCMCDistributed, MCMCSerial -""" - AbstractChains - -`AbstractChains` is an abstract type for an object that stores -parameter samples generated through a MCMC process. -""" -abstract type AbstractChains end - """ AbstractSampler @@ -137,6 +129,7 @@ function setparams!!(model::AbstractModel, state, params) return setparams!!(state, params) end +include("chains.jl") include("samplingstats.jl") include("logging.jl") include("interface.jl") diff --git a/src/chains.jl b/src/chains.jl new file mode 100644 index 00000000..3a3b6eb7 --- /dev/null +++ b/src/chains.jl @@ -0,0 +1,61 @@ +# AbstractChains interface +# +# NOTE: The entire interface is treated as experimental except for the AbstractChains type +# itself, along with `chainscat` and `chainsstack`. Thus, if you change any of those three, +# it is mandatory to release a breaking version. Other changes to the AbstractChains +# interface can be made in patch releases. + +""" + AbstractMCMC.AbstractChains + +An abstract type for Markov chains, i.e., a data structure which stores samples +obtained from Markov chain Monte Carlo (MCMC) sampling. + +!!! danger "Explicitly experimental" + + Although the abstract type `AbstractMCMC.AbstractChains` itself, along with the + functions `chainscat` and `chainsstack`, are exported and public, please note that *all + other parts of the interface remain experimental and subject to change*. In particular, + breaking changes to the interface may be introduced in formally non-breaking releases. + +Markov chains should generally have dictionary-like behaviour, where keys are mapped to +matrices of values. + +## Interface + +To implement a new subtype of `AbstractChains`, you need to define the following methods: + +- `Base.size` should return a tuple of ints (the exact meaning is left to you) +- `Base.keys` should return a list of keys +- [`AbstractMCMC.get_data`](@ref)`(chn, key)` +- [`AbstractMCMC.iter_indices`](@ref)`(chn)` +- [`AbstractMCMC.chain_indices`](@ref)`(chn)` + +You can optionally define the following methods for efficiency: + +- [`AbstractChains.niters`](@ref)`(chn)` +- [`AbstractChains.nchains`](@ref)`(chn)` +""" +abstract type AbstractChains end + +""" + chainscat(c::AbstractChains...) + +Concatenate multiple chains. + +By default, the chains are concatenated along the third dimension by calling +`cat(c...; dims=3)`. +""" +chainscat(c::AbstractChains...) = cat(c...; dims=3) + +""" + chainsstack(c::AbstractVector) + +Stack chains in `c`. + +By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`, +then `reduce(chainscat, c)` is called. +""" +chainsstack(c) = c +chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) +include("experimental/chains.jl") diff --git a/src/experimental/chains.jl b/src/experimental/chains.jl new file mode 100644 index 00000000..b4e13790 --- /dev/null +++ b/src/experimental/chains.jl @@ -0,0 +1,82 @@ +module Chains + +using AbstractMCMC: AbstractChains + +""" + AbstractMCMC.Chains.get_data(chn, key) + +Obtain the data associated with `key` from the `AbstractChain` object `chn`. + +This function should return an `AbstractMatrix` where the rows correspond to iterations and +columns correspond to chains. +""" +function get_data end + +""" + AbstractMCMC.Chains.iter_indices(chn) + +Obtain the indices of each iteration for the `AbstractChains` object `chn`. + +This function should return an `AbstractVector{<:Integer}`. +""" +function iter_indices end + +""" + AbstractMCMC.Chains.chain_indices(chn) + +Obtain the indices of each chain in the `AbstractChains` object `chn`. + +This function should return an `AbstractVector{<:Integer}`. +""" +function chain_indices end + +""" + AbstractMCMC.Chains.niters(chn) + +Obtain the number of iterations in the `AbstractChains` object `chn`. + +The default implementation calculates the length of `AbstractChains.iter_indices(chn)`. You +can define your own method if you have a more efficient way of obtaining this information. +""" +niters(c::AbstractChains) = length(iter_indices(c)) + +""" + AbstractMCMC.Chains.nchains(chn) + +Obtain the number of chains in the `AbstractChains` object `chn`. + +The default implementation calculates the length of `AbstractChains.chain_indices(chn)`. You +can define your own method if you have a more efficient way of obtaining this information. +""" +nchains(c::AbstractChains) = length(chain_indices(c)) + +""" + AbstractMCMC.Chains.test_interface(chn) + +Test that the `AbstractChains` object `chn` implements the required interface. +""" +function test_interface end # Extended in TestExt + +# Plotting functions; to be extended by individual chain libraries +function autocorplot end +function autocorplot! end +function energyplot end +function energyplot! end +function forestplot end +function forestplot! end +function meanplot end +function meanplot! end +function mixeddensity end +function mixeddensity! end +function ppcplot end +function ppcplot! end +function ridgelineplot end +function ridgelineplot! end +function traceplot end +function traceplot! end +# Note that other functions are provided by other libraries. In particular: +# Plots.histogram +# StatsPlots.density +# StatsPlots.cornerplot + +end # AbstractMCMC.Chains diff --git a/src/interface.jl b/src/interface.jl index 902424d2..98c4d858 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,24 +1,3 @@ -""" - chainscat(c::AbstractChains...) - -Concatenate multiple chains. - -By default, the chains are concatenated along the third dimension by calling -`cat(c...; dims=3)`. -""" -chainscat(c::AbstractChains...) = cat(c...; dims=3) - -""" - chainsstack(c::AbstractVector) - -Stack chains in `c`. - -By default, the vector of chains is returned unmodified. If `eltype(c) <: AbstractChains`, -then `reduce(chainscat, c)` is called. -""" -chainsstack(c) = c -chainsstack(c::AbstractVector{<:AbstractChains}) = reduce(chainscat, c) - """ bundle_samples(samples, model, sampler, state, chain_type[; kwargs...]) From 2045d4fb6463b61a749391aab72eb4d9514e5dc6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 10 Oct 2025 18:45:17 +0100 Subject: [PATCH 2/5] Bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index cc7d8b8d..7c40ca91 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probabilistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "5.8.2" +version = "5.9.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 873df97ff28a619e62c6f2783532b837c6467102 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 10 Oct 2025 19:21:07 +0100 Subject: [PATCH 3/5] Implement experimental part of AbstractChains interface --- Project.toml | 2 ++ src/experimental/chains.jl | 40 +++++++++++++++++++++++++++++++++++--- test/chains.jl | 26 +++++++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 test/chains.jl diff --git a/Project.toml b/Project.toml index 7c40ca91..5bd32bd5 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" @@ -29,6 +30,7 @@ LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" StatsBase = "0.32, 0.33, 0.34" TerminalLoggers = "0.1" +Test = "1" Transducers = "0.4.30" UUIDs = "<0.0.1, 1" julia = "1.6" diff --git a/src/experimental/chains.jl b/src/experimental/chains.jl index b4e13790..25b3e2ec 100644 --- a/src/experimental/chains.jl +++ b/src/experimental/chains.jl @@ -1,6 +1,7 @@ module Chains -using AbstractMCMC: AbstractChains +using AbstractMCMC: AbstractMCMC, AbstractChains +using Test """ AbstractMCMC.Chains.get_data(chn, key) @@ -55,7 +56,40 @@ nchains(c::AbstractChains) = length(chain_indices(c)) Test that the `AbstractChains` object `chn` implements the required interface. """ -function test_interface end # Extended in TestExt +function test_interface(chn::AbstractChains) + # TODO: Test chainscat, chainsstack + + @testset "Base.size, AbstractMCMC.Chains.niters, AbstractMCMC.Chains.nchains" begin + @test size(chn) isa NTuple{N,Int} where {N} + @test AbstractMCMC.Chains.niters(chn) isa Int + @test AbstractMCMC.Chains.nchains(chn) isa Int + end + + @testset "Base.keys" begin + @test collect(keys(chn)) isa AbstractVector + end + + @testset "AbstractMCMC.Chains.get_data" begin + for k in keys(chn) + data = AbstractMCMC.Chains.get_data(chn, k) + @test data isa AbstractMatrix + @test size(data) == + (AbstractMCMC.Chains.niters(chn), AbstractMCMC.Chains.nchains(chn)) + end + end + + @testset "AbstractMCMC.Chains.iter_indices" begin + ii = AbstractMCMC.Chains.iter_indices(chn) + @test ii isa AbstractVector{<:Integer} + @test length(ii) == AbstractMCMC.Chains.niters(chn) + end + + @testset "AbstractMCMC.Chains.chain_indices" begin + ci = AbstractMCMC.Chains.chain_indices(chn) + @test ci isa AbstractVector{<:Integer} + @test length(ci) == AbstractMCMC.Chains.nchains(chn) + end +end # Plotting functions; to be extended by individual chain libraries function autocorplot end @@ -76,7 +110,7 @@ function traceplot end function traceplot! end # Note that other functions are provided by other libraries. In particular: # Plots.histogram -# StatsPlots.density +# Plots.density # StatsPlots.cornerplot end # AbstractMCMC.Chains diff --git a/test/chains.jl b/test/chains.jl new file mode 100644 index 00000000..38c7e5e9 --- /dev/null +++ b/test/chains.jl @@ -0,0 +1,26 @@ +module AbstractMCMCChainsTests + +using AbstractMCMC: AbstractMCMC +using Test + +# This is a test mock: it minimally satisfies the AbstractChains interface. We use this to +# test our `test_interface` function, i.e., to ensure that something that satisfies the +# interface passes the test. +# See: https://invenia.github.io/blog/2020/11/06/interfacetesting/ +struct MockChain <: AbstractMCMC.AbstractChains + iter_indices::Vector{Int} + chain_indices::Vector{Int} + data::Dict{Symbol,Matrix{Float64}} +end +const MOCK = MockChain(1:10, 1:3, Dict(:param1 => rand(10, 3), :param2 => rand(10, 3))) +AbstractMCMC.Chains.iter_indices(c::MockChain) = c.iter_indices +AbstractMCMC.Chains.chain_indices(c::MockChain) = c.chain_indices +Base.size(c::MockChain) = (AbstractMCMC.Chains.niters(c), AbstractMCMC.Chains.nchains(c)) +Base.keys(c::MockChain) = keys(c.data) +AbstractMCMC.Chains.get_data(c::MockChain, k::Symbol) = c.data[k] + +@testset "AbstractChains interface" begin + AbstractMCMC.Chains.test_interface(MOCK) +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 909ae8b3..64fcfe85 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,4 +24,5 @@ include("utils.jl") include("stepper.jl") include("transducer.jl") include("logdensityproblems.jl") + include("chains.jl") end From a587aacbe3c729381075204ccff2348b0966230e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 10 Oct 2025 19:29:27 +0100 Subject: [PATCH 4/5] fix docs --- docs/src/api.md | 13 ++++++++++++- src/chains.jl | 10 +++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 94b006ab..b11f3c17 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -144,13 +144,24 @@ AbstractMCMC defines the abstract type `AbstractChains` for Markov chains. AbstractMCMC.AbstractChains ``` -For chains of this type, AbstractMCMC defines the following two methods. +For chains of this type, AbstractMCMC defines the following two **public** methods. ```@docs AbstractMCMC.chainscat AbstractMCMC.chainsstack ``` +The following interface methods are considered experimental and may change even in formally non-breaking releases. + +```@docs +AbstractMCMC.Chains.get_data +AbstractMCMC.Chains.iter_indices +AbstractMCMC.Chains.chain_indices +AbstractMCMC.Chains.niters +AbstractMCMC.Chains.nchains +AbstractMCMC.Chains.test_interface +``` + ## Interacting with states of samplers To make it a bit easier to interact with some arbitrary sampler state, we encourage implementations of `AbstractSampler` to implement the following methods: diff --git a/src/chains.jl b/src/chains.jl index 3a3b6eb7..fd57ba6e 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -27,14 +27,14 @@ To implement a new subtype of `AbstractChains`, you need to define the following - `Base.size` should return a tuple of ints (the exact meaning is left to you) - `Base.keys` should return a list of keys -- [`AbstractMCMC.get_data`](@ref)`(chn, key)` -- [`AbstractMCMC.iter_indices`](@ref)`(chn)` -- [`AbstractMCMC.chain_indices`](@ref)`(chn)` +- [`AbstractMCMC.Chains.get_data`](@ref)`(chn, key)` +- [`AbstractMCMC.Chains.iter_indices`](@ref)`(chn)` +- [`AbstractMCMC.Chains.chain_indices`](@ref)`(chn)` You can optionally define the following methods for efficiency: -- [`AbstractChains.niters`](@ref)`(chn)` -- [`AbstractChains.nchains`](@ref)`(chn)` +- [`AbstractMCMC.Chains.niters`](@ref)`(chn)` +- [`AbstractMCMC.Chains.nchains`](@ref)`(chn)` """ abstract type AbstractChains end From 2ac37ad60938e0608663b8a6e4ab46f188e8b2e2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 15 Oct 2025 20:11:49 +0100 Subject: [PATCH 5/5] Extend docstrings for `iter_indices` and `chain_indices` --- docs/Project.toml | 1 + src/experimental/chains.jl | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index f74dfb58..d5fc343e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/experimental/chains.jl b/src/experimental/chains.jl index 25b3e2ec..29ce8aa3 100644 --- a/src/experimental/chains.jl +++ b/src/experimental/chains.jl @@ -18,6 +18,9 @@ function get_data end Obtain the indices of each iteration for the `AbstractChains` object `chn`. +For example, if `chn` contains 1000 samples, but 1000 warmup steps and a thinning factor of +2 was used, then this function should return `1001:2:3000` (or an equivalent vector). + This function should return an `AbstractVector{<:Integer}`. """ function iter_indices end @@ -27,6 +30,10 @@ function iter_indices end Obtain the indices of each chain in the `AbstractChains` object `chn`. +If there is no special numbering associated with chains, then this function can simply +return `1:nchains(chn)`. However, this function provides the flexibility to have +non-standard chain numbering (e.g. if chains are combined from multiple sources). + This function should return an `AbstractVector{<:Integer}`. """ function chain_indices end