@@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC
3838 ) where
3939
4040import Control.Concurrent.Supply (splitSupply )
41+ #if !MIN_VERSION_base(4,18,0)
42+ import Control.Applicative (liftA2 )
43+ #endif
4144import Control.Lens ((^.) , _1 )
4245import qualified Control.Lens as Lens
4346import qualified Control.Monad as Monad
@@ -56,6 +59,7 @@ import qualified Data.Map.Strict as Map
5659import qualified Data.Maybe as Maybe
5760import Data.Monoid (All (.. ))
5861import qualified Data.Text as Text
62+ import Data.Text.Extra (showt )
5963import GHC.Stack (HasCallStack )
6064import qualified Language.Haskell.TH as TH
6165
@@ -72,21 +76,22 @@ import Constants (mAX_TUPLE_SIZE)
7276#endif
7377
7478-- internal
75- import Clash.Core.DataCon (DataCon )
79+ import Clash.Core.DataCon (DataCon )
7680import Clash.Core.Evaluator.Types (whnf' )
7781import Clash.Core.FreeVars
7882 (termFreeVars' , typeFreeVars' , localVarsDoNotOccurIn )
7983import Clash.Core.HasType
8084import Clash.Core.Literal (Literal (.. ))
81- import Clash.Core.Name (nameOcc )
85+ import Clash.Core.Name (OccName , nameOcc )
86+ import Clash.Core.Pretty (showPpr )
8287import Clash.Core.Term
8388 ( Alt , LetBinding , Pat (.. ), PrimInfo (.. ), Term (.. ), TickInfo (.. )
8489 , collectArgs , collectArgsTicks , mkApps , mkTicks , patIds , stripTicks )
8590import Clash.Core.TyCon (TyConMap , TyConName , tyConDataCons )
8691import Clash.Core.Type
8792 (Type , TypeView (.. ), isPolyFunTy , mkTyConApp , splitFunForallTy , tyView )
8893import Clash.Core.Util (mkInternalVar , mkSelectorCase , sccLetBindings )
89- import Clash.Core.Var (isGlobalId , isLocalId , varName )
94+ import Clash.Core.Var (Id , isGlobalId , isLocalId , varName )
9095import Clash.Core.VarEnv
9196 ( InScopeSet , elemInScopeSet , extendInScopeSet , extendInScopeSetList
9297 , notElemInScopeSet , unionInScope )
@@ -138,6 +143,24 @@ import qualified GHC.Prim
138143-- B -> f_out
139144-- C -> h x
140145-- @
146+ --
147+ -- Though that's a lie. It actually converts it into:
148+ --
149+ -- @
150+ -- let f_tupIn = case x of {A -> (3,y); B -> (x,x)}
151+ -- f_arg0 = case f_tupIn of (l,_) -> l
152+ -- f_arg1 = case f_tupIn of (_,r) -> r
153+ -- f_out = f f_arg0 f_arg1
154+ -- in case x of
155+ -- A -> f_out
156+ -- B -> f_out
157+ -- C -> h x
158+ -- @
159+ --
160+ -- In order to share the expression that's in the subject of the case expression,
161+ -- and to share the /decoder/ circuit that logic synthesis will create to map the
162+ -- bits of the subject expression to the bits needed to make the selection in the
163+ -- multiplexer.
141164disjointExpressionConsolidation :: HasCallStack => NormRewrite
142165disjointExpressionConsolidation ctx@ (TransformContext isCtx _) e@ (Case _scrut _ty _alts@ (_: _: _)) = do
143166 -- Collect all (the applications of) global binders (and certain primitives)
@@ -150,11 +173,13 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
150173 else do
151174 -- For every to-lift expression create (the generalization of):
152175 --
153- -- let fargs = case x of {A -> (3,y); B -> (x,x)}
154- -- in f (fst fargs) (snd fargs)
176+ -- let f_tupIn = case x of {A -> (3,y); B -> (x,x)}
177+ -- f_arg0 = case f_tupIn of (l,_) -> l
178+ -- f_arg1 = case f_tupIn of (_,r) -> r
179+ -- in f f_arg0 f_arg0
155180 --
156- -- the let-expression is not created when `f` has only one (selectable)
157- -- argument
181+ -- if an argument is non-representable, the case-expression is inlined,
182+ -- and no let-binding will be created for it.
158183 --
159184 -- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
160185 -- whether expressions reference variables from the context, or
@@ -190,11 +215,8 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
190215 -- Make the let-binder for the lifted expressions
191216 mkFunOut tcm isN ((fun,_),(eLifted,_)) = do
192217 let ty = inferCoreTypeOf tcm eLifted
193- nm = case collectArgs fun of
194- (Var v,_) -> nameOcc (varName v)
195- (Prim p,_) -> primName p
196- _ -> " complex_expression_"
197- nm1 = last (Text. splitOn " ." nm) `Text.append` " Out"
218+ nm = decFunName fun
219+ nm1 = nm `Text.append` " _out"
198220 nm2 <- mkInternalVar isN nm1 ty
199221 return (extendInScopeSet isN nm2,nm2)
200222
@@ -249,12 +271,26 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
249271disjointExpressionConsolidation _ e = return e
250272{-# SCC disjointExpressionConsolidation #-}
251273
274+ decFunName :: Term -> OccName
275+ decFunName fun = last . Text. splitOn " ." $ case collectArgs fun of
276+ (Var v, _) -> nameOcc (varName v)
277+ (Prim p, _) -> primName p
278+ _ -> " complex_expression"
279+
252280data CaseTree a
253281 = Leaf a
254282 | LB [LetBinding ] (CaseTree a )
255283 | Branch Term [(Pat ,CaseTree a )]
256284 deriving (Eq ,Show ,Functor ,Foldable )
257285
286+ instance Applicative CaseTree where
287+ pure = Leaf
288+ liftA2 f (Leaf a) (Leaf b) = Leaf (f a b)
289+ liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2)
290+ liftA2 f (Branch scrut alts1) (Branch _ alts2) =
291+ Branch scrut (zipWith (\ (p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2)
292+ liftA2 _ _ _ = error " CaseTree.liftA2: internal error, this should not happen."
293+
258294-- | Test if a 'CaseTree' collected from an expression indicates that
259295-- application of a global binder is disjoint: occur in separate branches of a
260296-- case-expression.
@@ -269,18 +305,6 @@ isDisjoint ct = go ct
269305 go (Branch _ [(_,x)]) = go x
270306 go b@ (Branch _ (_: _: _)) = allEqual (map Either. rights (Foldable. toList b))
271307
272- -- Remove empty branches from a 'CaseTree'
273- removeEmpty :: Eq a => CaseTree [a ] -> CaseTree [a ]
274- removeEmpty l@ (Leaf _) = l
275- removeEmpty (LB lb ct) =
276- case removeEmpty ct of
277- Leaf [] -> Leaf []
278- ct' -> LB lb ct'
279- removeEmpty (Branch s bs) =
280- case filter ((/= (Leaf [] )) . snd ) (map (second removeEmpty) bs) of
281- [] -> Leaf []
282- bs' -> Branch s bs'
283-
284308-- | Test if all elements in a list are equal to each other.
285309allEqual :: Eq a => [a ] -> Bool
286310allEqual [] = True
@@ -464,90 +488,94 @@ collectGlobalsLbs is0 substitution seen lbs = do
464488-- function-position\", return a let-expression: where the let-binding holds
465489-- a case-expression selecting between the distinct arguments of the case-tree,
466490-- and the body is an application of the term applied to the shared arguments of
467- -- the case tree, and projections of let-binding corresponding to the distinct
468- -- argument positions.
491+ -- the case tree, and variable references to the created let-bindings.
492+ --
493+ -- case-expressions whose type would be non-representable are not let-bound,
494+ -- but occur directly in the argument position of the application in the body
495+ -- of the let-expression.
469496mkDisjointGroup
470497 :: InScopeSet
471498 -- ^ Variables in scope at the very top of the case-tree, i.e., the original
472499 -- expression
473- -> (Term ,([Term ],CaseTree [( Either Term Type ) ]))
500+ -> (Term ,([Term ],CaseTree [Either Term Type ]))
474501 -- ^ Case-tree of arguments belonging to the applied term.
475502 -> NormalizeSession (Term ,[Term ])
476503mkDisjointGroup inScope (fun,(seen,cs)) = do
477504 tcm <- Lens. view tcCache
478- let argss = Foldable. toList cs
479- argssT = zip [0 .. ] (List. transpose argss)
480- (sharedT,distinctT) = List. partition (areShared tcm inScope . fmap (first stripTicks) . snd ) argssT
481- -- TODO: find a better solution than "maybe undefined fst . uncons"
482- shared = map (second (maybe (error " impossible" ) fst . List. uncons)) sharedT
483- distinct = map (Either. lefts) (List. transpose (map snd distinctT))
484- cs' = fmap (zip [0 .. ]) cs
485- cs'' = removeEmpty
486- $ fmap (Either. lefts . map snd )
487- (if null shared
488- then cs'
489- else fmap (filter (`notElem` shared)) cs')
490- (distinctCaseM,distinctProjections) <- case distinct of
491- -- only shared arguments: do nothing.
492- [] -> return (Nothing ,[] )
493- -- Create selectors and projections
494- (uc: _) -> do
495- let argTys = map (inferCoreTypeOf tcm) uc
496- disJointSelProj inScope argTys cs''
497- let newArgs = mkDJArgs 0 shared distinctProjections
498- case distinctCaseM of
499- Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
500- Nothing -> return (mkApps fun newArgs, seen)
501-
502- -- | Create a single selector for all the representable distinct arguments by
503- -- selecting between tuples. This selector is only ('Just') created when the
504- -- number of representable uncommmon arguments is larger than one, otherwise it
505- -- is not ('Nothing').
506- --
507- -- It also returns:
508- --
509- -- * For all the non-representable distinct arguments: a selector
510- -- * For all the representable distinct arguments: a projection out of the tuple
511- -- created by the larger selector. If this larger selector does not exist, a
512- -- single selector is created for the single representable distinct argument.
505+ let funName = decFunName fun
506+ argLen = case Foldable. toList cs of
507+ [] -> error " mkDisjointGroup: no disjoint groups"
508+ l: _ -> length l
509+ csT :: [CaseTree (Either Term Type )] -- "Transposed" 'CaseTree [Either Term Type]'
510+ csT = map (\ i -> fmap (!! i) cs) [0 .. (argLen- 1 )] -- sequenceA does the wrong thing
511+ (lbs,newArgs) <- List. mapAccumRM (\ lbs (c,pos) -> do
512+ let cL = Foldable. toList c
513+ case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
514+ (Right ty: _, True ) ->
515+ return (lbs,Right ty)
516+ (Right _: _, False ) ->
517+ error (" mkDisjointGroup: non-equal type arguments: " <>
518+ showPpr (Either. rights cL))
519+ (Left tm: _, True ) ->
520+ return (lbs,Left tm)
521+ (Left tm: _, False ) -> do
522+ let ty = inferCoreTypeOf tcm tm
523+ let err = error (" mkDisjointGroup: mixed type and term arguments: " <> show cL)
524+ (lbM,arg) <- disJointSelProj inScope ty (Either. fromLeft err <$> c) funName pos
525+ case lbM of
526+ Just lb -> return (lb: lbs, Left arg)
527+ _ -> return (lbs, Left arg)
528+ ([] , _) ->
529+ error " mkDisjointGroup: no arguments"
530+ ) [] (zip csT [0 .. ])
531+ let funApp = mkApps fun newArgs
532+ tupTcm <- Lens. view tupleTcCache
533+ case lbs of
534+ [] ->
535+ return (funApp, seen)
536+ [(v,(ty,ct))] -> do
537+ let e = genCase tcm tupTcm ty [ty] (fmap (: [] ) ct)
538+ return (Letrec [(v,e)] funApp, seen)
539+ _ -> do
540+ let (vs,zs) = unzip lbs
541+ csL :: [CaseTree Term ]
542+ (tys,csL) = unzip zs
543+ csLT :: CaseTree [Term ]
544+ csLT = fmap ($ [] ) (foldr1 (liftA2 (.) ) (fmap (fmap (:) ) csL))
545+ bigTupTy = mkBigTupTy tcm tupTcm tys
546+ ct = genCase tcm tupTcm bigTupTy tys csLT
547+ tupIn <- mkInternalVar inScope (funName <> " _tupIn" ) bigTupTy
548+ projections <-
549+ Monad. zipWithM (\ v n ->
550+ (v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n)
551+ vs [0 .. ]
552+ return (Letrec ((tupIn,ct): projections) funApp, seen)
553+
554+ -- | Create a selector for the case-tree of the argument. If the argument is
555+ -- representable create a let-binding for the created selector, and return
556+ -- a variable reference to this let-binding. If the argument is not representable
557+ -- return the selector directly.
513558disJointSelProj
514559 :: InScopeSet
515- -> [Type ]
516- -- ^ Types of the arguments
517- -> CaseTree [Term ]
518- -- The case-tree of arguments
519- -> NormalizeSession (Maybe LetBinding ,[Term ])
520- disJointSelProj _ _ (Leaf [] ) = return (Nothing ,[] )
521- disJointSelProj inScope argTys cs = do
522- tcm <- Lens. view tcCache
560+ -> Type
561+ -- ^ Types of the argument
562+ -> CaseTree Term
563+ -- ^ The case-tree of argument
564+ -> OccName
565+ -- ^ Name of the lifted function
566+ -> Word
567+ -- ^ Position of the argument
568+ -> NormalizeSession (Maybe (Id , (Type , CaseTree Term )),Term )
569+ disJointSelProj inScope argTy cs funName argN = do
570+ tcm <- Lens. view tcCache
523571 tupTcm <- Lens. view tupleTcCache
524- let maxIndex = length argTys - 1
525- css = map (\ i -> fmap ((: [] ) . (!! i)) cs) [0 .. maxIndex]
526- (untran,tran) <- List. partitionM (isUntranslatableType False . snd ) (zip [0 .. ] argTys)
527- let untranCs = map (css!! ) (map fst untran)
528- untranSels = zipWith (\ (_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
529- untran untranCs
530- (lbM,projs) <- case tran of
531- [] -> return (Nothing ,[] )
532- [(i,ty)] -> return (Nothing ,[genCase tcm tupTcm ty [ty] (css!! i)])
533- tys -> do
534- let m = length tys
535- (tyIxs,tys') = unzip tys
536- tupTy = mkBigTupTy tcm tupTcm tys'
537- cs' = fmap (\ es -> map (es !! ) tyIxs) cs
538- djCase = genCase tcm tupTcm tupTy tys' cs'
539- scrutId <- mkInternalVar inScope " tupIn" tupTy
540- projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0 .. m- 1 ]
541- return (Just (scrutId,djCase),projections)
542- let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
543-
544- return (lbM,selProjs)
545- where
546- tranOrUnTran _ [] projs = projs
547- tranOrUnTran _ sels [] = map snd sels
548- tranOrUnTran n ((ut,s): uts) (p: projs)
549- | n == ut = s : tranOrUnTran (n+ 1 ) uts (p: projs)
550- | otherwise = p : tranOrUnTran (n+ 1 ) ((ut,s): uts) projs
572+ let sel = genCase tcm tupTcm argTy [argTy] (fmap (: [] ) cs)
573+ untran <- isUntranslatableType False argTy
574+ case untran of
575+ True -> return (Nothing , sel)
576+ False -> do
577+ argId <- mkInternalVar inScope (funName <> " _arg" <> showt argN) argTy
578+ return (Just (argId,(argTy,cs)), Var argId)
551579
552580-- | Arguments are shared between invocations if:
553581--
@@ -579,18 +607,6 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
579607 _ -> False
580608 isProof _ = False
581609
582- -- | Create a list of arguments given a map of positions to common arguments,
583- -- and a list of arguments
584- mkDJArgs :: Int -- ^ Current position
585- -> [(Int ,Either Term Type )] -- ^ map from position to common argument
586- -> [Term ] -- ^ (projections for) distinct arguments
587- -> [Either Term Type ]
588- mkDJArgs _ cms [] = map snd cms
589- mkDJArgs _ [] uncms = map Left uncms
590- mkDJArgs n ((m,x): cms) (y: uncms)
591- | n == m = x : mkDJArgs (n+ 1 ) cms (y: uncms)
592- | otherwise = Left y : mkDJArgs (n+ 1 ) ((m,x): cms) uncms
593-
594610-- | Create a case-expression that selects between the distinct arguments given
595611-- a case-tree
596612genCase :: TyConMap
0 commit comments