@@ -37,6 +37,7 @@ import Data.Bifunctor (bimap)
3737import Data.Either (lefts )
3838import qualified Data.List as List
3939import qualified Data.Map as Map
40+ import Data.Maybe (fromMaybe )
4041import qualified Data.HashMap.Strict as HashMapS
4142import Data.Text (Text )
4243import qualified Data.Text as Text
@@ -51,7 +52,7 @@ import Clash.Core.FreeVars
5152import Clash.Core.Name (Name (nameOcc ,nameUniq ))
5253import Clash.Core.Pretty (showPpr )
5354import Clash.Core.Subst
54- (deShadowTerm , extendTvSubstList , mkSubst , substTm )
55+ (deShadowTerm , extendTvSubstList , mkSubst , substTm , extendIdSubstList )
5556import Clash.Core.Term
5657 (Context , CoreContext (AppArg ), PrimInfo (.. ), Term (.. ), WorkInfo (.. ),
5758 TickInfo (NameMod ), NameMod (PrefixName ), collectArgs , collectArgsTicks )
@@ -407,7 +408,7 @@ normalizeTopLvlBndr isTop nm (nm',sp,inl,tm) = makeCachedU nm (extra.normalized)
407408 -- into a loop. Deshadowing freshens all the bindings
408409 -- to avoid this.
409410 let tm1 = deShadowTerm emptyInScopeSet tm
410- tm2 = if isTop then substWithTyEq [] [] tm1 else tm1
411+ tm2 = if isTop then fromMaybe tm1 ( substWithTyEq [] [] [] tm1) else tm1
411412 old <- Lens. use curFun
412413 tm3 <- rewriteExpr (" normalization" ,normalization) (nmS,tm2) (nm',sp)
413414 curFun .= old
@@ -428,18 +429,24 @@ normalizeTopLvlBndr isTop nm (nm',sp,inl,tm) = makeCachedU nm (extra.normalized)
428429substWithTyEq
429430 :: [TyVar ]
430431 -> [(TyVar ,Type )]
432+ -> [Id ]
431433 -> Term
432- -> Term
433- substWithTyEq tvs cvs (TyLam tv e) = substWithTyEq (tv: tvs) cvs e
434- substWithTyEq tvs cvs (Lam v e)
434+ -> Maybe Term
435+ -- ^ 'Nothing' if 'substWithTyEq' didn't have to substitute anything
436+ substWithTyEq tvs cvs ids_ (TyLam tv e) = substWithTyEq (tv: tvs) cvs ids_ e
437+ substWithTyEq tvs cvs ids_ (Lam v e)
435438 | TyConApp (nameUniq -> tcUniq) [_,VarTy tv, ty] <- tyView (varType v)
436439 , tcUniq == getKey eqTyConKey
437440 , tv `elem` tvs
438- = substWithTyEq (tvs List. \\ [tv]) ((tv,ty): cvs) e
439- substWithTyEq tvs cvs@ (_: _) e =
440- let e1 = List. foldl' (flip TyLam ) e tvs
441- in substTm " substWithTyEq" (extendTvSubstList (mkSubst emptyInScopeSet) cvs) e1
442- substWithTyEq tvs _ e = List. foldl' (flip TyLam ) e tvs
441+ = substWithTyEq (tvs List. \\ [tv]) ((tv,ty): cvs) (v: ids_) e
442+ substWithTyEq tvs cvs@ (_: _) ids_ e =
443+ let
444+ e1 = List. foldl' (flip TyLam ) e tvs
445+ subst0 = extendTvSubstList (mkSubst emptyInScopeSet) cvs
446+ subst1 = extendIdSubstList subst0 [(v, removedTm (varType v)) | v <- ids_]
447+ in
448+ Just (substTm " substWithTyEq" subst1 e1)
449+ substWithTyEq _ _ _ _ = Nothing
443450
444451-- | Rewrite a term according to the provided transformation
445452rewriteExpr :: (String ,NormRewrite ) -- ^ Transformation to apply
0 commit comments