Skip to content

Commit a8a4f0d

Browse files
committed
Update Sample docs, export Fold folds
1 parent 65dde34 commit a8a4f0d

File tree

1 file changed

+58
-71
lines changed

1 file changed

+58
-71
lines changed

Statistics/Sample.hs

Lines changed: 58 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{-# LANGUAGE FlexibleContexts #-}
22
-- |
33
-- Module : Statistics.Sample
4-
-- Copyright : (c) 2008 Don Stewart, 2009 Bryan O'Sullivan
4+
-- Copyright : (c) 2008 Don Stewart, 2009 Bryan O'Sullivan, 2025 Alex Mason
55
-- License : BSD3
66
--
77
-- Maintainer : bos@serpentine.com
@@ -10,6 +10,9 @@
1010
--
1111
-- Commonly used sample statistics, also known as descriptive
1212
-- statistics.
13+
--
14+
-- To apply these to structures other than `Vector`s, see "Statistics.Sample.Fold"
15+
-- which contains the definitions of all algorithms in this module.
1316

1417
module Statistics.Sample
1518
(
@@ -57,7 +60,11 @@ module Statistics.Sample
5760
, correlation
5861
, covariance2
5962
, correlation2
63+
64+
-- * Helpers
6065
, pair
66+
, foldf
67+
, pairFold
6168
-- * References
6269
-- $references
6370
) where
@@ -75,18 +82,29 @@ import Prelude hiding ((^), sum)
7582
import qualified Control.Foldl as F
7683
import qualified Statistics.Sample.Fold as FF
7784

78-
ffold :: G.Vector v a => F.Fold a b -> v a -> b
79-
ffold f v = F.purely_ (\s i -> G.foldl s i v) f
80-
{-# INLINE ffold #-}
81-
{-# SPECIALIZE ffold :: F.Fold Double b -> U.Vector Double -> b #-}
82-
{-# SPECIALIZE ffold :: F.Fold Double b -> V.Vector Double -> b #-}
85+
-- | Apply a 'Fold' to a 'Vector', with specialisations
86+
-- for [Unboxed]("Data.Vector.Unboxed")
87+
-- and [Boxed]("Data.Vector") vectors.
88+
foldf :: G.Vector v a => F.Fold a b -> v a -> b
89+
foldf f v = F.purely_ (\s i -> G.foldl s i v) f
90+
{-# INLINE foldf #-}
91+
{-# SPECIALIZE foldf :: F.Fold Double b -> U.Vector Double -> b #-}
92+
{-# SPECIALIZE foldf :: F.Fold Double b -> V.Vector Double -> b #-}
8393

8494
data P a = P !a !Int
8595

86-
pfold :: (G.Vector v a, G.Vector v b) => String -> F.Fold (a,b) c -> v a -> v b -> c
87-
pfold err fld va vb
96+
-- | Apply a 'Fold' to a pair of 'Vector's, the vectors must be of the same
97+
-- length and will 'error' if they are not. Because the lengths must be known,
98+
-- this function is not subject to fusion.
99+
100+
-- TODO: Can this be made fuse by using Stream? Instead of checking
101+
-- length at the end, error if the end of one vector is reached before
102+
-- the other. This is the optimistic alternative to the pessimistic
103+
-- initial length check.
104+
pairFold :: (G.Vector v a, G.Vector v b) => String -> F.Fold (a,b) c -> v a -> v b -> c
105+
pairFold err fld va vb
88106
| la /= lb = error err
89-
| otherwise = ffold foldPair va
107+
| otherwise = foldf foldPair va
90108
where la = G.length va
91109
lb = G.length vb
92110
-- Pattern match here so we don't end up forcing arguments
@@ -95,29 +113,29 @@ pfold err fld va vb
95113
(F.Fold step ini end) -> F.Fold step' (P ini 0) end' where
96114
end' (P a _) = end a
97115
step' (P x n) a = P (step x (a, G.unsafeIndex vb n)) (n+1)
98-
{-# INLINE pfold #-}
99-
{-# SPECIALIZE pfold :: String -> F.Fold (Double, Double) b -> U.Vector Double -> U.Vector Double -> b #-}
100-
{-# SPECIALIZE pfold :: String -> F.Fold (Double, Double) b -> V.Vector Double -> V.Vector Double -> b #-}
116+
{-# INLINE pairFold #-}
117+
{-# SPECIALIZE pairFold :: String -> F.Fold (Double, Double) b -> U.Vector Double -> U.Vector Double -> b #-}
118+
{-# SPECIALIZE pairFold :: String -> F.Fold (Double, Double) b -> V.Vector Double -> V.Vector Double -> b #-}
101119

102120

103121
-- | /O(n)/ Range. The difference between the largest and smallest
104122
-- elements of a sample.
105123
range :: (G.Vector v Double) => v Double -> Double
106-
range = ffold FF.range
124+
range = foldf FF.range
107125
{-# INLINE range #-}
108126

109127
-- | /O(n)/ Compute expectation of function over for sample. This is
110128
-- simply @mean . map f@ but won't create intermediate vector.
111129
expectation :: (G.Vector v a) => (a -> Double) -> v a -> Double
112-
expectation f xs = ffold (FF.expectation f) xs
130+
expectation f xs = foldf (FF.expectation f) xs
113131
{-# INLINE expectation #-}
114132

115133
-- | /O(n)/ Arithmetic mean. This uses Kahan-Babuška-Neumaier
116134
-- summation, so is more accurate than 'welfordMean' unless the input
117135
-- values are very large. This function is not subject to stream
118136
-- fusion.
119137
mean :: (G.Vector v Double) => v Double -> Double
120-
mean = ffold FF.mean
138+
mean = foldf FF.mean
121139
{-# SPECIALIZE mean :: U.Vector Double -> Double #-}
122140
{-# SPECIALIZE mean :: V.Vector Double -> Double #-}
123141

@@ -127,25 +145,25 @@ mean = ffold FF.mean
127145
-- Compared to 'mean', this loses a surprising amount of precision
128146
-- unless the inputs are very large.
129147
welfordMean :: (G.Vector v Double) => v Double -> Double
130-
welfordMean = ffold FF.welfordMean
148+
welfordMean = foldf FF.welfordMean
131149
{-# SPECIALIZE welfordMean :: U.Vector Double -> Double #-}
132150
{-# SPECIALIZE welfordMean :: V.Vector Double -> Double #-}
133151

134152
-- | /O(n)/ Arithmetic mean for weighted sample. It uses a single-pass
135153
-- algorithm analogous to the one used by 'welfordMean'.
136154
meanWeighted :: (G.Vector v (Double,Double)) => v (Double,Double) -> Double
137-
meanWeighted = ffold FF.meanWeighted
155+
meanWeighted = foldf FF.meanWeighted
138156
{-# INLINE meanWeighted #-}
139157

140158
-- | /O(n)/ Harmonic mean. This algorithm performs a single pass over
141159
-- the sample.
142160
harmonicMean :: (G.Vector v Double) => v Double -> Double
143-
harmonicMean = ffold FF.harmonicMean
161+
harmonicMean = foldf FF.harmonicMean
144162
{-# INLINE harmonicMean #-}
145163

146164
-- | /O(n)/ Geometric mean of a sample containing no negative values.
147165
geometricMean :: (G.Vector v Double) => v Double -> Double
148-
geometricMean = ffold FF.geometricMean
166+
geometricMean = foldf FF.geometricMean
149167
{-# INLINE geometricMean #-}
150168

151169
-- | Compute the /k/th central moment of a sample. The central moment
@@ -157,7 +175,7 @@ geometricMean = ffold FF.geometricMean
157175
-- For samples containing many values very close to the mean, this
158176
-- function is subject to inaccuracy due to catastrophic cancellation.
159177
centralMoment :: (G.Vector v Double) => Int -> v Double -> Double
160-
centralMoment a xs = ffold (FF.centralMoment a m) xs
178+
centralMoment a xs = foldf (FF.centralMoment a m) xs
161179
where
162180
m = mean xs
163181
{-# SPECIALIZE centralMoment :: Int -> U.Vector Double -> Double #-}
@@ -171,7 +189,7 @@ centralMoment a xs = ffold (FF.centralMoment a m) xs
171189
-- For samples containing many values very close to the mean, this
172190
-- function is subject to inaccuracy due to catastrophic cancellation.
173191
centralMoments :: (G.Vector v Double) => Int -> Int -> v Double -> (Double, Double)
174-
centralMoments a b xs = ffold (FF.centralMoments a b m) xs
192+
centralMoments a b xs = foldf (FF.centralMoments a b m) xs
175193
where m = mean xs
176194

177195

@@ -198,7 +216,7 @@ centralMoments a b xs = ffold (FF.centralMoments a b m) xs
198216
-- For samples containing many values very close to the mean, this
199217
-- function is subject to inaccuracy due to catastrophic cancellation.
200218
skewness :: (G.Vector v Double) => v Double -> Double
201-
skewness xs = ffold (FF.skewness m) xs
219+
skewness xs = foldf (FF.skewness m) xs
202220
where m = mean xs
203221
{-# SPECIALIZE skewness :: U.Vector Double -> Double #-}
204222
{-# SPECIALIZE skewness :: V.Vector Double -> Double #-}
@@ -217,7 +235,7 @@ skewness xs = ffold (FF.skewness m) xs
217235
-- For samples containing many values very close to the mean, this
218236
-- function is subject to inaccuracy due to catastrophic cancellation.
219237
kurtosis :: (G.Vector v Double) => v Double -> Double
220-
kurtosis xs = ffold (FF.kurtosis m) xs
238+
kurtosis xs = foldf (FF.kurtosis m) xs
221239
where m = mean xs
222240
{-# SPECIALIZE kurtosis :: U.Vector Double -> Double #-}
223241
{-# SPECIALIZE kurtosis :: V.Vector Double -> Double #-}
@@ -239,7 +257,7 @@ kurtosis xs = ffold (FF.kurtosis m) xs
239257
-- | Maximum likelihood estimate of a sample's variance. Also known
240258
-- as the population variance, where the denominator is /n/.
241259
variance :: (G.Vector v Double) => v Double -> Double
242-
variance samp = ffold (FF.variance m) samp
260+
variance samp = foldf (FF.variance m) samp
243261
where m = mean samp
244262
{-# SPECIALIZE variance :: U.Vector Double -> Double #-}
245263
{-# SPECIALIZE variance :: V.Vector Double -> Double #-}
@@ -248,7 +266,7 @@ variance samp = ffold (FF.variance m) samp
248266
-- | Unbiased estimate of a sample's variance. Also known as the
249267
-- sample variance, where the denominator is /n/-1.
250268
varianceUnbiased :: (G.Vector v Double) => v Double -> Double
251-
varianceUnbiased samp = ffold (FF.varianceUnbiased m) samp
269+
varianceUnbiased samp = foldf (FF.varianceUnbiased m) samp
252270
where m = mean samp
253271
{-# SPECIALIZE varianceUnbiased :: U.Vector Double -> Double #-}
254272
{-# SPECIALIZE varianceUnbiased :: V.Vector Double -> Double #-}
@@ -258,7 +276,7 @@ varianceUnbiased samp = ffold (FF.varianceUnbiased m) samp
258276
-- since it will calculate mean only once.
259277
meanVariance :: (G.Vector v Double) => v Double -> (Double,Double)
260278
meanVariance samp
261-
| n > 1 = (m, ffold (FF.robustSumVar m) samp / fromIntegral n)
279+
| n > 1 = (m, foldf (FF.robustSumVar m) samp / fromIntegral n)
262280
| otherwise = (m, 0)
263281
where
264282
n = G.length samp
@@ -271,7 +289,7 @@ meanVariance samp
271289
-- since it will calculate mean only once.
272290
meanVarianceUnb :: (G.Vector v Double) => v Double -> (Double,Double)
273291
meanVarianceUnb samp
274-
| n > 1 = (m, ffold (FF.robustSumVar m) samp / fromIntegral (n-1))
292+
| n > 1 = (m, foldf (FF.robustSumVar m) samp / fromIntegral (n-1))
275293
| otherwise = (m, 0)
276294
where
277295
n = G.length samp
@@ -282,21 +300,21 @@ meanVarianceUnb samp
282300
-- | Standard deviation. This is simply the square root of the
283301
-- unbiased estimate of the variance.
284302
stdDev :: (G.Vector v Double) => v Double -> Double
285-
stdDev samp = ffold (FF.stdDev m) samp
303+
stdDev samp = foldf (FF.stdDev m) samp
286304
where m = mean samp
287305
{-# SPECIALIZE stdDev :: U.Vector Double -> Double #-}
288306
{-# SPECIALIZE stdDev :: V.Vector Double -> Double #-}
289307

290308
-- | Standard error of the mean. This is the standard deviation
291309
-- divided by the square root of the sample size.
292310
stdErrMean :: (G.Vector v Double) => v Double -> Double
293-
stdErrMean samp = ffold (FF.stdErrMean m) samp
311+
stdErrMean samp = foldf (FF.stdErrMean m) samp
294312
where m = mean samp
295313
{-# SPECIALIZE stdErrMean :: U.Vector Double -> Double #-}
296314
{-# SPECIALIZE stdErrMean :: V.Vector Double -> Double #-}
297315

298316
robustSumVarWeighted :: (G.Vector v (Double,Double)) => v (Double,Double) -> FF.V
299-
robustSumVarWeighted samp = ffold (fmap fini $ FF.robustSumVarWeighted m) samp
317+
robustSumVarWeighted samp = foldf (fmap fini $ FF.robustSumVarWeighted m) samp
300318
where m = meanWeighted samp
301319
fini (FF.V a b) = FF.V a b
302320
{-# INLINE robustSumVarWeighted #-}
@@ -322,33 +340,31 @@ varianceWeighted samp
322340
-- mean, Knuth's algorithm gives inaccurate results due to
323341
-- catastrophic cancellation.
324342

325-
-- fastVar :: (G.Vector v Double) => v Double -> FF.T1
326-
-- fastVar = ffold FF.fastVar
327343

328344
-- | Maximum likelihood estimate of a sample's variance.
329345
fastVariance :: (G.Vector v Double) => v Double -> Double
330-
fastVariance = ffold FF.fastVariance
346+
fastVariance = foldf FF.fastVariance
331347
{-# INLINE fastVariance #-}
332348

333349
-- | Unbiased estimate of a sample's variance.
334350
fastVarianceUnbiased :: (G.Vector v Double) => v Double -> Double
335-
fastVarianceUnbiased = ffold FF.fastVarianceUnbiased
351+
fastVarianceUnbiased = foldf FF.fastVarianceUnbiased
336352
{-# INLINE fastVarianceUnbiased #-}
337353

338-
-- | Standard deviation. This is simply the square root of the
354+
-- | Standard deviation. This is simply the square root of the
339355
-- maximum likelihood estimate of the variance.
340356
fastStdDev :: (G.Vector v Double) => v Double -> Double
341-
fastStdDev = ffold FF.fastStdDev
357+
fastStdDev = foldf FF.fastStdDev
342358
{-# INLINE fastStdDev #-}
343359

344360
-- | Covariance of sample of pairs. For empty sample it's set to
345361
-- zero
346362
covariance :: (G.Vector v (Double,Double))
347363
=> v (Double,Double)
348364
-> Double
349-
covariance xy = ffold (FF.covariance (muX, muY)) xy
365+
covariance xy = foldf (FF.covariance (muX, muY)) xy
350366
where
351-
FF.V muX muY = ffold (FF.biExpectation fst snd) xy
367+
FF.V muX muY = foldf (FF.biExpectation fst snd) xy
352368
{-# SPECIALIZE covariance :: U.Vector (Double,Double) -> Double #-}
353369
{-# SPECIALIZE covariance :: V.Vector (Double,Double) -> Double #-}
354370

@@ -357,9 +373,9 @@ covariance xy = ffold (FF.covariance (muX, muY)) xy
357373
correlation :: (G.Vector v (Double,Double))
358374
=> v (Double,Double)
359375
-> Double
360-
correlation xy = ffold (FF.correlation (muX, muY)) xy
376+
correlation xy = foldf (FF.correlation (muX, muY)) xy
361377
where
362-
FF.V muX muY = ffold (FF.biExpectation fst snd) xy
378+
FF.V muX muY = foldf (FF.biExpectation fst snd) xy
363379

364380
{-# SPECIALIZE correlation :: U.Vector (Double,Double) -> Double #-}
365381
{-# SPECIALIZE correlation :: V.Vector (Double,Double) -> Double #-}
@@ -372,7 +388,7 @@ covariance2 :: (G.Vector v Double)
372388
-> v Double
373389
-> Double
374390
covariance2 xs ys =
375-
pfold "Statistics.Sample.covariance2: both samples must have same length"
391+
pairFold "Statistics.Sample.covariance2: both samples must have same length"
376392
(FF.covariance (muX, muY))
377393
xs ys
378394
where
@@ -389,12 +405,9 @@ correlation2 :: (G.Vector v Double)
389405
-> v Double
390406
-> Double
391407
correlation2 xs ys =
392-
pfold "Statistics.Sample.correlation2: both samples must have same length"
408+
pairFold "Statistics.Sample.correlation2: both samples must have same length"
393409
(FF.correlation (muX, muY))
394410
xs ys
395-
-- | nx /= ny = error $ "Statistics.Sample.correlation2: both samples must have same length"
396-
-- | nx == 0 = 0
397-
-- | otherwise = cov / sqrt (varX * varY)
398411
where
399412
muX = mean xs
400413
muY = mean ys
@@ -411,32 +424,6 @@ pair va vb
411424
| otherwise = error "Statistics.Sample.pair: vector must have same length"
412425
{-# INLINE pair #-}
413426

414-
------------------------------------------------------------------------
415-
-- Helper code. Monomorphic unpacked accumulators.
416-
417-
-- don't support polymorphism, as we can't get unboxed returns if we use it.
418-
419-
{-
420-
421-
Consider this core:
422-
423-
with data T a = T !a !Int
424-
425-
$wfold :: Double#
426-
-> Int#
427-
-> Int#
428-
-> (# Double, Int# #)
429-
430-
and without,
431-
432-
$wfold :: Double#
433-
-> Int#
434-
-> Int#
435-
-> (# Double#, Int# #)
436-
437-
yielding to boxed returns and heap checks.
438-
439-
-}
440427

441428
-- $references
442429
--

0 commit comments

Comments
 (0)