From 0859ae28be116004bd3d32d022954ea82062d7bd Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Mon, 24 Apr 2023 17:17:54 -0600 Subject: [PATCH 1/6] Update CI.yml dev POMDPTools in docs --- .github/workflows/CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7b2c84dd..01ec0bd4 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,6 +35,7 @@ jobs: julia --project=docs -e ' using Pkg Pkg.develop(PackageSpec(path=pwd())) + Pkg.develop(path="lib/POMDPTools") Pkg.instantiate()' - run: julia --project=docs docs/make.jl env: From b02a0e96664a35ddc53ddd6f3726b1191d4a1495 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 25 Apr 2023 09:55:11 -0600 Subject: [PATCH 2/6] fix docs build --- .github/workflows/CI.yml | 2 +- docs/Project.toml | 1 + lib/POMDPTools/src/ModelTools/ModelTools.jl | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 01ec0bd4..c615e865 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -35,7 +35,7 @@ jobs: julia --project=docs -e ' using Pkg Pkg.develop(PackageSpec(path=pwd())) - Pkg.develop(path="lib/POMDPTools") + Pkg.develop(PackageSpec(path="lib/POMDPTools")) Pkg.instantiate()' - run: julia --project=docs docs/make.jl env: diff --git a/docs/Project.toml b/docs/Project.toml index a4010ef5..7b0715ae 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,4 +6,5 @@ NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" +POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" QuickPOMDPs = "8af83fb2-a731-493c-9049-9e19dbce6165" diff --git a/lib/POMDPTools/src/ModelTools/ModelTools.jl b/lib/POMDPTools/src/ModelTools/ModelTools.jl index d92550cf..f8d6340d 100644 --- a/lib/POMDPTools/src/ModelTools/ModelTools.jl +++ b/lib/POMDPTools/src/ModelTools/ModelTools.jl @@ -16,7 +16,7 @@ import Base: == using ..POMDPDistributions # import Distributions: pdf, mode, mean, support -# import POMDPLinter: @POMDP_require +import POMDPLinter: @POMDP_require export render From f5149e92f83d847e5e0948d8d7e2f4381c864e49 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 25 Apr 2023 09:57:10 -0600 Subject: [PATCH 3/6] removed POMDPModelTools from docs Project --- docs/Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 7b0715ae..127d6a12 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,7 +3,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d" NamedTupleTools = "d9ec5142-1e00-5aa0-9d6a-321866360f50" -POMDPModelTools = "08074719-1b2a-587c-a292-00f91cc44415" POMDPModels = "355abbd5-f08e-5560-ac9e-8b5f2592a0ca" POMDPTools = "7588e00f-9cae-40de-98dc-e0c70c48cdd7" POMDPs = "a93abf59-7444-517b-a68a-c42f96afdd7d" From 6429e308ae9e33f7c92d6f2417bf197a3cfd6a33 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Fri, 19 May 2023 18:56:00 -0700 Subject: [PATCH 4/6] added quick and dirty default policy to PlaybackPolicy. someone make it better! --- lib/POMDPTools/src/Policies/playback.jl | 8 +++++--- lib/POMDPTools/test/policies/test_playback_policy.jl | 6 ++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/lib/POMDPTools/src/Policies/playback.jl b/lib/POMDPTools/src/Policies/playback.jl index a0853060..57479a3e 100644 --- a/lib/POMDPTools/src/Policies/playback.jl +++ b/lib/POMDPTools/src/Policies/playback.jl @@ -21,7 +21,11 @@ mutable struct PlaybackPolicy{A<:AbstractArray, P<:Policy, V<:AbstractArray{<:Re end # Constructor for the PlaybackPolicy -PlaybackPolicy(actions::AbstractArray, backup_policy::Policy; logpdfs::AbstractArray{<:Real} = Float64[]) = PlaybackPolicy(actions, backup_policy, logpdfs, 1) +function PlaybackPolicy(actions::AbstractArray, + backup_policy::Policy = FunctionPolicy(s->error("PlaybackPolicy out of actions.")); + logpdfs::AbstractArray{<:Real} = Float64[]) + return PlaybackPolicy(actions, backup_policy, logpdfs, 1) +end # Action selection for the PlaybackPolicy function POMDPs.action(p::PlaybackPolicy, s) @@ -41,5 +45,3 @@ function Distributions.logpdf(p::PlaybackPolicy, h) return sum(p.logpdfs[1:N]) end end - - diff --git a/lib/POMDPTools/test/policies/test_playback_policy.jl b/lib/POMDPTools/test/policies/test_playback_policy.jl index 1ea5d951..2dc47468 100644 --- a/lib/POMDPTools/test/policies/test_playback_policy.jl +++ b/lib/POMDPTools/test/policies/test_playback_policy.jl @@ -13,6 +13,12 @@ playback = PlaybackPolicy(collect(action_hist(hist)), RandomPolicy(mdp)) hist2 = simulate(HistoryRecorder(), mdp, playback, GWPos(3,3)) @test hist == hist2 +## Test with default error policy +playback = PlaybackPolicy(collect(action_hist(hist))) +@test all(playback.actions .== action_hist(hist)) +hist3 = simulate(HistoryRecorder(), mdp, playback, GWPos(3,3)) +@test_throws ErrorException action(playback, GWPos(3,3)) + ## Test log probability Distributions.logpdf(p::RandomPolicy, h) = length(h)*log(1. / length(actions(p.problem))) playback = PlaybackPolicy(collect(action_hist(hist)), RandomPolicy(mdp), logpdfs = -ones(length(hist))) From 51ba55c4097e82741cbacab466628a1413e5b8d6 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 23 May 2023 10:22:31 -0700 Subject: [PATCH 5/6] before SparseCatIterator --- .../POMDPDistributions/POMDPDistributions.jl | 18 ++++++++++++++++++ lib/POMDPTools/src/POMDPDistributions/bool.jl | 3 ++- .../src/POMDPDistributions/sparse_cat.jl | 17 ++++++++++++----- lib/POMDPTools/test/runtests.jl | 3 +++ 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl b/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl index 9300c763..58c0c8fb 100644 --- a/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl +++ b/lib/POMDPTools/src/POMDPDistributions/POMDPDistributions.jl @@ -6,10 +6,28 @@ using Random: AbstractRNG # Should use Module.function directly in the code instead of doing this import Distributions: support, pdf, mode, mean +using Distributions: DiscreteUnivariateDistribution, Distribution +using Distributions: VariateForm, Multivariate, Matrixvariate, Univariate +using Distributions: ValueSupport, Discrete import Random: rand using UnicodePlots: barplot +""" +Try to guess the Distributions.VariateForm for a distribution based on the sample type. +""" +function infer_variate_form(T::Type) + if T <: AbstractVector + return Multivariate + elseif T <: AbstractMatrix + return Matrixvariate + elseif T <: Number + return Univariate + else + return VariateForm + end +end + export weighted_iterator include("weighted_iteration.jl") diff --git a/lib/POMDPTools/src/POMDPDistributions/bool.jl b/lib/POMDPTools/src/POMDPDistributions/bool.jl index 5114645b..5065fb37 100644 --- a/lib/POMDPTools/src/POMDPDistributions/bool.jl +++ b/lib/POMDPTools/src/POMDPDistributions/bool.jl @@ -5,13 +5,14 @@ Create a distribution over Boolean values (`true` or `false`). `p_true` is the probability of the `true` outcome; the probability of `false` is 1-`p_true`. """ -struct BoolDistribution +struct BoolDistribution <: DiscreteUnivariateDistribution p::Float64 # probability of true end pdf(d::BoolDistribution, s::Bool) = s ? d.p : 1.0-d.p rand(rng::AbstractRNG, s::Random.SamplerTrivial{BoolDistribution}) = rand(rng) <= s[].p +rand(rng::AbstractRNG, d::BoolDistribution) = rand(rng) <= d.p Base.iterate(d::BoolDistribution) = ((d.p, true), true) function Base.iterate(d::BoolDistribution, state::Bool) diff --git a/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl b/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl index dec98b48..badb3913 100644 --- a/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl +++ b/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl @@ -7,7 +7,7 @@ Create a sparse categorical distribution. This is optimized for value iteration with a fast implementation of `weighted_iterator`. Both `pdf` and `rand` are order n. """ -struct SparseCat{V, P} +struct SparseCat{V, P, F} <: Distribution{F, Discrete} vals::V probs::P end @@ -23,7 +23,9 @@ function SparseCat(v, p::AbstractArray) SparseCat(v, cp) end # the method above gets all arrays *except* ones that have a numeric eltype, which are handled below -SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p)}(v, p) +SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p) + +SparseCat(v, p) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p) function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) d = s[] @@ -47,8 +49,14 @@ function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) error("Error sampling from SparseCat distribution with vals $(d.vals) and probs $(d.probs)") # try to help with type stability end +rand(rng::AbstractRNG, d::SparseCat) = rand(rng, Random.SamplerTrivial(d)) + +# to resolve ambiguity between pdf(::UnivariateDistribution, ::Real) and pdf(::SparseCat, ::Any) +pdf(d::SparseCat, s) = _pdf(d, s) +pdf(d::SparseCat, s::Real) = _pdf(d, s) + # slow linear search :( -function pdf(d::SparseCat, s) +function _pdf(d::SparseCat, s) for (v, p) in d if v == s return p @@ -57,7 +65,7 @@ function pdf(d::SparseCat, s) return zero(eltype(d.probs)) end -function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray} +function _pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray} for (i,v) in enumerate(d.vals) if v == s return d.probs[i] @@ -67,7 +75,6 @@ function pdf(d::SparseCat{V,P}, s) where {V<:AbstractArray, P<:AbstractArray} end - support(d::SparseCat) = d.vals weighted_iterator(d::SparseCat) = d diff --git a/lib/POMDPTools/test/runtests.jl b/lib/POMDPTools/test/runtests.jl index 3d5e8cea..fd5f1fe3 100644 --- a/lib/POMDPTools/test/runtests.jl +++ b/lib/POMDPTools/test/runtests.jl @@ -14,6 +14,8 @@ using SparseArrays: sparse import CommonRLInterface +import Distributions: Product + @testset "POMDPTools.jl" begin @testset "POMDPDistributions" begin @@ -23,6 +25,7 @@ import CommonRLInterface include("distributions/test_pretty_printing.jl") include("distributions/test_sparse_cat.jl") include("distributions/test_uniform.jl") + include("distributions/test_distributions_jl_integration.jl") end @testset "ModelTools" begin From 69d1145bfe3a3f40df7a91ee210b00d51f239813 Mon Sep 17 00:00:00 2001 From: Zachary Sunberg Date: Tue, 23 May 2023 12:53:29 -0700 Subject: [PATCH 6/6] finished basic sparse_cat, bool --- lib/POMDPTools/src/POMDPDistributions/bool.jl | 3 +- .../src/POMDPDistributions/sparse_cat.jl | 46 ++++++++++++------- .../test_distributions_jl_integration.jl | 10 ++++ lib/POMDPTools/test/runtests.jl | 3 +- 4 files changed, 42 insertions(+), 20 deletions(-) create mode 100644 lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl diff --git a/lib/POMDPTools/src/POMDPDistributions/bool.jl b/lib/POMDPTools/src/POMDPDistributions/bool.jl index 5065fb37..550f306a 100644 --- a/lib/POMDPTools/src/POMDPDistributions/bool.jl +++ b/lib/POMDPTools/src/POMDPDistributions/bool.jl @@ -9,7 +9,8 @@ struct BoolDistribution <: DiscreteUnivariateDistribution p::Float64 # probability of true end -pdf(d::BoolDistribution, s::Bool) = s ? d.p : 1.0-d.p +pdf(d::BoolDistribution, s::Real) = convert(Bool, s) ? d.p : 1.0-d.p +Distributions.logpdf(d::BoolDistribution, s) = log(pdf(d, s)) rand(rng::AbstractRNG, s::Random.SamplerTrivial{BoolDistribution}) = rand(rng) <= s[].p rand(rng::AbstractRNG, d::BoolDistribution) = rand(rng) <= d.p diff --git a/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl b/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl index badb3913..4fde9485 100644 --- a/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl +++ b/lib/POMDPTools/src/POMDPDistributions/sparse_cat.jl @@ -27,11 +27,12 @@ SparseCat(v, p::AbstractArray{<:Number}) = SparseCat{typeof(v), typeof(p), infer SparseCat(v, p) = SparseCat{typeof(v), typeof(p), infer_variate_form(eltype(v))}(v, p) -function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) - d = s[] +rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) = rand(rng, s[]) + +function rand(rng::AbstractRNG, d::SparseCat) r = sum(d.probs)*rand(rng) tot = zero(eltype(d.probs)) - for (v, p) in d + for (v, p) in weighted_iterator(d) tot += p if r < tot return v @@ -49,15 +50,18 @@ function rand(rng::AbstractRNG, s::Random.SamplerTrivial{<:SparseCat}) error("Error sampling from SparseCat distribution with vals $(d.vals) and probs $(d.probs)") # try to help with type stability end -rand(rng::AbstractRNG, d::SparseCat) = rand(rng, Random.SamplerTrivial(d)) +Distributions.sampler(d::SparseCat) = Random.SamplerTrivial(d) +Random.Sampler(::AbstractRNG, d::SparseCat, repetition::Union{Val{1}, Val{Inf}}) = Random.SamplerTrivial(d) # to resolve ambiguity between pdf(::UnivariateDistribution, ::Real) and pdf(::SparseCat, ::Any) pdf(d::SparseCat, s) = _pdf(d, s) pdf(d::SparseCat, s::Real) = _pdf(d, s) +Distributions.logpdf(d::SparseCat, x) = log(pdf(d, x)) + # slow linear search :( function _pdf(d::SparseCat, s) - for (v, p) in d + for (v, p) in weighted_iterator(d) if v == s return p end @@ -77,16 +81,23 @@ end support(d::SparseCat) = d.vals -weighted_iterator(d::SparseCat) = d +struct SparseCatIterator{D<:SparseCat} + d::D +end + + +weighted_iterator(d::SparseCat) = SparseCatIterator(d) # iterator for general SparseCat # this has some type stability problems -function Base.iterate(d::SparseCat) +function Base.iterate(i::SparseCatIterator) + d = i.d val, vstate = iterate(d.vals) prob, pstate = iterate(d.probs) return ((val=>prob), (vstate, pstate)) end -function Base.iterate(d::SparseCat, dstate::Tuple) +function Base.iterate(i::SparseCatIterator, dstate::Tuple) + d = i.d vstate, pstate = dstate vnext = iterate(d.vals, vstate) pnext = iterate(d.probs, pstate) @@ -101,21 +112,22 @@ end # iterator for SparseCat with indexed members const Indexed = Union{AbstractArray, Tuple, NamedTuple} -function Base.iterate(d::SparseCat{V,P}, state::Integer=1) where {V<:Indexed, P<:Indexed} - if state > length(d) +function Base.iterate(i::SparseCatIterator{<:SparseCat{<:Indexed,<:Indexed}}, state::Integer=1) + if state > length(i) return nothing end - return (d.vals[state]=>d.probs[state], state+1) + return (i.d.vals[state]=>i.d.probs[state], state+1) end -Base.length(d::SparseCat) = min(length(d.vals), length(d.probs)) -Base.eltype(D::Type{SparseCat{V,P}}) where {V, P} = Pair{eltype(V), eltype(P)} -sampletype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V) -Random.gentype(D::Type{SparseCat{V,P}}) where {V, P} = eltype(V) +Base.length(i::SparseCatIterator) = min(length(i.d.vals), length(i.d.probs)) +Base.eltype(D::Type{SparseCatIterator{SparseCat{V,P,F}}}) where {V,P,F} = Pair{eltype(V), eltype(P)} + +sampletype(D::Type{SparseCat{V,P,F}}) where {V,P,F} = eltype(V) +Random.gentype(D::Type{SparseCat{V,P,F}}) where {V,P,F} = eltype(V) function mean(d::SparseCat) vsum = zero(eltype(d.vals)) - for (v, p) in d + for (v, p) in weighted_iterator(d) vsum += v*p end return vsum/sum(d.probs) @@ -124,7 +136,7 @@ end function mode(d::SparseCat) bestp = zero(eltype(d.probs)) bestv = first(d.vals) - for (v, p) in d + for (v, p) in weighted_iterator(d) if p >= bestp bestp = p bestv = v diff --git a/lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl b/lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl new file mode 100644 index 00000000..072e3e10 --- /dev/null +++ b/lib/POMDPTools/test/distributions/test_distributions_jl_integration.jl @@ -0,0 +1,10 @@ +@test POMDPDistributions.infer_variate_form(typeof([1 2; 3 4])) == Distributions.Matrixvariate +@test POMDPDistributions.infer_variate_form(typeof([1, 2])) == Distributions.Multivariate +@test POMDPDistributions.infer_variate_form(typeof(1)) == Distributions.Univariate +@test POMDPDistributions.infer_variate_form(Any) == Distributions.VariateForm + +p = product_distribution([SparseCat([1, 2, 3], [0.5, 0.2, 0.3]), BoolDistribution(1.0)]) +@test rand(p) isa AbstractVector +@test pdf(p, [1, 1]) == 0.5 + +@test_broken p = Product([SparseCat([:a,:b,:c], [0.5, 0.2, 0.3]), BoolDistribution(1.0)]) diff --git a/lib/POMDPTools/test/runtests.jl b/lib/POMDPTools/test/runtests.jl index fd5f1fe3..d2ffa260 100644 --- a/lib/POMDPTools/test/runtests.jl +++ b/lib/POMDPTools/test/runtests.jl @@ -14,8 +14,7 @@ using SparseArrays: sparse import CommonRLInterface -import Distributions: Product - +using Distributions: Distributions, product_distribution, Product @testset "POMDPTools.jl" begin @testset "POMDPDistributions" begin