Skip to content

Commit 5927123

Browse files
authored
Refactor DEC transformation (#2668)
The previous code was a big mess where we partioned arguments into shared and non-shared and then filtered the case-tree depending on whether they were part of the shared arguments or not. But then with the normalisation of type arguments, the second filter did not work properly. This then resulted in shared arguments becoming part of the tuple in the alternatives of the case-expression for the non-shared arguments. The new code is also more robust in the sense that shared and non-shared arguments no longer need to be partioned (shared occur left-most, non-shared occur right-most). They can now be interleaved. The old code would also generate bad Core if ever type and term arguments occured interleaved, this is no longer the case for the new code. Fixes #2628
1 parent f946617 commit 5927123

File tree

5 files changed

+297
-109
lines changed

5 files changed

+297
-109
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
FIXED: Clash no longer errors out in the netlist generation stage when a polymorphic function is applied to type X in one alternative of a case-statement and applied to a newtype wrapper of type X in a different alternative. See [#2828](https://github.com/clash-lang/clash-compiler/issues/2628)

clash-lib/src/Clash/Normalize/Transformations/DEC.hs

Lines changed: 125 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ module Clash.Normalize.Transformations.DEC
3838
) where
3939

4040
import Control.Concurrent.Supply (splitSupply)
41+
#if !MIN_VERSION_base(4,18,0)
42+
import Control.Applicative (liftA2)
43+
#endif
4144
import Control.Lens ((^.), _1)
4245
import qualified Control.Lens as Lens
4346
import qualified Control.Monad as Monad
@@ -56,6 +59,7 @@ import qualified Data.Map.Strict as Map
5659
import qualified Data.Maybe as Maybe
5760
import Data.Monoid (All(..))
5861
import qualified Data.Text as Text
62+
import Data.Text.Extra (showt)
5963
import GHC.Stack (HasCallStack)
6064
import qualified Language.Haskell.TH as TH
6165

@@ -72,21 +76,22 @@ import Constants (mAX_TUPLE_SIZE)
7276
#endif
7377

7478
-- internal
75-
import Clash.Core.DataCon (DataCon)
79+
import Clash.Core.DataCon (DataCon)
7680
import Clash.Core.Evaluator.Types (whnf')
7781
import Clash.Core.FreeVars
7882
(termFreeVars', typeFreeVars', localVarsDoNotOccurIn)
7983
import Clash.Core.HasType
8084
import Clash.Core.Literal (Literal(..))
81-
import Clash.Core.Name (nameOcc)
85+
import Clash.Core.Name (OccName, nameOcc)
86+
import Clash.Core.Pretty (showPpr)
8287
import Clash.Core.Term
8388
( Alt, LetBinding, Pat(..), PrimInfo(..), Term(..), TickInfo(..)
8489
, collectArgs, collectArgsTicks, mkApps, mkTicks, patIds, stripTicks)
8590
import Clash.Core.TyCon (TyConMap, TyConName, tyConDataCons)
8691
import Clash.Core.Type
8792
(Type, TypeView (..), isPolyFunTy, mkTyConApp, splitFunForallTy, tyView)
8893
import Clash.Core.Util (mkInternalVar, mkSelectorCase, sccLetBindings)
89-
import Clash.Core.Var (isGlobalId, isLocalId, varName)
94+
import Clash.Core.Var (Id, isGlobalId, isLocalId, varName)
9095
import Clash.Core.VarEnv
9196
( InScopeSet, elemInScopeSet, extendInScopeSet, extendInScopeSetList
9297
, notElemInScopeSet, unionInScope)
@@ -138,6 +143,24 @@ import qualified GHC.Prim
138143
-- B -> f_out
139144
-- C -> h x
140145
-- @
146+
--
147+
-- Though that's a lie. It actually converts it into:
148+
--
149+
-- @
150+
-- let f_tupIn = case x of {A -> (3,y); B -> (x,x)}
151+
-- f_arg0 = case f_tupIn of (l,_) -> l
152+
-- f_arg1 = case f_tupIn of (_,r) -> r
153+
-- f_out = f f_arg0 f_arg1
154+
-- in case x of
155+
-- A -> f_out
156+
-- B -> f_out
157+
-- C -> h x
158+
-- @
159+
--
160+
-- In order to share the expression that's in the subject of the case expression,
161+
-- and to share the /decoder/ circuit that logic synthesis will create to map the
162+
-- bits of the subject expression to the bits needed to make the selection in the
163+
-- multiplexer.
141164
disjointExpressionConsolidation :: HasCallStack => NormRewrite
142165
disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _ty _alts@(_:_:_)) = do
143166
-- Collect all (the applications of) global binders (and certain primitives)
@@ -150,11 +173,13 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
150173
else do
151174
-- For every to-lift expression create (the generalization of):
152175
--
153-
-- let fargs = case x of {A -> (3,y); B -> (x,x)}
154-
-- in f (fst fargs) (snd fargs)
176+
-- let f_tupIn = case x of {A -> (3,y); B -> (x,x)}
177+
-- f_arg0 = case f_tupIn of (l,_) -> l
178+
-- f_arg1 = case f_tupIn of (_,r) -> r
179+
-- in f f_arg0 f_arg0
155180
--
156-
-- the let-expression is not created when `f` has only one (selectable)
157-
-- argument
181+
-- if an argument is non-representable, the case-expression is inlined,
182+
-- and no let-binding will be created for it.
158183
--
159184
-- NB: mkDisJointGroup needs the context InScopeSet, isCtx, to determine
160185
-- whether expressions reference variables from the context, or
@@ -190,11 +215,8 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
190215
-- Make the let-binder for the lifted expressions
191216
mkFunOut tcm isN ((fun,_),(eLifted,_)) = do
192217
let ty = inferCoreTypeOf tcm eLifted
193-
nm = case collectArgs fun of
194-
(Var v,_) -> nameOcc (varName v)
195-
(Prim p,_) -> primName p
196-
_ -> "complex_expression_"
197-
nm1 = last (Text.splitOn "." nm) `Text.append` "Out"
218+
nm = decFunName fun
219+
nm1 = nm `Text.append` "_out"
198220
nm2 <- mkInternalVar isN nm1 ty
199221
return (extendInScopeSet isN nm2,nm2)
200222

@@ -249,12 +271,26 @@ disjointExpressionConsolidation ctx@(TransformContext isCtx _) e@(Case _scrut _t
249271
disjointExpressionConsolidation _ e = return e
250272
{-# SCC disjointExpressionConsolidation #-}
251273

274+
decFunName :: Term -> OccName
275+
decFunName fun = last . Text.splitOn "." $ case collectArgs fun of
276+
(Var v, _) -> nameOcc (varName v)
277+
(Prim p, _) -> primName p
278+
_ -> "complex_expression"
279+
252280
data CaseTree a
253281
= Leaf a
254282
| LB [LetBinding] (CaseTree a)
255283
| Branch Term [(Pat,CaseTree a)]
256284
deriving (Eq,Show,Functor,Foldable)
257285

286+
instance Applicative CaseTree where
287+
pure = Leaf
288+
liftA2 f (Leaf a) (Leaf b) = Leaf (f a b)
289+
liftA2 f (LB lb c1) (LB _ c2) = LB lb (liftA2 f c1 c2)
290+
liftA2 f (Branch scrut alts1) (Branch _ alts2) =
291+
Branch scrut (zipWith (\(p1,a1) (_,a2) -> (p1,liftA2 f a1 a2)) alts1 alts2)
292+
liftA2 _ _ _ = error "CaseTree.liftA2: internal error, this should not happen."
293+
258294
-- | Test if a 'CaseTree' collected from an expression indicates that
259295
-- application of a global binder is disjoint: occur in separate branches of a
260296
-- case-expression.
@@ -269,18 +305,6 @@ isDisjoint ct = go ct
269305
go (Branch _ [(_,x)]) = go x
270306
go b@(Branch _ (_:_:_)) = allEqual (map Either.rights (Foldable.toList b))
271307

272-
-- Remove empty branches from a 'CaseTree'
273-
removeEmpty :: Eq a => CaseTree [a] -> CaseTree [a]
274-
removeEmpty l@(Leaf _) = l
275-
removeEmpty (LB lb ct) =
276-
case removeEmpty ct of
277-
Leaf [] -> Leaf []
278-
ct' -> LB lb ct'
279-
removeEmpty (Branch s bs) =
280-
case filter ((/= (Leaf [])) . snd) (map (second removeEmpty) bs) of
281-
[] -> Leaf []
282-
bs' -> Branch s bs'
283-
284308
-- | Test if all elements in a list are equal to each other.
285309
allEqual :: Eq a => [a] -> Bool
286310
allEqual [] = True
@@ -464,90 +488,94 @@ collectGlobalsLbs is0 substitution seen lbs = do
464488
-- function-position\", return a let-expression: where the let-binding holds
465489
-- a case-expression selecting between the distinct arguments of the case-tree,
466490
-- and the body is an application of the term applied to the shared arguments of
467-
-- the case tree, and projections of let-binding corresponding to the distinct
468-
-- argument positions.
491+
-- the case tree, and variable references to the created let-bindings.
492+
--
493+
-- case-expressions whose type would be non-representable are not let-bound,
494+
-- but occur directly in the argument position of the application in the body
495+
-- of the let-expression.
469496
mkDisjointGroup
470497
:: InScopeSet
471498
-- ^ Variables in scope at the very top of the case-tree, i.e., the original
472499
-- expression
473-
-> (Term,([Term],CaseTree [(Either Term Type)]))
500+
-> (Term,([Term],CaseTree [Either Term Type]))
474501
-- ^ Case-tree of arguments belonging to the applied term.
475502
-> NormalizeSession (Term,[Term])
476503
mkDisjointGroup inScope (fun,(seen,cs)) = do
477504
tcm <- Lens.view tcCache
478-
let argss = Foldable.toList cs
479-
argssT = zip [0..] (List.transpose argss)
480-
(sharedT,distinctT) = List.partition (areShared tcm inScope . fmap (first stripTicks) . snd) argssT
481-
-- TODO: find a better solution than "maybe undefined fst . uncons"
482-
shared = map (second (maybe (error "impossible") fst . List.uncons)) sharedT
483-
distinct = map (Either.lefts) (List.transpose (map snd distinctT))
484-
cs' = fmap (zip [0..]) cs
485-
cs'' = removeEmpty
486-
$ fmap (Either.lefts . map snd)
487-
(if null shared
488-
then cs'
489-
else fmap (filter (`notElem` shared)) cs')
490-
(distinctCaseM,distinctProjections) <- case distinct of
491-
-- only shared arguments: do nothing.
492-
[] -> return (Nothing,[])
493-
-- Create selectors and projections
494-
(uc:_) -> do
495-
let argTys = map (inferCoreTypeOf tcm) uc
496-
disJointSelProj inScope argTys cs''
497-
let newArgs = mkDJArgs 0 shared distinctProjections
498-
case distinctCaseM of
499-
Just lb -> return (Letrec [lb] (mkApps fun newArgs), seen)
500-
Nothing -> return (mkApps fun newArgs, seen)
501-
502-
-- | Create a single selector for all the representable distinct arguments by
503-
-- selecting between tuples. This selector is only ('Just') created when the
504-
-- number of representable uncommmon arguments is larger than one, otherwise it
505-
-- is not ('Nothing').
506-
--
507-
-- It also returns:
508-
--
509-
-- * For all the non-representable distinct arguments: a selector
510-
-- * For all the representable distinct arguments: a projection out of the tuple
511-
-- created by the larger selector. If this larger selector does not exist, a
512-
-- single selector is created for the single representable distinct argument.
505+
let funName = decFunName fun
506+
argLen = case Foldable.toList cs of
507+
[] -> error "mkDisjointGroup: no disjoint groups"
508+
l:_ -> length l
509+
csT :: [CaseTree (Either Term Type)] -- "Transposed" 'CaseTree [Either Term Type]'
510+
csT = map (\i -> fmap (!!i) cs) [0..(argLen-1)] -- sequenceA does the wrong thing
511+
(lbs,newArgs) <- List.mapAccumRM (\lbs (c,pos) -> do
512+
let cL = Foldable.toList c
513+
case (cL, areShared tcm inScope (fmap (first stripTicks) cL)) of
514+
(Right ty:_, True) ->
515+
return (lbs,Right ty)
516+
(Right _:_, False) ->
517+
error ("mkDisjointGroup: non-equal type arguments: " <>
518+
showPpr (Either.rights cL))
519+
(Left tm:_, True) ->
520+
return (lbs,Left tm)
521+
(Left tm:_, False) -> do
522+
let ty = inferCoreTypeOf tcm tm
523+
let err = error ("mkDisjointGroup: mixed type and term arguments: " <> show cL)
524+
(lbM,arg) <- disJointSelProj inScope ty (Either.fromLeft err <$> c) funName pos
525+
case lbM of
526+
Just lb -> return (lb:lbs, Left arg)
527+
_ -> return (lbs, Left arg)
528+
([], _) ->
529+
error "mkDisjointGroup: no arguments"
530+
) [] (zip csT [0..])
531+
let funApp = mkApps fun newArgs
532+
tupTcm <- Lens.view tupleTcCache
533+
case lbs of
534+
[] ->
535+
return (funApp, seen)
536+
[(v,(ty,ct))] -> do
537+
let e = genCase tcm tupTcm ty [ty] (fmap (:[]) ct)
538+
return (Letrec [(v,e)] funApp, seen)
539+
_ -> do
540+
let (vs,zs) = unzip lbs
541+
csL :: [CaseTree Term]
542+
(tys,csL) = unzip zs
543+
csLT :: CaseTree [Term]
544+
csLT = fmap ($ []) (foldr1 (liftA2 (.)) (fmap (fmap (:)) csL))
545+
bigTupTy = mkBigTupTy tcm tupTcm tys
546+
ct = genCase tcm tupTcm bigTupTy tys csLT
547+
tupIn <- mkInternalVar inScope (funName <> "_tupIn") bigTupTy
548+
projections <-
549+
Monad.zipWithM (\v n ->
550+
(v,) <$> mkBigTupSelector inScope tcm tupTcm (Var tupIn) tys n)
551+
vs [0..]
552+
return (Letrec ((tupIn,ct):projections) funApp, seen)
553+
554+
-- | Create a selector for the case-tree of the argument. If the argument is
555+
-- representable create a let-binding for the created selector, and return
556+
-- a variable reference to this let-binding. If the argument is not representable
557+
-- return the selector directly.
513558
disJointSelProj
514559
:: InScopeSet
515-
-> [Type]
516-
-- ^ Types of the arguments
517-
-> CaseTree [Term]
518-
-- The case-tree of arguments
519-
-> NormalizeSession (Maybe LetBinding,[Term])
520-
disJointSelProj _ _ (Leaf []) = return (Nothing,[])
521-
disJointSelProj inScope argTys cs = do
522-
tcm <- Lens.view tcCache
560+
-> Type
561+
-- ^ Types of the argument
562+
-> CaseTree Term
563+
-- ^ The case-tree of argument
564+
-> OccName
565+
-- ^ Name of the lifted function
566+
-> Word
567+
-- ^ Position of the argument
568+
-> NormalizeSession (Maybe (Id, (Type, CaseTree Term)),Term)
569+
disJointSelProj inScope argTy cs funName argN = do
570+
tcm <- Lens.view tcCache
523571
tupTcm <- Lens.view tupleTcCache
524-
let maxIndex = length argTys - 1
525-
css = map (\i -> fmap ((:[]) . (!!i)) cs) [0..maxIndex]
526-
(untran,tran) <- List.partitionM (isUntranslatableType False . snd) (zip [0..] argTys)
527-
let untranCs = map (css!!) (map fst untran)
528-
untranSels = zipWith (\(_,ty) cs' -> genCase tcm tupTcm ty [ty] cs')
529-
untran untranCs
530-
(lbM,projs) <- case tran of
531-
[] -> return (Nothing,[])
532-
[(i,ty)] -> return (Nothing,[genCase tcm tupTcm ty [ty] (css!!i)])
533-
tys -> do
534-
let m = length tys
535-
(tyIxs,tys') = unzip tys
536-
tupTy = mkBigTupTy tcm tupTcm tys'
537-
cs' = fmap (\es -> map (es !!) tyIxs) cs
538-
djCase = genCase tcm tupTcm tupTy tys' cs'
539-
scrutId <- mkInternalVar inScope "tupIn" tupTy
540-
projections <- mapM (mkBigTupSelector inScope tcm tupTcm (Var scrutId) tys') [0..m-1]
541-
return (Just (scrutId,djCase),projections)
542-
let selProjs = tranOrUnTran 0 (zip (map fst untran) untranSels) projs
543-
544-
return (lbM,selProjs)
545-
where
546-
tranOrUnTran _ [] projs = projs
547-
tranOrUnTran _ sels [] = map snd sels
548-
tranOrUnTran n ((ut,s):uts) (p:projs)
549-
| n == ut = s : tranOrUnTran (n+1) uts (p:projs)
550-
| otherwise = p : tranOrUnTran (n+1) ((ut,s):uts) projs
572+
let sel = genCase tcm tupTcm argTy [argTy] (fmap (:[]) cs)
573+
untran <- isUntranslatableType False argTy
574+
case untran of
575+
True -> return (Nothing, sel)
576+
False -> do
577+
argId <- mkInternalVar inScope (funName <> "_arg" <> showt argN) argTy
578+
return (Just (argId,(argTy,cs)), Var argId)
551579

552580
-- | Arguments are shared between invocations if:
553581
--
@@ -579,18 +607,6 @@ areShared tcm inScope xs@(x:_) = noFV1 && (isProof x || allEqual xs)
579607
_ -> False
580608
isProof _ = False
581609

582-
-- | Create a list of arguments given a map of positions to common arguments,
583-
-- and a list of arguments
584-
mkDJArgs :: Int -- ^ Current position
585-
-> [(Int,Either Term Type)] -- ^ map from position to common argument
586-
-> [Term] -- ^ (projections for) distinct arguments
587-
-> [Either Term Type]
588-
mkDJArgs _ cms [] = map snd cms
589-
mkDJArgs _ [] uncms = map Left uncms
590-
mkDJArgs n ((m,x):cms) (y:uncms)
591-
| n == m = x : mkDJArgs (n+1) cms (y:uncms)
592-
| otherwise = Left y : mkDJArgs (n+1) ((m,x):cms) uncms
593-
594610
-- | Create a case-expression that selects between the distinct arguments given
595611
-- a case-tree
596612
genCase :: TyConMap

clash-lib/src/Data/List/Extra.hs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
module Data.List.Extra
55
( partitionM
66
, mapAccumLM
7+
, mapAccumRM
78
, iterateNM
89
, (<:>)
910
, indexMaybe
@@ -46,6 +47,19 @@ mapAccumLM f acc (x:xs) = do
4647
(acc'',ys) <- mapAccumLM f acc' xs
4748
return (acc'',y:ys)
4849

50+
-- | Monadic version of 'Data.List.mapAccumR'
51+
mapAccumRM
52+
:: Monad m
53+
=> (acc -> x -> m (acc,y))
54+
-> acc
55+
-> [x]
56+
-> m (acc,[y])
57+
mapAccumRM _ acc [] = return (acc,[])
58+
mapAccumRM f acc (x:xs) = do
59+
(acc1,ys) <- mapAccumRM f acc xs
60+
(acc2,y) <- f acc1 x
61+
return (acc2,y:ys)
62+
4963
-- | Monadic version of 'iterate'. A carbon copy ('iterateM') would not
5064
-- terminate, hence the first argument.
5165
iterateNM

tests/Main.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,7 @@ runClashTest = defaultMain $ clashTestRoot
633633
, runTest "T2593" def{hdlSim=[]}
634634
, runTest "T2623CaseConFVs" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]}
635635
, runTest "T2781" def{hdlLoad=[],hdlSim=[],hdlTargets=[VHDL]}
636+
, runTest "T2628" def{hdlTargets=[VHDL], buildTargets=BuildSpecific ["TACacheServerStep"], hdlSim=[]}
636637
] <>
637638
if compiledWith == Cabal then
638639
-- This tests fails without environment files present, which are only

0 commit comments

Comments
 (0)