diff --git a/src/sampling.jl b/src/sampling.jl index a256dd0c..de0d37f0 100644 --- a/src/sampling.jl +++ b/src/sampling.jl @@ -603,6 +603,9 @@ sample(wv::AbstractWeights) = sample(default_rng(), wv) sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] sample(a::AbstractArray, wv::AbstractWeights) = sample(default_rng(), a, wv) +# Specialization for `UnitWeights` +sample(rng::AbstractRNG, wv::UnitWeights) = rand(rng, 1:length(wv)) + """ direct_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray) @@ -633,6 +636,14 @@ end direct_sample!(a::AbstractArray, wv::AbstractWeights, x::AbstractArray) = direct_sample!(default_rng(), a, wv, x) +# Specialization for `UnitWeights` +function direct_sample!(rng::AbstractRNG, a::AbstractArray, wv::UnitWeights, x::AbstractArray) + if length(a) != length(wv) + throw(DimensionMismatch(lazy"Number of samples ($(length(a))) and sample weights ($(length(wv))) must be equal.")) + end + return direct_sample!(rng, a, x) +end + """ alias_sample!([rng], a::AbstractArray, wv::AbstractWeights, x::AbstractArray) @@ -741,7 +752,7 @@ function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, # calculate keys for all items keys = randexp(rng, n) for i in 1:n - @inbounds keys[i] = wv.values[i]/keys[i] + @inbounds keys[i] = wv[i]/keys[i] end # return items with largest keys @@ -787,7 +798,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, s = 0 @inbounds for _s in 1:n s = _s - w = wv.values[s] + w = wv[s] w < 0 && error("Negative weight found in weight vector at index $s") if w > 0 i += 1 @@ -802,7 +813,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, @inbounds threshold = pq[1].first @inbounds for i in s+1:n - w = wv.values[i] + w = wv[i] w < 0 && error("Negative weight found in weight vector at index $i") w > 0 || continue key = w/randexp(rng) @@ -861,7 +872,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, s = 0 @inbounds for _s in 1:n s = _s - w = wv.values[s] + w = wv[s] w < 0 && error("Negative weight found in weight vector at index $s") if w > 0 i += 1 @@ -877,7 +888,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, X = threshold*randexp(rng) @inbounds for i in s+1:n - w = wv.values[i] + w = wv[i] w < 0 && error("Negative weight found in weight vector at index $i") w > 0 || continue X -= w @@ -958,6 +969,14 @@ sample(a::AbstractArray, wv::AbstractWeights, dims::Dims; replace::Bool=true, ordered::Bool=false) = sample(default_rng(), a, wv, dims; replace=replace, ordered=ordered) +# Specialization for `UnitWeights` +function sample!(rng::AbstractRNG, a::AbstractArray, wv::UnitWeights, x::AbstractArray; replace::Bool=true, ordered::Bool=false) + if length(a) != length(wv) + throw(DimensionMismatch(lazy"Number of samples ($(length(a))) and sample weights ($(length(wv))) must be equal.")) + end + return sample!(rng, a, x; replace, ordered) +end + # wsample interface """ diff --git a/test/sampling.jl b/test/sampling.jl index cd2ff096..a0d9bd51 100644 --- a/test/sampling.jl +++ b/test/sampling.jl @@ -297,3 +297,51 @@ end end end end + +# Custom unit weights without `values` field +struct YAUnitWeights <: StatsBase.AbstractWeights{Int, Int, Vector{Int}} + n::Int +end +Base.sum(wv::YAUnitWeights) = wv.n +Base.length(wv::YAUnitWeights) = wv.n +Base.isempty(wv::YAUnitWeights) = iszero(wv.n) +Base.size(wv::YAUnitWeights) = (wv.n,) +Base.axes(wv::YAUnitWeights) = (Base.OneTo(wv.n),) +function Base.getindex(wv::YAUnitWeights, i::Int) + @boundscheck checkbounds(wv, i) + return 1 +end + +@testset "issue #950" begin + # Sampling with unit weights behaves the same as sampling without weights + Random.seed!(123) + xs = sample(1:100, uweights(100), 10; replace=false) + Random.seed!(123) + @test xs == sample(1:100, 10; replace=false) + + Random.seed!(123) + x = sample(uweights(100)) + Random.seed!(123) + @test x == sample(1:100) + + Random.seed!(123) + xs = direct_sample!(1:100, uweights(100), Vector{Int}(undef, 10)) + Random.seed!(123) + @test xs == direct_sample!(1:100, Vector{Int}(undef, 10)) + + # Errors + @test_throws DimensionMismatch("Number of samples (100) and sample weights (99) must be equal.") sample(1:100, uweights(99), 10; replace=false) + @test_throws DimensionMismatch("Number of samples (80) and sample weights (53) must be equal.") direct_sample!(1:80, uweights(53), Vector{Int}(undef, 10)) + + # Custom units don't error and behave the same as sampling with `Weights` + Random.seed!(123) + xs = sample(1:100, YAUnitWeights(100), 10; replace=false) + Random.seed!(123) + @test xs == sample(1:100, weights(ones(Int, 100)), 10; replace=false) + for f in (StatsBase.efraimidis_a_wsample_norep!, StatsBase.efraimidis_ares_wsample_norep!, StatsBase.efraimidis_aexpj_wsample_norep!) + Random.seed!(123) + xs = f(1:100, YAUnitWeights(100), Vector{Int}(undef, 10)) + Random.seed!(123) + @test xs == f(1:100, weights(ones(Int, 100)), Vector{Int}(undef, 10)) + end +end