-
Notifications
You must be signed in to change notification settings - Fork 195
Fix some issues with sampling #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 3 commits
6c4c119
d8dfe60
d56a40a
4a1e261
faa481f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,39 @@ | ||
using Base: mightalias | ||
|
||
if isdefined(Base, :require_one_based_indexing) # TODO: use this directly once we require Julia 1.2+ | ||
using Base: require_one_based_indexing | ||
else | ||
require_one_based_indexing(xs...) = | ||
any((!) ∘ isone ∘ firstindex, xs) && throw(ArgumentError("non 1-based arrays are not supported")) | ||
end | ||
|
||
function _validate_sample_inputs(input::AbstractArray, output::AbstractArray, replace::Bool) | ||
mightalias(input, output) && | ||
throw(ArgumentError("destination array must not share memory with the source array")) | ||
require_one_based_indexing(input, output) | ||
n = length(input) | ||
k = length(output) | ||
if !replace && k > n | ||
throw(DimensionMismatch("cannot draw $k samples of $n values without replacement")) | ||
end | ||
return (n, k) | ||
end | ||
|
||
function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights, | ||
output::AbstractArray, replace::Bool) | ||
mightalias(output, weights) && | ||
throw(ArgumentError("destination array must not share memory with weights array")) | ||
_validate_sample_inputs(input, weights) | ||
return _validate_sample_inputs(input, output, replace) | ||
end | ||
|
||
function _validate_sample_inputs(input::AbstractArray, weights::AbstractWeights) | ||
require_one_based_indexing(weights) | ||
n = length(input) | ||
nw = length(weights) | ||
nw == n || throw(DimensionMismatch("source and weight arrays must have the same length, got $n and $nw")) | ||
return n | ||
end | ||
|
||
########################################################### | ||
# | ||
|
@@ -10,16 +46,15 @@ using Random: Sampler, Random.GLOBAL_RNG | |
### Algorithms for sampling with replacement | ||
|
||
function direct_sample!(rng::AbstractRNG, a::UnitRange, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
s = Sampler(rng, 1:length(a)) | ||
n, k = _validate_sample_inputs(a, x, true) | ||
s = Sampler(rng, 1:n) | ||
b = a[1] - 1 | ||
if b == 0 | ||
for i = 1:length(x) | ||
for i = 1:k | ||
@inbounds x[i] = rand(rng, s) | ||
end | ||
else | ||
for i = 1:length(x) | ||
for i = 1:k | ||
@inbounds x[i] = b + rand(rng, s) | ||
end | ||
end | ||
|
@@ -36,12 +71,9 @@ and set `x[j] = a[i]`, with `n=length(a)` and `k=length(x)`. | |
This algorithm consumes `k` random numbers. | ||
""" | ||
function direct_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
s = Sampler(rng, 1:length(a)) | ||
for i = 1:length(x) | ||
n, k = _validate_sample_inputs(a, x, true) | ||
s = Sampler(rng, 1:n) | ||
for i = 1:k | ||
ararslan marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
@inbounds x[i] = a[rand(rng, s)] | ||
end | ||
return x | ||
|
@@ -61,11 +93,7 @@ storeindices(n, k, T) = false | |
|
||
# order results of a sampler that does not order automatically | ||
function sample_ordered!(sampler!, rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n, k = length(a), length(x) | ||
n, k = _validate_sample_inputs(a, x, true) | ||
# todo: if eltype(x) <: Real && eltype(a) <: Real, | ||
# in some cases it might be faster to check | ||
# issorted(a) to see if we can just sort x | ||
|
@@ -140,13 +168,7 @@ memory space. Suitable for the case where memory is tight. | |
""" | ||
function knuths_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; | ||
initshuffle::Bool=true) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
# initialize | ||
for i = 1:k | ||
|
@@ -200,13 +222,7 @@ faster than Knuth's algorithm especially when `n` is greater than `k`. | |
It is ``O(n)`` for initialization, plus ``O(k)`` for random shuffling | ||
""" | ||
function fisher_yates_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
inds = Vector{Int}(undef, n) | ||
for i = 1:n | ||
|
@@ -240,13 +256,7 @@ However, if `k` is large and approaches ``n``, the rejection rate would increase | |
drastically, resulting in poorer performance. | ||
""" | ||
function self_avoid_sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
s = Set{Int}() | ||
sizehint!(s, k) | ||
|
@@ -282,13 +292,7 @@ This algorithm consumes ``O(n)`` random numbers, with `n=length(a)`. | |
The outputs are ordered. | ||
""" | ||
function seqsample_a!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -324,13 +328,7 @@ This algorithm consumes ``O(k^2)`` random numbers, with `k=length(x)`. | |
The outputs are ordered. | ||
""" | ||
function seqsample_c!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
n = length(a) | ||
k = length(x) | ||
k <= n || error("length(x) should not exceed length(a)") | ||
n, k = _validate_sample_inputs(a, x, false) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -370,13 +368,7 @@ This algorithm consumes ``O(k)`` random numbers, with `k=length(x)`. | |
The outputs are ordered. | ||
""" | ||
function seqsample_d!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
N = length(a) | ||
n = length(x) | ||
n <= N || error("length(x) should not exceed length(a)") | ||
N, n = _validate_sample_inputs(a, x, false) | ||
|
||
i = 0 | ||
j = 0 | ||
|
@@ -485,10 +477,7 @@ nor share memory with them, or the result may be incorrect. | |
""" | ||
function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; | ||
replace::Bool=true, ordered::Bool=false) | ||
1 == firstindex(a) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, x, replace) | ||
k == 0 && return x | ||
|
||
if replace # with replacement | ||
|
@@ -499,8 +488,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, x::AbstractArray; | |
end | ||
|
||
else # without replacement | ||
k <= n || error("Cannot draw more samples without replacement.") | ||
|
||
if ordered | ||
if n > 10 * k * k | ||
seqsample_c!(rng, a, x) | ||
|
@@ -582,8 +569,7 @@ Optionally specify a random number generator `rng` as the first argument | |
(defaults to `Random.GLOBAL_RNG`). | ||
""" | ||
function sample(rng::AbstractRNG, wv::AbstractWeights) | ||
1 == firstindex(wv) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
require_one_based_indexing(wv) | ||
t = rand(rng) * sum(wv) | ||
n = length(wv) | ||
i = 1 | ||
|
@@ -596,7 +582,10 @@ function sample(rng::AbstractRNG, wv::AbstractWeights) | |
end | ||
sample(wv::AbstractWeights) = sample(Random.GLOBAL_RNG, wv) | ||
|
||
sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) = a[sample(rng, wv)] | ||
function sample(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights) | ||
_validate_sample_inputs(a, wv) | ||
return a[sample(rng, wv)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's weird that this line isn't tested. |
||
end | ||
sample(a::AbstractArray, wv::AbstractWeights) = sample(Random.GLOBAL_RNG, a, wv) | ||
|
||
""" | ||
|
@@ -613,15 +602,8 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm: | |
""" | ||
function direct_sample!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) | ||
for i = 1:length(x) | ||
_, k = _validate_sample_inputs(a, wv, x, true) | ||
for i = 1:k | ||
x[i] = a[sample(rng, wv)] | ||
end | ||
return x | ||
|
@@ -702,14 +684,7 @@ Noting `k=length(x)` and `n=length(a)`, this algorithm takes ``O(n \\log n)`` ti | |
for building the alias table, and then ``O(1)`` to draw each sample. It consumes ``2 k`` random numbers. | ||
""" | ||
function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) | ||
n, k = _validate_sample_inputs(a, wv, x, true) | ||
|
||
# create alias table | ||
ap = Vector{Float64}(undef, n) | ||
|
@@ -718,7 +693,7 @@ function alias_sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, | |
|
||
# sampling | ||
s = Sampler(rng, 1:n) | ||
for i = 1:length(x) | ||
for i = 1:k | ||
j = rand(rng, s) | ||
x[i] = rand(rng) < ap[j] ? a[j] : a[alias[j]] | ||
end | ||
|
@@ -740,15 +715,8 @@ and has overall time complexity ``O(n k)``. | |
""" | ||
function naive_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("Inconsistent lengths.")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
k > 0 || return x | ||
|
||
w = Vector{Float64}(undef, n) | ||
copyto!(w, wv) | ||
|
@@ -786,20 +754,13 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. | |
""" | ||
function efraimidis_a_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
|
||
k > 0 || return x | ||
|
||
# calculate keys for all items | ||
keys = randexp(rng, n) | ||
for i in 1:n | ||
@inbounds keys[i] = wv.values[i]/keys[i] | ||
@inbounds keys[i] = wv[i]/keys[i] | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# return items with largest keys | ||
|
@@ -827,15 +788,7 @@ processing time to draw ``k`` elements. It consumes ``n`` random numbers. | |
""" | ||
function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
devmotion marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
k > 0 || return x | ||
|
||
# initialize priority queue | ||
|
@@ -844,7 +797,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
s = 0 | ||
@inbounds for _s in 1:n | ||
s = _s | ||
w = wv.values[s] | ||
w = wv[s] | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
w < 0 && error("Negative weight found in weight vector at index $s") | ||
if w > 0 | ||
i += 1 | ||
|
@@ -859,7 +812,7 @@ function efraimidis_ares_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
@inbounds threshold = pq[1].first | ||
|
||
@inbounds for i in s+1:n | ||
w = wv.values[i] | ||
w = wv[i] | ||
w < 0 && error("Negative weight found in weight vector at index $i") | ||
w > 0 || continue | ||
key = w/randexp(rng) | ||
|
@@ -900,15 +853,7 @@ processing time to draw ``k`` elements. It consumes ``O(k \\log(n / k))`` random | |
function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | ||
wv::AbstractWeights, x::AbstractArray; | ||
ordered::Bool=false) | ||
Base.mightalias(a, x) && | ||
throw(ArgumentError("output array x must not share memory with input array a")) | ||
Base.mightalias(x, wv) && | ||
throw(ArgumentError("output array x must not share memory with weights array wv")) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
length(wv) == n || throw(DimensionMismatch("a and wv must be of same length (got $n and $(length(wv))).")) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, false) | ||
k > 0 || return x | ||
|
||
# initialize priority queue | ||
|
@@ -917,7 +862,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
s = 0 | ||
@inbounds for _s in 1:n | ||
s = _s | ||
w = wv.values[s] | ||
w = wv[s] | ||
w < 0 && error("Negative weight found in weight vector at index $s") | ||
if w > 0 | ||
i += 1 | ||
|
@@ -933,7 +878,7 @@ function efraimidis_aexpj_wsample_norep!(rng::AbstractRNG, a::AbstractArray, | |
X = threshold*randexp(rng) | ||
|
||
@inbounds for i in s+1:n | ||
w = wv.values[i] | ||
w = wv[i] | ||
w < 0 && error("Negative weight found in weight vector at index $i") | ||
w > 0 || continue | ||
X -= w | ||
|
@@ -968,10 +913,8 @@ efraimidis_aexpj_wsample_norep!(a::AbstractArray, wv::AbstractWeights, x::Abstra | |
|
||
function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::AbstractArray; | ||
replace::Bool=true, ordered::Bool=false) | ||
1 == firstindex(a) == firstindex(wv) == firstindex(x) || | ||
throw(ArgumentError("non 1-based arrays are not supported")) | ||
n = length(a) | ||
k = length(x) | ||
n, k = _validate_sample_inputs(a, wv, x, replace) | ||
k > 0 || return x | ||
|
||
if replace | ||
if ordered | ||
|
@@ -991,7 +934,6 @@ function sample!(rng::AbstractRNG, a::AbstractArray, wv::AbstractWeights, x::Abs | |
end | ||
end | ||
else | ||
k <= n || error("Cannot draw $k samples from $n samples without replacement.") | ||
efraimidis_aexpj_wsample_norep!(rng, a, wv, x; ordered=ordered) | ||
end | ||
return x | ||
|
Uh oh!
There was an error while loading. Please reload this page.