Skip to content

Commit 5e7fb87

Browse files
Change classifyWith to work with scans
1 parent f8afdaf commit 5e7fb87

File tree

2 files changed

+79
-47
lines changed
  • benchmark/Streamly/Benchmark/Data
  • core/src/Streamly/Internal/Data

2 files changed

+79
-47
lines changed

benchmark/Streamly/Benchmark/Data/Fold.hs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ classifyWith ::
190190
(Monad m, Ord k, Num a) => (a -> k) -> SerialT m a -> m (Map k a)
191191
classifyWith f = S.fold (FL.classifyWith f FL.sum)
192192

193+
{-# INLINE classifyScanWith #-}
194+
classifyScanWith ::
195+
(Monad m, Ord k, Num a) => (a -> k) -> SerialT m a -> m ()
196+
classifyScanWith f = S.drain . S.postscan (FL.classifyScanWith f FL.sum)
197+
193198
-------------------------------------------------------------------------------
194199
-- unzip
195200
-------------------------------------------------------------------------------
@@ -338,6 +343,8 @@ o_1_space_serial_composition value =
338343
$ demuxDefaultWith fn mp
339344
, benchIOSink value "demuxWith [sum, length]" $ demuxWith fn mp
340345
, benchIOSink value "classifyWith sum" $ classifyWith (fst . fn)
346+
, benchIOSink value "classifyScanWith sum"
347+
$ classifyScanWith (fst . fn)
341348
]
342349
]
343350

core/src/Streamly/Internal/Data/Fold.hs

Lines changed: 72 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ module Streamly.Internal.Data.Fold
7474
-- $toListRev
7575
, toStream
7676
, toStreamRev
77+
, toMap
7778

7879
-- ** Terminating Folds
7980
, drainN
@@ -199,6 +200,7 @@ module Streamly.Internal.Data.Fold
199200
-- in individual output buckets using the given fold.
200201
, classify
201202
, classifyWith
203+
, classifyScanWith
202204
-- , classifyWithSel
203205
-- , classifyWithMin
204206

@@ -258,6 +260,7 @@ import Streamly.Internal.Data.Tuple.Strict (Tuple'(..), Tuple3'(..))
258260
import Streamly.Internal.Data.Stream.Serial (SerialT(..))
259261

260262
import qualified Data.Map.Strict as Map
263+
import qualified Data.Set as Set
261264
import qualified Streamly.Internal.Data.Pipe.Type as Pipe
262265
-- import qualified Streamly.Internal.Data.Stream.IsStream.Enumeration as Stream
263266
import qualified Prelude
@@ -1524,10 +1527,57 @@ demuxDefault :: (Monad m, Ord k)
15241527
=> Map k (Fold m a b) -> Fold m (k, a) b -> Fold m (k, a) (Map k b, b)
15251528
demuxDefault = demuxDefaultWith id
15261529

1527-
-- TODO If the data is large we may need a map/hashmap in pinned memory instead
1528-
-- of a regular Map. That may require a serializable constraint though. We can
1529-
-- have another API for that.
1530-
--
1530+
{-# INLINE classifyScanWith #-}
1531+
classifyScanWith :: (Monad m, Ord k) =>
1532+
-- Note: we need to return the Map itself to display the in-progress values
1533+
-- e.g. to implement top. We could possibly create a separate abstraction
1534+
-- for that use case. We return an action because we want it to be lazy so
1535+
-- that the downstream consumers can choose to process or discard it.
1536+
(a -> k) -> Fold m a b -> Fold m a (m (Map k b), Maybe (k, b))
1537+
classifyScanWith f (Fold step1 initial1 extract1) =
1538+
fmap extract $ foldlM' step initial
1539+
1540+
where
1541+
1542+
initial = return $ Tuple3' Map.empty Set.empty Nothing
1543+
1544+
{-# INLINE initFold #-}
1545+
initFold kv set k a = do
1546+
x <- initial1
1547+
case x of
1548+
Partial s -> do
1549+
r <- step1 s a
1550+
return
1551+
$ case r of
1552+
Partial s1 ->
1553+
Tuple3' (Map.insert k s1 kv) set Nothing
1554+
Done b ->
1555+
Tuple3' kv set (Just (k, b))
1556+
Done b -> return (Tuple3' kv (Set.insert k set) (Just (k, b)))
1557+
1558+
step (Tuple3' kv set _) a = do
1559+
let k = f a
1560+
case Map.lookup k kv of
1561+
Nothing -> do
1562+
if Set.member k set
1563+
then return (Tuple3' kv set Nothing)
1564+
else initFold kv set k a
1565+
Just s -> do
1566+
r <- step1 s a
1567+
return
1568+
$ case r of
1569+
Partial s1 ->
1570+
Tuple3' (Map.insert k s1 kv) set Nothing
1571+
Done b ->
1572+
let kv1 = Map.delete k kv
1573+
in Tuple3' kv1 (Set.insert k set) (Just (k, b))
1574+
1575+
extract (Tuple3' kv _ x) = (Prelude.mapM extract1 kv, x)
1576+
1577+
{-# INLINE toMap #-}
1578+
toMap :: (Monad m, Ord k) => Fold m (k, a) (Map k a)
1579+
toMap = foldl' (\kv (k, v) -> Map.insert k v kv) Map.empty
1580+
15311581
-- | Split the input stream based on a key field and fold each split using the
15321582
-- given fold. Useful for map/reduce, bucketizing the input in different bins
15331583
-- or for generating histograms.
@@ -1538,55 +1588,30 @@ demuxDefault = demuxDefaultWith id
15381588
-- :}
15391589
-- fromList [("ONE",[1.0,1.1]),("TWO",[2.0,2.2])]
15401590
--
1541-
-- If the classifier fold stops for a particular key any further inputs in that
1542-
-- bucket are ignored.
1591+
-- Once the classifier fold terminates for a particular key any further inputs
1592+
-- in that bucket are ignored.
1593+
--
1594+
-- Space used is proportional to the number of keys seen till now and
1595+
-- monotonically increases because it stores whether a key has been seen or
1596+
-- not.
15431597
--
15441598
-- /Stops: never/
15451599
--
15461600
-- /Pre-release/
15471601
--
15481602
{-# INLINE classifyWith #-}
1549-
classifyWith :: (Monad m, Ord k) => (a -> k) -> Fold m a b -> Fold m a (Map k b)
1550-
classifyWith f (Fold step1 initial1 extract1) =
1551-
rmapM extract $ foldlM' step initial
1552-
1553-
where
1554-
1555-
initial = return Map.empty
1556-
1557-
step kv a =
1558-
case Map.lookup k kv of
1559-
Nothing -> do
1560-
x <- initial1
1561-
case x of
1562-
Partial s -> do
1563-
r <- step1 s a
1564-
return
1565-
$ flip (Map.insert k) kv
1566-
$ case r of
1567-
Partial s1 -> Left' s1
1568-
Done b -> Right' b
1569-
Done b -> return $ Map.insert k (Right' b) kv
1570-
Just x -> do
1571-
case x of
1572-
Left' s -> do
1573-
r <- step1 s a
1574-
return
1575-
$ flip (Map.insert k) kv
1576-
$ case r of
1577-
Partial s1 -> Left' s1
1578-
Done b -> Right' b
1579-
Right' _ -> return kv
1580-
1581-
where
1582-
1583-
k = f a
1584-
1585-
extract =
1586-
Prelude.mapM
1587-
(\case
1588-
Left' s -> extract1 s
1589-
Right' b -> return b)
1603+
classifyWith :: (Monad m, Ord k) =>
1604+
(a -> k) -> Fold m a b -> Fold m a (Map k b)
1605+
classifyWith f fld =
1606+
let
1607+
classifier = classifyScanWith f fld
1608+
getMap Nothing = pure Map.empty
1609+
getMap (Just action) = action
1610+
aggregator =
1611+
teeWith Map.union
1612+
(rmapM getMap $ lmap fst last)
1613+
(lmap snd $ catMaybes toMap)
1614+
in postscan classifier aggregator
15901615

15911616
-- | Given an input stream of key value pairs and a fold for values, fold all
15921617
-- the values belonging to each key. Useful for map/reduce, bucketizing the

0 commit comments

Comments
 (0)