Skip to content

Commit 8f3e26d

Browse files
authored
Overload find*, searchsorted*, and StatsBase.sample (#121)
* Overload find*, searchsorted*, and StatsBase.sample * Update Project.toml * add minimum/maximum/extrema * unify with quasimatrix maximum * Update quasireducedim.jl * Update Project.toml * increase coverage * Update ci.yml * Update test_calculus.jl
1 parent f004b85 commit 8f3e26d

15 files changed

+182
-33
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
name: CI
22
on:
3-
- push
4-
- pull_request
3+
push:
4+
branches:
5+
- master
6+
paths-ignore:
7+
- 'LICENSE'
8+
- 'README.md'
9+
- '.github/workflows/TagBot.yml'
10+
pull_request:
11+
paths-ignore:
12+
- 'LICENSE'
13+
- 'README.md'
14+
- '.github/workflows/TagBot.yml'
515
jobs:
616
test:
717
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}

Project.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "QuasiArrays"
22
uuid = "c4ea9172-b204-11e9-377d-29865faadc5c"
33
authors = ["Sheehan Olver <solver@mac.com>"]
4-
version = "0.12.2"
4+
version = "0.13"
55

66
[deps]
77
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
@@ -14,25 +14,30 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1414

1515
[weakdeps]
1616
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
17+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1718

1819
[extensions]
1920
QuasiArraysSparseArraysExt = "SparseArrays"
21+
QuasiArraysStatsBaseExt = "StatsBase"
2022

2123

2224
[compat]
2325
ArrayLayouts = "1"
2426
DomainSets = "0.7.6"
2527
FillArrays = "1"
26-
LazyArrays = "1.2, 2"
28+
LazyArrays = "2"
29+
Random = "1.0"
2730
StaticArrays = "1"
28-
julia = "1.6"
31+
StatsBase = "0.34"
32+
julia = "1.10"
2933

3034
[extras]
3135
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
3236
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
3337
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3438
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
39+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3540
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3641

3742
[targets]
38-
test = ["Base64", "IntervalSets", "Random", "SparseArrays", "Test"]
43+
test = ["Base64", "IntervalSets", "Random", "SparseArrays", "StatsBase", "Test"]

ext/QuasiArraysStatsBaseExt.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
module QuasiArraysStatsBaseExt
2+
using QuasiArrays, StatsBase
3+
import StatsBase: sample, AbstractRNG
4+
import QuasiArrays: sample_layout, MemoryLayout
5+
6+
"""
7+
sample([rng], w::AbstractQuasiArray)
8+
9+
Sample a single random element from `axes(w,1)` weighted according to `w`.
10+
"""
11+
sample(w::AbstractQuasiArray) = sample_layout(MemoryLayout(w), w)
12+
sample(rng::AbstractRNG, w::AbstractQuasiArray) = sample_layout(MemoryLayout(w), rng, w)
13+
14+
"""
15+
sample([rng], w::AbstractQuasiArray, n::Integer)
16+
17+
Sample a n random elements from `axes(w,1)` weighted according to `w`.
18+
"""
19+
sample(w::AbstractQuasiArray, n::Integer) = sample_layout(MemoryLayout(w), w, n)
20+
sample(rng::AbstractRNG, w::AbstractQuasiArray, n::Integer) = sample_layout(MemoryLayout(w), rng, w, n)
21+
22+
function sample_layout(_, f::AbstractQuasiVector)
23+
g = cumsum(f)
24+
searchsortedfirst(g/last(g), rand())
25+
end
26+
27+
function sample_layout(_, rng::AbstractRNG, f::AbstractQuasiVector)
28+
g = cumsum(f)
29+
searchsortedfirst(g/last(g), rand())
30+
end
31+
32+
function sample_layout(_, f::AbstractQuasiVector, n::Integer)
33+
g = cumsum(f)
34+
searchsortedfirst.(Ref(g/last(g)), rand(n))
35+
end
36+
37+
function sample_layout(_, rng::AbstractRNG, f::AbstractQuasiVector, n::Integer)
38+
g = cumsum(f)
39+
searchsortedfirst.(Ref(g/last(g)), rand(rng, n))
40+
end
41+
42+
function sample_layout(_, f::AbstractQuasiMatrix, n...)
43+
@assert size(f,2) == 1 # TODO generalise
44+
sample(f[:,1], n...)
45+
end
46+
47+
function sample_layout(_, rng::AbstractRNG, f::AbstractQuasiMatrix, n...)
48+
@assert size(f,2) == 1 # TODO generalise
49+
sample(rng, f[:,1], n...)
50+
end
51+
52+
end

src/QuasiArrays.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ import Base: Slice, IdentityUnitRange, ScalarIndex, RangeIndex, view, viewindexi
1616
parentindices, reverse, ndims, checkbounds, uncolon,
1717
maybeview, unsafe_view, checkindex, checkbounds_indices,
1818
throw_boundserror, rdims, replace_in_print_matrix, show, summary,
19-
hcat, vcat, hvcat, isassigned
19+
hcat, vcat, hvcat, isassigned, searchsortedfirst, searchsortedlast, searchsorted,
20+
findall, findfirst, findlast, minimum, maximum, extrema
2021
import Base: *, /, \, +, -, ^, inv
2122
import Base: exp, log, sqrt,
2223
cos, sin, tan, csc, sec, cot,
@@ -27,7 +28,8 @@ import Base: Array, Matrix, Vector
2728
import Base: union, intersect, sort, sort!
2829
import Base: conj, real, imag
2930
# reducedim.jl imports
30-
import Base: prod, sum, cumsum, diff, add_sum, mul_prod, mapreduce, max, min, count, _count, any, _any, all, _all, _sum, _prod, _mapreduce, reduced_index, check_reducedims, mapfoldl_impl
31+
import Base: prod, sum, cumsum, diff, add_sum, mul_prod, mapreduce, max, min, count, _count, any, _any, all, _all, _sum, _prod, _mapreduce, reduced_index, check_reducedims, mapfoldl_impl,
32+
_minimum, _maximum, _extrema, _extrema_rf
3133
import Base: BitInteger, IEEEFloat, uniontypes, _InitialValue, safe_tail, reducedim1, _simple_count
3234

3335
import Base: ones, zeros, one, zero, fill
@@ -91,8 +93,8 @@ include("quasibroadcast.jl")
9193
include("abstractquasiarraymath.jl")
9294
include("quasireducedim.jl")
9395

94-
9596
include("quasiarray.jl")
97+
include("quasisort.jl")
9698
include("quasiarraymath.jl")
9799

98100
include("lazyquasiarrays.jl")
@@ -129,9 +131,12 @@ function isapprox(x::AbstractQuasiArray, y::AbstractQuasiArray;
129131
end
130132
end
131133

132-
if !isdefined(Base, :get_extension)
133-
include("../ext/QuasiArraysSparseArraysExt.jl")
134-
end
134+
135+
###
136+
# extension support
137+
###
138+
139+
function sample_layout end
135140

136141

137142
end

src/calculus.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
# sum/cumsum
33
###
44

5-
# support overloading sum by MemoryLayout
6-
_sum(V::AbstractQuasiArray, dims) = sum_layout(MemoryLayout(V), V, dims)
7-
_sum(V::AbstractQuasiArray, ::Colon) = sum_layout(MemoryLayout(V), V, :)
85

96
_cumsum(A, dims) = cumsum_layout(MemoryLayout(A), A, dims)
107
cumsum(A::AbstractQuasiArray; dims::Integer=1) = _cumsum(A, dims)

src/indices.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,8 @@ to_indices(A::AbstractQuasiArray, inds, I::Tuple{Any,Vararg{Any}}) =
118118
(_uncolon(inds, I), to_indices(A, _cutdim(inds, I[1]), tail(I))...)
119119

120120
_cutdim(inds, I1) = safe_tail(inds)
121+
_uncolon(inds, I) = uncolon(inds)
121122

122-
if VERSION < v"1.9-"
123-
_uncolon(inds, I) = uncolon(inds, I)
124-
else
125-
_uncolon(inds, I) = uncolon(inds)
126-
end
127123
LinearIndices(A::AbstractQuasiArray) = LinearIndices(axes(A))
128124

129125

@@ -184,9 +180,6 @@ first(S::Inclusion) = first(S.domain)
184180
last(S::Inclusion) = last(S.domain)
185181
size(S::Inclusion) = (cardinality(S.domain),)
186182
length(S::Inclusion) = cardinality(S.domain)
187-
if VERSION < v"1.7-"
188-
Base.unsafe_length(S::Inclusion) = length(S)
189-
end
190183
cardinality(S::Inclusion) = cardinality(S.domain)
191184
measure(x) = cardinality(x) # TODO: Inclusion(0:0.5:1) should have
192185
getindex(S::Inclusion{T}, i::T) where T = (@_inline_meta; @boundscheck checkbounds(S, i); convert(T,i))

src/quasifill.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,8 @@ sum(x::AbstractQuasiFill) = getindex_value(x)*measure(axes(x,1))
264264
sum(x::QuasiZeros) = getindex_value(x)
265265

266266
# define `sum(::Callable, ::AbstractQuasiFill)` to avoid method ambiguity errors on Julia 1.0
267-
sum(f, x::AbstractQuasiFill) = _sum(f, x)
268-
sum(f::Base.Callable, x::AbstractQuasiFill) = _sum(f, x)
269-
_sum(f, x::AbstractQuasiFill) = measure(x) * f(getindex_value(x))
267+
_sum(f, x::AbstractQuasiFill, ::Colon) = measure(x) * f(getindex_value(x))
268+
_sum(f, x::AbstractQuasiFill, dims) = measure(x) * f(getindex_value(x))
270269

271270

272271
#########

src/quasireducedim.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,22 @@ count!(f, r::AbstractQuasiArray, A::AbstractQuasiArray; init::Bool=true) =
167167

168168
for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod, :mul_prod),
169169
(:maximum, :_maximum, :max), (:minimum, :_minimum, :min)]
170+
fname_layout = Symbol(string(fname) * "_layout")
170171
@eval begin
171172
# User-facing methods with keyword arguments
172173
@inline ($fname)(a::AbstractQuasiArray; dims=:, kw...) = ($_fname)(a, dims; kw...)
173174
@inline ($fname)(f, a::AbstractQuasiArray; dims=:, kw...) = ($_fname)(f, a, dims; kw...)
175+
@inline $_fname(a::AbstractQuasiArray, dims::Colon) = $fname_layout(MemoryLayout(a), a, dims)
176+
@inline $_fname(a::AbstractQuasiArray, dims) = $fname_layout(MemoryLayout(a), a, dims)
177+
@inline $_fname(f, a::AbstractQuasiArray, dims::Colon) = $fname_layout(MemoryLayout(a), f, a, dims)
178+
@inline $_fname(f, a::AbstractQuasiArray, dims) = $fname_layout(MemoryLayout(a), f, a, dims)
179+
@inline $fname_layout(lay, A, dims; kw...) = mapreduce(identity, $(op), A; dims=dims, kw...)
180+
@inline $fname_layout(lay, f, A, dims; kw...) = mapreduce(f, $(op), A; dims=dims, kw...)
174181
end
175182
end
176183

184+
extrema(f::AbstractQuasiVector) = extrema_layout(MemoryLayout(f), f)
185+
177186

178187

179188
any(a::AbstractQuasiArray; dims=:) = _any(a, dims)

src/quasisort.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
for func in (:findall, :findfirst, :findlast)
2+
func_layout = Symbol(string(func) * "_layout")
3+
@eval begin
4+
$func(f::Function, v::AbstractQuasiVector; kwds...) = $func_layout(MemoryLayout(v), f, v; kwds...)
5+
function $func_layout(lay, f, v::QuasiVector; kwds...)
6+
inds = $func(f, parent(v); kwds...)
7+
inds isa Nothing && return nothing
8+
v.axes[1][inds]
9+
end
10+
end
11+
end
12+
13+
for func in (:searchsortedfirst, :searchsorted, :searchsortedlast)
14+
func_layout = Symbol(string(func) * "_layout")
15+
@eval begin
16+
$func(f::AbstractQuasiVector, x; kwds...) = $func_layout(MemoryLayout(f), f, x; kwds...)
17+
function $func_layout(lay, f::QuasiVector, x; kwds...)
18+
inds = $func(parent(f), x; kwds...)
19+
f.axes[1][inds]
20+
end
21+
end
22+
end
23+
24+
extrema_layout(::QuasiArrayLayout, f) = extrema(parent(f)) # Hack that assumes QuasiArray

test/runtests.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ include("test_quasipermutedims.jl")
88
include("test_quasireducedim.jl")
99
include("test_quasireshapedarray.jl")
1010

11+
include("test_quasisort.jl")
12+
1113
include("test_dense.jl")
1214
include("test_quasiadjtrans.jl")
1315
include("test_continuous.jl")
@@ -22,4 +24,5 @@ include("test_quasikron.jl")
2224
include("test_ldiv.jl")
2325
include("test_quasilazy.jl")
2426

25-
include("test_sparsequasi.jl")
27+
include("test_sparsequasi.jl")
28+
include("test_statsbaseext.jl")

0 commit comments

Comments
 (0)