{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
module GHC.Stg.EnforceEpt.Rewrite (rewriteTopBinds, rewriteOpApp)
where
import GHC.Prelude
import GHC.Builtin.PrimOps ( PrimOp(..) )
import GHC.Types.Basic ( CbvMark (..), isMarkedCbv
, TopLevelFlag(..), isTopLevel )
import GHC.Types.Id
import GHC.Types.Name
import GHC.Types.Unique.Supply
import GHC.Types.Unique.FM
import GHC.Types.RepType
import GHC.Types.Var.Set
import GHC.Unit.Types
import GHC.Core.DataCon
import GHC.Core ( AltCon(..) )
import GHC.Core.Type
import GHC.StgToCmm.Types
import GHC.StgToCmm.Closure (importedIdLFInfo)
import GHC.Stg.Utils
import GHC.Stg.Syntax as StgSyn
import GHC.Data.Maybe
import GHC.Utils.Panic
import GHC.Utils.Outputable
import GHC.Utils.Monad.State.Strict
import GHC.Utils.Misc
import GHC.Stg.EnforceEpt.Types
import Control.Monad
newtype RM a = RM { forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM :: (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a) }
deriving ((forall a b. (a -> b) -> RM a -> RM b)
-> (forall a b. a -> RM b -> RM a) -> Functor RM
forall a b. a -> RM b -> RM a
forall a b. (a -> b) -> RM a -> RM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> RM a -> RM b
fmap :: forall a b. (a -> b) -> RM a -> RM b
$c<$ :: forall a b. a -> RM b -> RM a
<$ :: forall a b. a -> RM b -> RM a
Functor, Applicative RM
Applicative RM =>
(forall a b. RM a -> (a -> RM b) -> RM b)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a. a -> RM a)
-> Monad RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM a -> (a -> RM b) -> RM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. RM a -> (a -> RM b) -> RM b
>>= :: forall a b. RM a -> (a -> RM b) -> RM b
$c>> :: forall a b. RM a -> RM b -> RM b
>> :: forall a b. RM a -> RM b -> RM b
$creturn :: forall a. a -> RM a
return :: forall a. a -> RM a
Monad, Functor RM
Functor RM =>
(forall a. a -> RM a)
-> (forall a b. RM (a -> b) -> RM a -> RM b)
-> (forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a b. RM a -> RM b -> RM a)
-> Applicative RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM (a -> b) -> RM a -> RM b
forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> RM a
pure :: forall a. a -> RM a
$c<*> :: forall a b. RM (a -> b) -> RM a -> RM b
<*> :: forall a b. RM (a -> b) -> RM a -> RM b
$cliftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
liftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
$c*> :: forall a b. RM a -> RM b -> RM b
*> :: forall a b. RM a -> RM b -> RM b
$c<* :: forall a b. RM a -> RM b -> RM a
<* :: forall a b. RM a -> RM b -> RM a
Applicative)
instance MonadUnique RM where
getUniqueSupplyM :: RM UniqSupply
getUniqueSupplyM = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a b. (a -> b) -> a -> b
$ do
(m, us, mod,lcls) <- State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
let (us1, us2) = splitUniqSupply us
(put) (m,us2,mod,lcls)
return us1
getMap :: RM (UniqFM Id TagSig)
getMap :: RM (UniqFM Id TagSig)
getMap = State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig))
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
fst,UniqSupply
_,Module
_,IdSet
_) -> UniqFM Id TagSig
fst) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> UniqFM Id TagSig)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)
setMap :: (UniqFM Id TagSig) -> RM ()
setMap :: UniqFM Id TagSig -> RM ()
setMap !UniqFM Id TagSig
m = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
(_,us,mod,lcls) <- State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
put (m, us,mod,lcls)
getMod :: RM Module
getMod :: RM Module
getMod = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a b. (a -> b) -> a -> b
$ ( (\(UniqFM Id TagSig
_,UniqSupply
_,Module
thrd,IdSet
_) -> Module
thrd) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> Module)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)
getFVs :: RM IdSet
getFVs :: RM IdSet
getFVs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
_,UniqSupply
_,Module
_,IdSet
lcls) -> IdSet
lcls) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> IdSet)
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)
setFVs :: IdSet -> RM ()
setFVs :: IdSet -> RM ()
setFVs !IdSet
fvs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
(tag_map,us,mod,_lcls) <- State
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
(UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
put (tag_map, us,mod,fvs)
withBind :: TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind :: forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag (StgNonRec BinderP 'InferTaggedBinders
bnd GenStgRhs 'InferTaggedBinders
_) RM a
cont = TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id, TagSig)
BinderP 'InferTaggedBinders
bnd RM a
cont
withBind TopLevelFlag
top_flag (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) RM a
cont = do
let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds :: ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
top_flag [(Id, TagSig)]
bnds RM a
cont
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind (StgNonRec (Id
id, TagSig
tag) GenStgRhs 'InferTaggedBinders
_) = do
s <- RM (UniqFM Id TagSig)
getMap
setMap $ addToUFM s id tag
return ()
addTopBind (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) = do
let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
!s <- RM (UniqFM Id TagSig)
getMap
setMap $! addListToUFM s bnds
withBinder :: TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder :: forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id
id,TagSig
sig) RM a
cont = do
oldMap <- RM (UniqFM Id TagSig)
getMap
setMap $ addToUFM oldMap id sig
a <- if isTopLevel top_flag
then cont
else withLcl id cont
setMap oldMap
return a
withBinders :: TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders :: forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
TopLevel [(Id, TagSig)]
sigs RM a
cont = do
oldMap <- RM (UniqFM Id TagSig)
getMap
setMap $ addListToUFM oldMap sigs
a <- cont
setMap oldMap
return a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
sigs RM a
cont = do
oldMap <- RM (UniqFM Id TagSig)
getMap
oldFvs <- getFVs
setMap $ addListToUFM oldMap sigs
setFVs $ extendVarSetList oldFvs (map fst sigs)
a <- cont
setMap oldMap
setFVs oldFvs
return a
withClosureLcls :: DIdSet -> RM a -> RM a
withClosureLcls :: forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
fvs RM a
act = do
old_fvs <- RM IdSet
getFVs
let !fvs' = (Id -> IdSet -> IdSet) -> IdSet -> DIdSet -> IdSet
forall a. (Id -> a -> a) -> a -> DIdSet -> a
nonDetStrictFoldDVarSet ((IdSet -> Id -> IdSet) -> Id -> IdSet -> IdSet
forall a b c. (a -> b -> c) -> b -> a -> c
flip IdSet -> Id -> IdSet
extendVarSet) IdSet
old_fvs DIdSet
fvs
setFVs fvs'
!r <- act
setFVs old_fvs
return r
withLcl :: Id -> RM a -> RM a
withLcl :: forall a. Id -> RM a -> RM a
withLcl Id
fv RM a
act = do
old_fvs <- RM IdSet
getFVs
let !fvs' = IdSet -> Id -> IdSet
extendVarSet IdSet
old_fvs Id
fv
setFVs fvs'
!r <- act
setFVs old_fvs
return r
isTagged :: Id -> RM Bool
isTagged :: Id -> RM Bool
isTagged Id
v
| Id -> Bool
isDeadEndId Id
v = Bool -> RM Bool
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise = do
this_mod <- RM Module
getMod
let lookupDefault Id
v = Bool -> SDoc -> TagSig -> TagSig
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Module -> Bool
isInteractiveModule Module
this_mod)
(String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"unknown Id:" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
this_mod SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
v)
(TagInfo -> TagSig
TagSig TagInfo
TagDunno)
case nameIsLocalOrFrom this_mod (idName v) of
Bool
True
| Type -> Bool
definitelyUnliftedType (Id -> Type
idType Id
v)
-> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
| Bool
otherwise -> do
!s <- RM (UniqFM Id TagSig)
getMap
let !sig = UniqFM Id TagSig -> TagSig -> Id -> TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> elt -> key -> elt
lookupWithDefaultUFM UniqFM Id TagSig
s (Id -> TagSig
lookupDefault Id
v) Id
v
return $ case sig of
TagSig TagInfo
info ->
case TagInfo
info of
TagInfo
TagDunno -> Bool
False
TagInfo
TagProper -> Bool
True
TagInfo
TagTagged -> Bool
True
TagTuple [TagInfo]
_ -> Bool
True
Bool
False -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RM Bool) -> Bool -> RM Bool
forall a b. (a -> b) -> a -> b
$!
case Id -> LambdaFormInfo
importedIdLFInfo Id
v of
LFReEntrant {}
-> Bool
True
LFThunk {}
-> Bool
False
LFCon {}
-> Bool
True
LFUnknown {}
-> Bool
False
LFUnlifted {}
-> Bool
True
LFLetNoEscape {}
-> Bool
True
isArgTagged :: StgArg -> RM Bool
isArgTagged :: StgArg -> RM Bool
isArgTagged (StgLitArg Literal
_) = Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
isArgTagged (StgVarArg Id
v) = Id -> RM Bool
isTagged Id
v
mkLocalArgId :: Id -> RM Id
mkLocalArgId :: Id -> RM Id
mkLocalArgId Id
id = do
!u <- RM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
return $! setIdUnique (localiseId id) u
rewriteTopBinds :: Module -> UniqSupply -> [GenStgTopBinding 'InferTaggedBinders] -> [TgStgTopBinding]
rewriteTopBinds :: Module
-> UniqSupply
-> [GenStgTopBinding 'InferTaggedBinders]
-> [TgStgTopBinding]
rewriteTopBinds Module
mod UniqSupply
us [GenStgTopBinding 'InferTaggedBinders]
binds =
let doBinds :: RM [TgStgTopBinding]
doBinds = (GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding)
-> [GenStgTopBinding 'InferTaggedBinders] -> RM [TgStgTopBinding]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop [GenStgTopBinding 'InferTaggedBinders]
binds
in State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
-> (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> [TgStgTopBinding]
forall s a. State s a -> s -> a
evalState (RM [TgStgTopBinding]
-> State
(UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM RM [TgStgTopBinding]
doBinds) (UniqFM Id TagSig
forall a. Monoid a => a
mempty, UniqSupply
us, Module
mod, IdSet
forall a. Monoid a => a
mempty)
rewriteTop :: InferStgTopBinding -> RM TgStgTopBinding
rewriteTop :: GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop (StgTopStringLit Id
v ByteString
s) = TgStgTopBinding -> RM TgStgTopBinding
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgTopBinding -> RM TgStgTopBinding)
-> TgStgTopBinding -> RM TgStgTopBinding
forall a b. (a -> b) -> a -> b
$! (Id -> ByteString -> TgStgTopBinding
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
v ByteString
s)
rewriteTop (StgTopLifted GenStgBinding 'InferTaggedBinders
bind) = do
GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind GenStgBinding 'InferTaggedBinders
bind
(GenStgBinding 'CodeGen -> TgStgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted) (GenStgBinding 'CodeGen -> TgStgTopBinding)
-> RM (GenStgBinding 'CodeGen) -> RM TgStgTopBinding
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> (TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
TopLevel GenStgBinding 'InferTaggedBinders
bind)
rewriteBinds :: TopLevelFlag -> InferStgBinding -> RM (TgStgBinding)
rewriteBinds :: TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
_top_flag (StgNonRec BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs) = do
(!rhs) <- (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id, TagSig)
BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs
return $! (StgNonRec (fst v) rhs)
rewriteBinds TopLevelFlag
top_flag b :: GenStgBinding 'InferTaggedBinders
b@(StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) =
TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM (GenStgBinding 'CodeGen)
-> RM (GenStgBinding 'CodeGen)
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag GenStgBinding 'InferTaggedBinders
b (RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen))
-> RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
(rhss) <- (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> RM [TgStgRhs]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (((Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs)
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
return $! (mkRec rhss)
where
mkRec :: [TgStgRhs] -> TgStgBinding
mkRec :: [TgStgRhs] -> GenStgBinding 'CodeGen
mkRec [TgStgRhs]
rhss = [(BinderP 'CodeGen, TgStgRhs)] -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ([Id] -> [TgStgRhs] -> [(Id, TgStgRhs)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> Id)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst ((Id, TagSig) -> Id)
-> (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig))
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders)
-> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig)
forall a b. (a, b) -> a
fst) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) [TgStgRhs]
rhss)
rewriteRhs :: (Id,TagSig) -> InferStgRhs
-> RM (TgStgRhs)
rewriteRhs :: (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id
_id, TagSig
_tagSig) (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
cn [StgTickish]
ticks [StgArg]
args Type
typ) = {-# SCC rewriteRhs_ #-} do
fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
let strictFields =
DataCon -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([StgArg] -> [Bool] -> [(StgArg, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StgArg]
args [Bool]
fieldInfos) :: [(StgArg,Bool)]
let needsEval = ((StgArg, Bool) -> StgArg) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, Bool) -> StgArg
forall a b. (a, b) -> a
fst ([(StgArg, Bool)] -> [StgArg])
-> ([(StgArg, Bool)] -> [(StgArg, Bool)])
-> [(StgArg, Bool)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
((StgArg, Bool) -> Bool) -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((StgArg, Bool) -> Bool) -> (StgArg, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StgArg, Bool) -> Bool
forall a b. (a, b) -> b
snd) ([(StgArg, Bool)] -> [StgArg]) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$
[(StgArg, Bool)]
strictFields :: [StgArg]
let evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]
if (null evalArgs)
then return $! (StgRhsCon ccs con cn ticks args typ)
else do
let ty_stub = String -> a
forall a. HasCallStack => String -> a
panic String
"mkSeqs shouldn't use the type arg"
conExpr <- mkSeqs args evalArgs (\[StgArg]
taggedArgs -> DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [[PrimRep]]
forall {a}. a
ty_stub)
fvs <- fvArgs args
return $! (StgRhsClosure fvs ccs Updatable [] $! conExpr) typ
rewriteRhs (Id, TagSig)
_binding (StgRhsClosure XRhsClosure 'InferTaggedBinders
fvs CostCentreStack
ccs UpdateFlag
flag [BinderP 'InferTaggedBinders]
args GenStgExpr 'InferTaggedBinders
body Type
typ) = do
TopLevelFlag -> [(Id, TagSig)] -> RM TgStgRhs -> RM TgStgRhs
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
DIdSet -> RM TgStgRhs -> RM TgStgRhs
forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
XRhsClosure 'InferTaggedBinders
fvs (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> TgStgExpr
-> Type
-> TgStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'InferTaggedBinders
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
flag (((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args) (TgStgExpr -> Type -> TgStgRhs)
-> RM TgStgExpr -> RM (Type -> TgStgRhs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
body RM (Type -> TgStgRhs) -> RM Type -> RM TgStgRhs
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ
fvArgs :: [StgArg] -> RM DVarSet
fvArgs :: [StgArg] -> RM DIdSet
fvArgs [StgArg]
args = do
fv_lcls <- RM IdSet
getFVs
return $ mkDVarSet [ v | StgVarArg v <- args, elemVarSet v fv_lcls]
rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs = (StgArg -> RM StgArg) -> [StgArg] -> RM [StgArg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM StgArg
rewriteArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg (StgVarArg Id
v) = Id -> StgArg
StgVarArg (Id -> StgArg) -> RM Id -> RM StgArg
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> Id -> RM Id
rewriteId Id
v
rewriteArg (lit :: StgArg
lit@StgLitArg{}) = StgArg -> RM StgArg
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return StgArg
lit
rewriteId :: Id -> RM Id
rewriteId :: Id -> RM Id
rewriteId Id
v = do
!is_tagged <- Id -> RM Bool
isTagged Id
v
if is_tagged then return $! setIdTagSig v (TagSig TagProper)
else return v
rewriteExpr :: GenStgExpr 'InferTaggedBinders -> RM (GenStgExpr 'CodeGen)
rewriteExpr :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgCase {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase GenStgExpr 'InferTaggedBinders
e
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgLet {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet GenStgExpr 'InferTaggedBinders
e
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgLetNoEscape {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
e
rewriteExpr (StgTick StgTickish
t GenStgExpr 'InferTaggedBinders
e) = StgTickish -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
t (TgStgExpr -> TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
e
rewriteExpr e :: GenStgExpr 'InferTaggedBinders
e@(StgConApp {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr e :: GenStgExpr 'InferTaggedBinders
e@(StgApp {}) = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr (StgLit Literal
lit) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (Literal -> TgStgExpr
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
rewriteExpr (StgOpApp StgOp
op [StgArg]
args Type
res_ty) = (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op) ([StgArg] -> Type -> TgStgExpr)
-> RM [StgArg] -> RM (Type -> TgStgExpr)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> [StgArg] -> RM [StgArg]
rewriteArgs [StgArg]
args RM (Type -> TgStgExpr) -> RM Type -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
res_ty
rewriteCase :: InferStgExpr -> RM TgStgExpr
rewriteCase :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase (StgCase GenStgExpr 'InferTaggedBinders
scrut BinderP 'InferTaggedBinders
bndr AltType
alt_type [GenStgAlt 'InferTaggedBinders]
alts) =
TopLevelFlag -> (Id, TagSig) -> RM TgStgExpr -> RM TgStgExpr
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
NotTopLevel (Id, TagSig)
BinderP 'InferTaggedBinders
bndr (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$
(TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM
(TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase RM
(TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM TgStgExpr
-> RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
scrut RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM Id -> RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
Id -> RM Id
rewriteId ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst (Id, TagSig)
BinderP 'InferTaggedBinders
bndr) RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM AltType -> RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
AltType -> RM AltType
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AltType
alt_type RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM [GenStgAlt 'CodeGen] -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
(GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen))
-> [GenStgAlt 'InferTaggedBinders] -> RM [GenStgAlt 'CodeGen]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt [GenStgAlt 'InferTaggedBinders]
alts
rewriteCase GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible: nodeCase"
rewriteAlt :: InferStgAlt -> RM TgStgAlt
rewriteAlt :: GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt alt :: GenStgAlt 'InferTaggedBinders
alt@GenStgAlt{alt_con :: forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con=AltCon
_, alt_bndrs :: forall (pass :: StgPass). GenStgAlt pass -> [BinderP pass]
alt_bndrs=[BinderP 'InferTaggedBinders]
bndrs, alt_rhs :: forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs=GenStgExpr 'InferTaggedBinders
rhs} =
TopLevelFlag
-> [(Id, TagSig)]
-> RM (GenStgAlt 'CodeGen)
-> RM (GenStgAlt 'CodeGen)
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
bndrs (RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen))
-> RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
!rhs' <- GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
rhs
return $! alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'}
rewriteLet :: InferStgExpr -> RM TgStgExpr
rewriteLet :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet (StgLet XLet 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
(!bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
withBind NotTopLevel bind $ do
!expr' <- rewriteExpr expr
return $! (StgLet xt bind' expr')
rewriteLet GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"
rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
rewriteLetNoEscape :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape (StgLetNoEscape XLetNoEscape 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
(!bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
withBind NotTopLevel bind $ do
!expr' <- rewriteExpr expr
return $! (StgLetNoEscape xt bind' expr')
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"
rewriteConApp :: InferStgExpr -> RM TgStgExpr
rewriteConApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp (StgConApp DataCon
con ConstructorNumber
cn [StgArg]
args [[PrimRep]]
tys) = do
fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
let strictIndices = DataCon -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([Bool] -> [StgArg] -> [(Bool, StgArg)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
fieldInfos [StgArg]
args) :: [(Bool, StgArg)]
let needsEval = ((Bool, StgArg) -> StgArg) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, StgArg) -> StgArg
forall a b. (a, b) -> b
snd ([(Bool, StgArg)] -> [StgArg])
-> ([(Bool, StgArg)] -> [(Bool, StgArg)])
-> [(Bool, StgArg)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, StgArg) -> Bool) -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Bool, StgArg) -> Bool) -> (Bool, StgArg) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, StgArg) -> Bool
forall a b. (a, b) -> a
fst) ([(Bool, StgArg)] -> [StgArg]) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$ [(Bool, StgArg)]
strictIndices :: [StgArg]
let evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]
if (not $ null evalArgs)
then do
mkSeqs args evalArgs (\[StgArg]
taggedArgs -> DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [[PrimRep]]
tys)
else return $! (StgConApp con cn args tys)
rewriteConApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"
rewriteApp :: InferStgExpr -> RM TgStgExpr
rewriteApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp (StgApp Id
f []) = do
f' <- Id -> RM Id
rewriteId Id
f
return $! StgApp f' []
rewriteApp (StgApp Id
f [StgArg]
args)
| Just [CbvMark]
marks <- Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f
, [CbvMark]
relevant_marks <- (CbvMark -> Bool) -> [CbvMark] -> [CbvMark]
forall a. (a -> Bool) -> [a] -> [a]
dropWhileEndLE (Bool -> Bool
not (Bool -> Bool) -> (CbvMark -> Bool) -> CbvMark -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CbvMark -> Bool
isMarkedCbv) [CbvMark]
marks
, (CbvMark -> Bool) -> [CbvMark] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any CbvMark -> Bool
isMarkedCbv [CbvMark]
relevant_marks
= Bool
-> SDoc -> ([CbvMark] -> RM TgStgExpr) -> [CbvMark] -> RM TgStgExpr
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr ([CbvMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CbvMark]
relevant_marks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [StgArg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StgArg]
args) (Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
f SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [StgArg] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [StgArg]
args SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [CbvMark] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CbvMark]
relevant_marks)
[CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks
where
unliftArg :: [CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks = do
argTags <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
let argInfo = (StgArg -> CbvMark -> Bool -> (StgArg, CbvMark, Bool))
-> [StgArg] -> [CbvMark] -> [Bool] -> [(StgArg, CbvMark, Bool)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 ((,,)) [StgArg]
args ([CbvMark]
relevant_marks[CbvMark] -> [CbvMark] -> [CbvMark]
forall a. [a] -> [a] -> [a]
++CbvMark -> [CbvMark]
forall a. a -> [a]
repeat CbvMark
NotMarkedCbv) [Bool]
argTags :: [(StgArg, CbvMark, Bool)]
cbvArgInfo = ((StgArg, CbvMark, Bool) -> Bool)
-> [(StgArg, CbvMark, Bool)] -> [(StgArg, CbvMark, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(StgArg, CbvMark, Bool)
x -> (StgArg, CbvMark, Bool) -> CbvMark
forall a b c. (a, b, c) -> b
sndOf3 (StgArg, CbvMark, Bool)
x CbvMark -> CbvMark -> Bool
forall a. Eq a => a -> a -> Bool
== CbvMark
MarkedCbv Bool -> Bool -> Bool
&& (StgArg, CbvMark, Bool) -> Bool
forall a b c. (a, b, c) -> c
thdOf3 (StgArg, CbvMark, Bool)
x Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False) [(StgArg, CbvMark, Bool)]
argInfo
cbvArgIds = [Id
x | StgVarArg Id
x <- ((StgArg, CbvMark, Bool) -> StgArg)
-> [(StgArg, CbvMark, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, CbvMark, Bool) -> StgArg
forall a b c. (a, b, c) -> a
fstOf3 [(StgArg, CbvMark, Bool)]
cbvArgInfo] :: [Id]
mkSeqs args cbvArgIds (\[StgArg]
cbv_args -> Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
cbv_args)
rewriteApp (StgApp Id
f [StgArg]
args) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
args
rewriteApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"
rewriteOpApp :: InferStgExpr -> RM TgStgExpr
rewriteOpApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteOpApp (StgOpApp StgOp
op [StgArg]
args Type
res_ty) = case StgOp
op of
op :: StgOp
op@(StgPrimOp PrimOp
primOp)
| PrimOp
primOp PrimOp -> PrimOp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimOp
DataToTagSmallOp Bool -> Bool -> Bool
|| PrimOp
primOp PrimOp -> PrimOp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimOp
DataToTagLargeOp
-> (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op) ([StgArg] -> Type -> TgStgExpr)
-> RM [StgArg] -> RM (Type -> TgStgExpr)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> [StgArg] -> RM [StgArg]
rewriteArgs [StgArg]
args RM (Type -> TgStgExpr) -> RM Type -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
res_ty
StgOp
_ -> TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
args Type
res_ty
rewriteOpApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
id Id
bndr !TgStgExpr
expr =
let altTy :: AltType
altTy = Id -> [GenStgAlt 'CodeGen] -> AltType
forall (p :: StgPass). Id -> [GenStgAlt p] -> AltType
mkStgAltTypeFromStgAlts Id
bndr [GenStgAlt 'CodeGen]
alt
alt :: [GenStgAlt 'CodeGen]
alt = [GenStgAlt {alt_con :: AltCon
alt_con = AltCon
DEFAULT, alt_bndrs :: [BinderP 'CodeGen]
alt_bndrs = [], alt_rhs :: TgStgExpr
alt_rhs = TgStgExpr
expr}]
in TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase (Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id []) Id
BinderP 'CodeGen
bndr AltType
altTy [GenStgAlt 'CodeGen]
alt
{-# INLINE mkSeqs #-}
mkSeqs :: [StgArg]
-> [Id]
-> ([StgArg] -> TgStgExpr)
-> RM TgStgExpr
mkSeqs :: [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
untaggedIds [StgArg] -> TgStgExpr
mkExpr = do
argMap <- (Id -> RM (Id, Id)) -> [Id] -> RM [(Id, Id)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Id
arg -> (Id
arg,) (Id -> (Id, Id)) -> RM Id -> RM (Id, Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> RM Id
mkLocalArgId Id
arg ) [Id]
untaggedIds :: RM [(InId, OutId)]
let taggedArgs :: [StgArg]
= map (\StgArg
v -> case StgArg
v of
StgVarArg Id
v' -> Id -> StgArg
StgVarArg (Id -> StgArg) -> Id -> StgArg
forall a b. (a -> b) -> a -> b
$ Id -> Maybe Id -> Id
forall a. a -> Maybe a -> a
fromMaybe Id
v' (Maybe Id -> Id) -> Maybe Id -> Id
forall a b. (a -> b) -> a -> b
$ Id -> [(Id, Id)] -> Maybe Id
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Id
v' [(Id, Id)]
argMap
StgArg
lit -> StgArg
lit)
args
let conBody = [StgArg] -> TgStgExpr
mkExpr [StgArg]
taggedArgs
let body = ((Id, Id) -> TgStgExpr -> TgStgExpr)
-> TgStgExpr -> [(Id, Id)] -> TgStgExpr
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Id
v,Id
bndr) TgStgExpr
expr -> Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
v Id
bndr TgStgExpr
expr) TgStgExpr
conBody [(Id, Id)]
argMap
return $! body
getStrictConArgs :: Outputable a => DataCon -> [a] -> [a]
getStrictConArgs :: forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con [a]
args
| DataCon -> Bool
isUnboxedTupleDataCon DataCon
con = []
| DataCon -> Bool
isUnboxedSumDataCon DataCon
con = []
| Bool
otherwise =
Bool -> SDoc -> [a] -> [a]
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StrictnessMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (HasDebugCallStack => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con))
(String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Mismatched con arg and con rep strictness lengths:" SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$
String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Con" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> DataCon -> SDoc
forall a. Outputable a => a -> SDoc
ppr DataCon
con SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"is applied to" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [a] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [a]
args SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$
String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"But seems to have arity" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Int -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([StrictnessMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StrictnessMark]
repStrictness)) ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$
[ a
arg | (a
arg,StrictnessMark
MarkedStrict)
<- String -> [a] -> [StrictnessMark] -> [(a, StrictnessMark)]
forall a b. HasDebugCallStack => String -> [a] -> [b] -> [(a, b)]
zipEqual String
"getStrictConArgs"
[a]
args
[StrictnessMark]
repStrictness]
where
repStrictness :: [StrictnessMark]
repStrictness = (HasDebugCallStack => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con)