diff --git a/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl b/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl index 9300c763..58c0c8fb 100644 --- a/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl +++ b/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl @@ -6,10 +6,28 @@ using Random: AbstractRNG # Should use Module.function directly in the code instead of doing this import Distributions: support, pdf, mode, mean +using Distributions: DiscreteUnivariateDistribution, Distribution +using Distributions: VariateForm, Multivariate, Matrixvariate, Univariate +using Distributions: ValueSupport, Discrete import Random: rand using UnicodePlots: barplot +""" +Try to guess the Distributions.VariateForm for a distribution based on the sample type. +""" +function infer_variate_form(T::Type) + if T <: AbstractVector + return Multivariate + elseif T <: AbstractMatrix + return Matrixvariate + elseif T <: Number + return Univariate + else + return VariateForm + end +end + export weighted_iterator include("weighted_iteration.jl") diff --git a/lib/POMDPTools/src/POMDPDistributions/bool.jl b/lib/POMDPTools/src/POMDPDistributions/bool.jl index 5114645b..550f306a 100644 --- a/lib/POMDPTools/src/POMDPDistributions/bool.jl +++ b/lib/POMDPTools/src/POMDPDistributions/bool.jl @@ -5,13 +5,15 @@ Create a distribution over Boolean values (`true` or `false`). `p_true` is the probability of the `true` outcome; the probability of `false` is 1-`p_true`. """ -struct BoolDistribution +struct BoolDistribution <: DiscreteUnivariateDistribution p::Float64 # probability of true end -pdf(d::BoolDistribution, s::Bool) = s ? d.p : 1.0-d.p +pdf(d::BoolDistribution, s::Real) = convert(Bool, s) ? d.p : 1.0-d.p +Distributions.logpdf(d::BoolDistribution, s) = log(pdf(d, s)) rand(rng::AbstractRNG, s::Random.SamplerTrivial{BoolDistribution}) = rand(rng) <= s[].p +rand(rng::AbstractRNG, d::BoolDistribution) = rand(rng) <= d.p Base.iterate(d::BoolDistribution) = ((d.p, true), true) function Base.iterate(d::BoolDistribution, state::Bool) diff --git a/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl b/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl index dec98b48..4fde9485 100644 --- a/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl +++ b/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl @@ -7,7 +7,7 @@ Create a sparse categorical distribution. This is optimized for value iteration with a fast implementation of `weighted_iterator`. Both `pdf` and `rand` are order n. """ -struct SparseCat{V, P} +struct SparseCat{V, P, F} <: Distribution{F, Discrete} vals::V probs::P end @@ -23,13 +23,16 @@ function SparseCat(v, p::AbstractArray) SparseCat(v, cp) end # the method above gets all arrays *except* ones that have a numeric eltype, which are handled below -SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p)}(v, p) +SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p) -function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) - d = s[] +SparseCat(v, p) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p) + +rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) = rand(rng, s[]) + +function rand(rng::AbstractRNG, d::SparseCat) r = sum(d.probs)*rand(rng) tot = zero(eltype(d.probs)) - for (v, p) in d + for (v, p) in weighted_iterator(d) tot += p if r < tot return v @@ -47,9 +50,18 @@ function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) error("Error sampling from SparseCat distribution with vals $(d.vals) and probs $(d.probs)") # try to help with type stability end +Distributions.sampler(d::SparseCat) = Random.SamplerTrivial(d) +Random.Sampler(::AbstractRNG, d::SparseCat, repetition::Union{Val{1}, Val{Inf}}) = Random.SamplerTrivial(d) + +# to resolve ambiguity between pdf(::UnivariateDistribution, ::Real) and pdf(::SparseCat, ::Any) +pdf(d::SparseCat, s) = _pdf(d, s) +pdf(d::SparseCat, s::Real) = _pdf(d, s) + +Distributions.logpdf(d::SparseCat, x) = log(pdf(d, x)) + # slow linear search :( -function pdf(d::SparseCat, s) - for (v, p) in d +function _pdf(d::SparseCat, s) + for (v, p) in weighted_iterator(d) if v == s return p end @@ -57,7 +69,7 @@ function pdf(d::SparseCat, s) return zero(eltype(d.probs)) end -function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray} +function _pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray} for (i,v) in enumerate(d.vals) if v == s return d.probs[i] @@ -67,19 +79,25 @@ function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray} end - support(d::SparseCat) = d.vals -weighted_iterator(d::SparseCat) = d +struct SparseCatIterator{D<:SparseCat} + d::D +end + + +weighted_iterator(d::SparseCat) = SparseCatIterator(d) # iterator for general SparseCat # this has some type stability problems -function Base.iterate(d::SparseCat) +function Base.iterate(i::SparseCatIterator) + d = i.d val, vstate = iterate(d.vals) prob, pstate = iterate(d.probs) return ((val=>prob), (vstate, pstate)) end -function Base.iterate(d::SparseCat, dstate::Tuple) +function Base.iterate(i::SparseCatIterator, dstate::Tuple) + d = i.d vstate, pstate = dstate vnext = iterate(d.vals, vstate) pnext = iterate(d.probs, pstate) @@ -94,21 +112,22 @@ end # iterator for SparseCat with indexed members const Indexed = Union{AbstractArray, Tuple, NamedTuple} -function Base.iterate(d::SparseCat{V,P}, state::Integer=1) where {V<:Indexed, P<:Indexed} - if state > length(d) +function Base.iterate(i::SparseCatIterator{<:SparseCat{<:Indexed,<:Indexed}}, state::Integer=1) + if state > length(i) return nothing end - return (d.vals[state]=>d.probs[state], state+1) + return (i.d.vals[state]=>i.d.probs[state], state+1) end -Base.length(d::SparseCat) = min(length(d.vals), length(d.probs)) -Base.eltype(D::Type{SparseCat{V,P}}) where {V, P} = Pair{eltype(V), eltype(P)} -sampletype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V) -Random.gentype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V) +Base.length(i::SparseCatIterator) = min(length(i.d.vals), length(i.d.probs)) +Base.eltype(D::Type{SparseCatIterator{SparseCat{V,P,F}}}) where {V,P,F} = Pair{eltype(V), eltype(P)} + +sampletype(D::Type{SparseCat{V,P,F}}) where {V,P,F} = eltype(V) +Random.gentype(D::Type{SparseCat{V,P,F}}) where {V,P,F} = eltype(V) function mean(d::SparseCat) vsum = zero(eltype(d.vals)) - for (v, p) in d + for (v, p) in weighted_iterator(d) vsum += v*p end return vsum/sum(d.probs) @@ -117,7 +136,7 @@ end function mode(d::SparseCat) bestp = zero(eltype(d.probs)) bestv = first(d.vals) - for (v, p) in d + for (v, p) in weighted_iterator(d) if p >= bestp bestp = p bestv = v diff --git a/lib/POMDPTools/src/Policies/playback.jl b/lib/POMDPTools/src/Policies/playback.jl index a0853060..57479a3e 100644 --- a/lib/POMDPTools/src/Policies/playback.jl +++ b/lib/POMDPTools/src/Policies/playback.jl @@ -21,7 +21,11 @@ mutable struct PlaybackPolicy{A<:AbstractArray, P<:Policy, V<:AbstractArray{<:Re end # Constructor for the PlaybackPolicy -PlaybackPolicy(actions::AbstractArray, backup_policy::Policy; logpdfs::AbstractArray{<:Real} = Float64[]) = PlaybackPolicy(actions, backup_policy, logpdfs, 1) +function PlaybackPolicy(actions::AbstractArray, + backup_policy::Policy = FunctionPolicy(s->error("PlaybackPolicy out of actions.")); + logpdfs::AbstractArray{<:Real} = Float64[]) + return PlaybackPolicy(actions, backup_policy, logpdfs, 1) +end # Action selection for the PlaybackPolicy function POMDPs.action(p::PlaybackPolicy, s) @@ -41,5 +45,3 @@ function Distributions.logpdf(p::PlaybackPolicy, h) return sum(p.logpdfs[1:N]) end end - - diff --git a/lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl b/lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl new file mode 100644 index 00000000..072e3e10 --- /dev/null +++ b/lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl @@ -0,0 +1,10 @@ +@test POMDPDistributions.infer_variate_form(typeof([1 2; 3 4])) == Distributions.Matrixvariate +@test POMDPDistributions.infer_variate_form(typeof([1, 2])) == Distributions.Multivariate +@test POMDPDistributions.infer_variate_form(typeof(1)) == Distributions.Univariate +@test POMDPDistributions.infer_variate_form(Any) == Distributions.VariateForm + +p = product_distribution([SparseCat([1, 2, 3], [0.5, 0.2, 0.3]), BoolDistribution(1.0)]) +@test rand(p) isa AbstractVector +@test pdf(p, [1, 1]) == 0.5 + +@test_broken p = Product([SparseCat([:a,:b,:c], [0.5, 0.2, 0.3]), BoolDistribution(1.0)]) diff --git a/lib/POMDPTools/test/policies/test_playback_policy.jl b/lib/POMDPTools/test/policies/test_playback_policy.jl index 1ea5d951..2dc47468 100644 --- a/lib/POMDPTools/test/policies/test_playback_policy.jl +++ b/lib/POMDPTools/test/policies/test_playback_policy.jl @@ -13,6 +13,12 @@ playback = PlaybackPolicy(collect(action_hist(hist)), RandomPolicy(mdp)) hist2 = simulate(HistoryRecorder(), mdp, playback, GWPos(3,3)) @test hist == hist2 +## Test with default error policy +playback = PlaybackPolicy(collect(action_hist(hist))) +@test all(playback.actions .== action_hist(hist)) +hist3 = simulate(HistoryRecorder(), mdp, playback, GWPos(3,3)) +@test_throws ErrorException action(playback, GWPos(3,3)) + ## Test log probability Distributions.logpdf(p::RandomPolicy, h) = length(h)*log(1. / length(actions(p.problem))) playback = PlaybackPolicy(collect(action_hist(hist)), RandomPolicy(mdp), logpdfs = -ones(length(hist))) diff --git a/lib/POMDPTools/test/runtests.jl b/lib/POMDPTools/test/runtests.jl index 3d5e8cea..d2ffa260 100644 --- a/lib/POMDPTools/test/runtests.jl +++ b/lib/POMDPTools/test/runtests.jl @@ -14,6 +14,7 @@ using SparseArrays: sparse import CommonRLInterface +using Distributions: Distributions, product_distribution, Product @testset "POMDPTools.jl" begin @testset "POMDPDistributions" begin @@ -23,6 +24,7 @@ import CommonRLInterface include("distributions/test_pretty_printing.jl") include("distributions/test_sparse_cat.jl") include("distributions/test_uniform.jl") + include("distributions/test_distributions_jl_integration.jl") end @testset "ModelTools" begin