diff --git a/src/constraints/form/form_point_mass.jl b/src/constraints/form/form_point_mass.jl index 6546edd44..42b400f12 100644 --- a/src/constraints/form/form_point_mass.jl +++ b/src/constraints/form/form_point_mass.jl @@ -73,6 +73,9 @@ ReactiveMP.constrain_form(pmconstraint::PointMassFormConstraint, distribution) = # There is no need to call the optimizer on a `Distribution` object since they should have a well defined `mode` ReactiveMP.constrain_form(::PointMassFormConstraint, distribution::Distribution) = PointMass(mode(distribution)) +# Except for the DirichletCollection. For this one, we will use the mean instead: +ReactiveMP.constrain_form(::PointMassFormConstraint, distribution::DirichletCollection) = PointMass(mean(distribution)) + """ default_point_mass_form_constraint_optimizer(::Type{<:VariateType}, ::Type{<:ValueSupport}, constraint::PointMassFormConstraint, distribution) diff --git a/test/constraints/form/form_point_mass_tests.jl b/test/constraints/form/form_point_mass_tests.jl index d94ac8638..1c14e85bc 100644 --- a/test/constraints/form/form_point_mass_tests.jl +++ b/test/constraints/form/form_point_mass_tests.jl @@ -125,4 +125,12 @@ @test mean(result.posteriors[:θ]) ≈ p atol = 1e-1 end end + + Distributions.mode(d::DirichletCollection) = error("`mode` should not be called on DirichletCollection.") + + @testset "`DirichletCollection` exception (mode is not defined)" begin + constraint = PointMassFormConstraint() + d = DirichletCollection(ones(3, 3)) + opt = constrain_form(constraint, d) + end end