-- Copyright (c) 2019 Andreas Klebinger
--

{-# LANGUAGE CPP                        #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE TypeFamilies               #-}

module GHC.Stg.EnforceEpt.Rewrite (rewriteTopBinds, rewriteOpApp)
where

import GHC.Prelude

import GHC.Builtin.PrimOps ( PrimOp(..) )
import GHC.Types.Basic     ( CbvMark (..), isMarkedCbv
                           , TopLevelFlag(..), isTopLevel )
import GHC.Types.Id
import GHC.Types.Name
import GHC.Types.Unique.Supply
import GHC.Types.Unique.FM
import GHC.Types.RepType
import GHC.Types.Var.Set
import GHC.Unit.Types

import GHC.Core.DataCon
import GHC.Core            ( AltCon(..) )
import GHC.Core.Type

import GHC.StgToCmm.Types
import GHC.StgToCmm.Closure (importedIdLFInfo)

import GHC.Stg.Utils
import GHC.Stg.Syntax as StgSyn

import GHC.Data.Maybe
import GHC.Utils.Panic

import GHC.Utils.Outputable
import GHC.Utils.Monad.State.Strict
import GHC.Utils.Misc

import GHC.Stg.EnforceEpt.Types

import Control.Monad

newtype RM a = RM { forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM :: (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a) }
    deriving ((forall a b. (a -> b) -> RM a -> RM b)
-> (forall a b. a -> RM b -> RM a) -> Functor RM
forall a b. a -> RM b -> RM a
forall a b. (a -> b) -> RM a -> RM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> RM a -> RM b
fmap :: forall a b. (a -> b) -> RM a -> RM b
$c<$ :: forall a b. a -> RM b -> RM a
<$ :: forall a b. a -> RM b -> RM a
Functor, Applicative RM
Applicative RM =>
(forall a b. RM a -> (a -> RM b) -> RM b)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a. a -> RM a)
-> Monad RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM a -> (a -> RM b) -> RM b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. RM a -> (a -> RM b) -> RM b
>>= :: forall a b. RM a -> (a -> RM b) -> RM b
$c>> :: forall a b. RM a -> RM b -> RM b
>> :: forall a b. RM a -> RM b -> RM b
$creturn :: forall a. a -> RM a
return :: forall a. a -> RM a
Monad, Functor RM
Functor RM =>
(forall a. a -> RM a)
-> (forall a b. RM (a -> b) -> RM a -> RM b)
-> (forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c)
-> (forall a b. RM a -> RM b -> RM b)
-> (forall a b. RM a -> RM b -> RM a)
-> Applicative RM
forall a. a -> RM a
forall a b. RM a -> RM b -> RM a
forall a b. RM a -> RM b -> RM b
forall a b. RM (a -> b) -> RM a -> RM b
forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> RM a
pure :: forall a. a -> RM a
$c<*> :: forall a b. RM (a -> b) -> RM a -> RM b
<*> :: forall a b. RM (a -> b) -> RM a -> RM b
$cliftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
liftA2 :: forall a b c. (a -> b -> c) -> RM a -> RM b -> RM c
$c*> :: forall a b. RM a -> RM b -> RM b
*> :: forall a b. RM a -> RM b -> RM b
$c<* :: forall a b. RM a -> RM b -> RM a
<* :: forall a b. RM a -> RM b -> RM a
Applicative)

------------------------------------------------------------
-- Add cases around strict fields where required.
------------------------------------------------------------
{-
The work of this pass is simple:
* We traverse the STG AST looking for constructor allocations.
* For all allocations we check if there are strict fields in the constructor.
* For any strict field we check if the argument is known to be properly tagged.
* If it's not known to be properly tagged, we wrap the whole thing in a case,
  which will force the argument before allocation.
This is described in detail in Note [Evaluated and Properly Tagged].

The only slight complication is that we have to make sure not to invalidate free
variable analysis in the process.

Note [Partially applied workers]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Sometimes we will get a function f of the form
    -- Arity 1
    f :: Dict a -> a -> b -> (c -> d)
    f dict a b = case dict of
        C m1 m2 -> m1 a b

Which will result in a W/W split along the lines of
    -- Arity 1
    f :: Dict a -> a -> b -> (c -> d)
    f dict a = case dict of
        C m1 m2 -> $wf m1 a b

    -- Arity 4
    $wf :: (a -> b -> d -> c) -> a -> b -> c -> d
    $wf m1 a b c = m1 a b c

It's notable that the worker is called *undersaturated* in the wrapper.
At runtime what happens is that the wrapper will allocate a PAP which
once fully applied will call the worker. And all is fine.

But what about a call by value function! Well the function returned by `f` would
be a unknown call, so we lose the ability to enforce the invariant that
cbv marked arguments from StictWorkerId's are actually properly tagged
as the annotations would be unavailable at the (unknown) call site.

The fix is easy. We eta-expand all calls to functions taking call-by-value
arguments during CorePrep just like we do with constructor allocations.

Note [Upholding free variable annotations]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The code generator requires us to maintain exact information
about free variables about closures. Since we convert some
RHSs from constructor allocations to closures we have to provide
fvs of these closures. Not all constructor arguments will become
free variables. Only these which are not bound at the top level
have to be captured.
To facilitate this we keep track of a set of locally bound variables in
the current context which we then use to filter constructor arguments
when building the free variable list.
-}

--------------------------------
-- Utilities
--------------------------------

instance MonadUnique RM where
    getUniqueSupplyM :: RM UniqSupply
getUniqueSupplyM = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
 -> RM UniqSupply)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) UniqSupply
-> RM UniqSupply
forall a b. (a -> b) -> a -> b
$ do
        (m, us, mod,lcls) <- State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
        let (us1, us2) = splitUniqSupply us
        (put) (m,us2,mod,lcls)
        return us1

getMap :: RM (UniqFM Id TagSig)
getMap :: RM (UniqFM Id TagSig)
getMap = State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State
   (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
 -> RM (UniqFM Id TagSig))
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
-> RM (UniqFM Id TagSig)
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
fst,UniqSupply
_,Module
_,IdSet
_) -> UniqFM Id TagSig
fst) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> UniqFM Id TagSig)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet) (UniqFM Id TagSig)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)

setMap :: (UniqFM Id TagSig) -> RM ()
setMap :: UniqFM Id TagSig -> RM ()
setMap !UniqFM Id TagSig
m = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
    (_,us,mod,lcls) <- State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
    put (m, us,mod,lcls)

getMod :: RM Module
getMod :: RM Module
getMod = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
 -> RM Module)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
-> RM Module
forall a b. (a -> b) -> a -> b
$ ( (\(UniqFM Id TagSig
_,UniqSupply
_,Module
thrd,IdSet
_) -> Module
thrd) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> Module)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) Module
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)

getFVs :: RM IdSet
getFVs :: RM IdSet
getFVs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
 -> RM IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
-> RM IdSet
forall a b. (a -> b) -> a -> b
$ ((\(UniqFM Id TagSig
_,UniqSupply
_,Module
_,IdSet
lcls) -> IdSet
lcls) ((UniqFM Id TagSig, UniqSupply, Module, IdSet) -> IdSet)
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
     (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) IdSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get)

setFVs :: IdSet -> RM ()
setFVs :: IdSet -> RM ()
setFVs !IdSet
fvs = State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a.
State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a -> RM a
RM (State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ())
-> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) () -> RM ()
forall a b. (a -> b) -> a -> b
$ do
    (tag_map,us,mod,_lcls) <- State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
  (UniqFM Id TagSig, UniqSupply, Module, IdSet)
forall s. State s s
get
    put (tag_map, us,mod,fvs)

-- Rewrite the RHS(s) while making the id and it's sig available
-- to determine if things are tagged/need to be captured as FV.
withBind :: TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind :: forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag (StgNonRec BinderP 'InferTaggedBinders
bnd GenStgRhs 'InferTaggedBinders
_) RM a
cont = TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id, TagSig)
BinderP 'InferTaggedBinders
bnd RM a
cont
withBind TopLevelFlag
top_flag (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) RM a
cont = do
    let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds :: ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
    TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
top_flag [(Id, TagSig)]
bnds RM a
cont

addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind :: GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind (StgNonRec (Id
id, TagSig
tag) GenStgRhs 'InferTaggedBinders
_) = do
    s <- RM (UniqFM Id TagSig)
getMap
    -- pprTraceM "AddBind" (ppr id)
    setMap $ addToUFM s id tag
    return ()
addTopBind (StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) = do
    let ([(Id, TagSig)]
bnds,[GenStgRhs 'InferTaggedBinders]
_rhss) = [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
-> ([(Id, TagSig)], [GenStgRhs 'InferTaggedBinders])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
    !s <- RM (UniqFM Id TagSig)
getMap
    -- pprTraceM "AddBinds" (ppr $ map fst bnds)
    setMap $! addListToUFM s bnds

withBinder :: TopLevelFlag ->  (Id, TagSig) -> RM a -> RM a
withBinder :: forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
top_flag (Id
id,TagSig
sig) RM a
cont = do
    oldMap <- RM (UniqFM Id TagSig)
getMap
    setMap $ addToUFM oldMap id sig
    a <- if isTopLevel top_flag
            then cont
            else withLcl id cont
    setMap oldMap
    return a

withBinders :: TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders :: forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
TopLevel [(Id, TagSig)]
sigs RM a
cont = do
    oldMap <- RM (UniqFM Id TagSig)
getMap
    setMap $ addListToUFM oldMap sigs
    a <- cont
    setMap oldMap
    return a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
sigs RM a
cont = do
    oldMap <- RM (UniqFM Id TagSig)
getMap
    oldFvs <- getFVs
    setMap $ addListToUFM oldMap sigs
    setFVs $ extendVarSetList oldFvs (map fst sigs)
    a <- cont
    setMap oldMap
    setFVs oldFvs
    return a

-- | Compute the argument with the given set of ids treated as requiring capture
-- as free variables.
withClosureLcls :: DIdSet -> RM a -> RM a
withClosureLcls :: forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
fvs RM a
act = do
    old_fvs <- RM IdSet
getFVs
    let !fvs' = (Id -> IdSet -> IdSet) -> IdSet -> DIdSet -> IdSet
forall a. (Id -> a -> a) -> a -> DIdSet -> a
nonDetStrictFoldDVarSet ((IdSet -> Id -> IdSet) -> Id -> IdSet -> IdSet
forall a b c. (a -> b -> c) -> b -> a -> c
flip IdSet -> Id -> IdSet
extendVarSet) IdSet
old_fvs DIdSet
fvs
    setFVs fvs'
    !r <- act
    setFVs old_fvs
    return r

-- | Compute the argument with the given id treated as requiring capture
-- as free variables in closures.
withLcl :: Id -> RM a -> RM a
withLcl :: forall a. Id -> RM a -> RM a
withLcl Id
fv RM a
act = do
    old_fvs <- RM IdSet
getFVs
    let !fvs' = IdSet -> Id -> IdSet
extendVarSet IdSet
old_fvs Id
fv
    setFVs fvs'
    !r <- act
    setFVs old_fvs
    return r

{- Note [Tag inference for interactive contexts]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When compiling bytecode we call myCoreToStg to get STG code first.
myCoreToStg in turn calls out to stg2stg which runs the STG to STG
passes followed by free variables analysis and the tag inference pass including
its rewriting phase at the end.
Running tag inference is important as it upholds Note [Evaluated and Properly Tagged].
While code executed by GHCi doesn't take advantage of the SFI it can call into
compiled code which does. So it must still make sure that the SFI is upheld.
See also #21083 and #22042.

However there one important difference in code generation for GHCi and regular
compilation. When compiling an entire module (not a GHCi expression), we call
`stg2stg` on the entire module which allows us to build up a map which is guaranteed
to have an entry for every binder in the current module.
For non-interactive compilation the tag inference rewrite pass takes advantage
of this by building up a map from binders to their tag signatures.

When compiling a GHCi expression on the other hand we invoke stg2stg separately
for each expression on the prompt. This means in GHCi for a sequence of:
    > let x = True
    > let y = StrictJust x
We first run stg2stg for `[x = True]`. And then again for [y = StrictJust x]`.

While computing the tag signature for `y` during tag inference inferConTag will check
if `x` is already tagged by looking up the tagsig of `x` in the binder->signature mapping.
However since this mapping isn't persistent between stg2stg
invocations the lookup will fail. This isn't a correctness issue since it's always
safe to assume a binding isn't tagged and that's what we do in such cases.

However for non-interactive mode we *don't* want to do this. Since in non-interactive mode
we have all binders of the module available for each invocation we can expect the binder->signature
mapping to be complete and all lookups to succeed. This means in non-interactive contexts a failed lookup
indicates a bug in the tag inference implementation.
For this reason we assert that we are running in interactive mode if a lookup fails.
-}
isTagged :: Id -> RM Bool
isTagged :: Id -> RM Bool
isTagged Id
v
    -- See Note [Bottom functions are TagTagged]
    | Id -> Bool
isDeadEndId Id
v = Bool -> RM Bool
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
    | Bool
otherwise = do
    this_mod <- RM Module
getMod
    -- See Note [Tag inference for interactive contexts]
    let lookupDefault Id
v = Bool -> SDoc -> TagSig -> TagSig
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Module -> Bool
isInteractiveModule Module
this_mod)
                                    (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"unknown Id:" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
this_mod SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
v)
                                    (TagInfo -> TagSig
TagSig TagInfo
TagDunno)
    case nameIsLocalOrFrom this_mod (idName v) of
        Bool
True
            | Type -> Bool
definitelyUnliftedType (Id -> Type
idType Id
v)
              -- NB: v might be the Id of a representation-polymorphic join point,
              -- so we shouldn't use isUnliftedType here. See T22212.
            -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
            | Bool
otherwise -> do -- Local binding
                !s <- RM (UniqFM Id TagSig)
getMap
                let !sig = UniqFM Id TagSig -> TagSig -> Id -> TagSig
forall key elt.
Uniquable key =>
UniqFM key elt -> elt -> key -> elt
lookupWithDefaultUFM UniqFM Id TagSig
s (Id -> TagSig
lookupDefault Id
v) Id
v
                return $ case sig of
                    TagSig TagInfo
info ->
                        case TagInfo
info of
                            TagInfo
TagDunno -> Bool
False
                            TagInfo
TagProper -> Bool
True
                            TagInfo
TagTagged -> Bool
True
                            TagTuple [TagInfo]
_ -> Bool
True -- Consider unboxed tuples tagged.
        -- Imported
        Bool
False -> Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> RM Bool) -> Bool -> RM Bool
forall a b. (a -> b) -> a -> b
$!
                -- Determine whether it is tagged from the LFInfo of the imported id.
                -- See Note [The LFInfo of Imported Ids]
                case Id -> LambdaFormInfo
importedIdLFInfo Id
v of
                    -- Function, applied not entered.
                    LFReEntrant {}
                        -> Bool
True
                    -- Thunks need to be entered.
                    LFThunk {}
                        -> Bool
False
                    -- LFCon means we already know the tag, and it's tagged.
                    LFCon {}
                        -> Bool
True
                    LFUnknown {}
                        -> Bool
False
                    LFUnlifted {}
                        -> Bool
True
                    LFLetNoEscape {}
                    -- Shouldn't be possible. I don't think we can export letNoEscapes
                        -> Bool
True


isArgTagged :: StgArg -> RM Bool
isArgTagged :: StgArg -> RM Bool
isArgTagged (StgLitArg Literal
_) = Bool -> RM Bool
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
isArgTagged (StgVarArg Id
v) = Id -> RM Bool
isTagged Id
v

mkLocalArgId :: Id -> RM Id
mkLocalArgId :: Id -> RM Id
mkLocalArgId Id
id = do
    !u <- RM Unique
forall (m :: * -> *). MonadUnique m => m Unique
getUniqueM
    return $! setIdUnique (localiseId id) u

---------------------------
-- Actual rewrite pass
---------------------------


rewriteTopBinds :: Module -> UniqSupply -> [GenStgTopBinding 'InferTaggedBinders] -> [TgStgTopBinding]
rewriteTopBinds :: Module
-> UniqSupply
-> [GenStgTopBinding 'InferTaggedBinders]
-> [TgStgTopBinding]
rewriteTopBinds Module
mod UniqSupply
us [GenStgTopBinding 'InferTaggedBinders]
binds =
    let doBinds :: RM [TgStgTopBinding]
doBinds = (GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding)
-> [GenStgTopBinding 'InferTaggedBinders] -> RM [TgStgTopBinding]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop [GenStgTopBinding 'InferTaggedBinders]
binds

    in State
  (UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
-> (UniqFM Id TagSig, UniqSupply, Module, IdSet)
-> [TgStgTopBinding]
forall s a. State s a -> s -> a
evalState (RM [TgStgTopBinding]
-> State
     (UniqFM Id TagSig, UniqSupply, Module, IdSet) [TgStgTopBinding]
forall a.
RM a -> State (UniqFM Id TagSig, UniqSupply, Module, IdSet) a
unRM RM [TgStgTopBinding]
doBinds) (UniqFM Id TagSig
forall a. Monoid a => a
mempty, UniqSupply
us, Module
mod, IdSet
forall a. Monoid a => a
mempty)

rewriteTop :: InferStgTopBinding -> RM TgStgTopBinding
rewriteTop :: GenStgTopBinding 'InferTaggedBinders -> RM TgStgTopBinding
rewriteTop (StgTopStringLit Id
v ByteString
s) = TgStgTopBinding -> RM TgStgTopBinding
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgTopBinding -> RM TgStgTopBinding)
-> TgStgTopBinding -> RM TgStgTopBinding
forall a b. (a -> b) -> a -> b
$! (Id -> ByteString -> TgStgTopBinding
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit Id
v ByteString
s)
rewriteTop (StgTopLifted GenStgBinding 'InferTaggedBinders
bind)   = do
    -- Top level bindings can, and must remain in scope
    GenStgBinding 'InferTaggedBinders -> RM ()
addTopBind GenStgBinding 'InferTaggedBinders
bind
    (GenStgBinding 'CodeGen -> TgStgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted) (GenStgBinding 'CodeGen -> TgStgTopBinding)
-> RM (GenStgBinding 'CodeGen) -> RM TgStgTopBinding
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> (TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
TopLevel GenStgBinding 'InferTaggedBinders
bind)

-- For top level binds, the wrapper is guaranteed to be `id`
rewriteBinds :: TopLevelFlag -> InferStgBinding -> RM (TgStgBinding)
rewriteBinds :: TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
_top_flag (StgNonRec BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs) = do
        (!rhs) <-  (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id, TagSig)
BinderP 'InferTaggedBinders
v GenStgRhs 'InferTaggedBinders
rhs
        return $! (StgNonRec (fst v) rhs)
rewriteBinds TopLevelFlag
top_flag b :: GenStgBinding 'InferTaggedBinders
b@(StgRec [(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) =
    -- Bring sigs of binds into scope for all rhss
    TopLevelFlag
-> GenStgBinding 'InferTaggedBinders
-> RM (GenStgBinding 'CodeGen)
-> RM (GenStgBinding 'CodeGen)
forall a.
TopLevelFlag -> GenStgBinding 'InferTaggedBinders -> RM a -> RM a
withBind TopLevelFlag
top_flag GenStgBinding 'InferTaggedBinders
b (RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen))
-> RM (GenStgBinding 'CodeGen) -> RM (GenStgBinding 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
        (rhss) <- (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> RM [TgStgRhs]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (((Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs)
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> RM TgStgRhs
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds
        return $! (mkRec rhss)
        where
            mkRec :: [TgStgRhs] -> TgStgBinding
            mkRec :: [TgStgRhs] -> GenStgBinding 'CodeGen
mkRec [TgStgRhs]
rhss = [(BinderP 'CodeGen, TgStgRhs)] -> GenStgBinding 'CodeGen
forall (pass :: StgPass).
[(BinderP pass, GenStgRhs pass)] -> GenStgBinding pass
StgRec ([Id] -> [TgStgRhs] -> [(Id, TgStgRhs)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> Id)
-> [((Id, TagSig), GenStgRhs 'InferTaggedBinders)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst ((Id, TagSig) -> Id)
-> (((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig))
-> ((Id, TagSig), GenStgRhs 'InferTaggedBinders)
-> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Id, TagSig), GenStgRhs 'InferTaggedBinders) -> (Id, TagSig)
forall a b. (a, b) -> a
fst) [((Id, TagSig), GenStgRhs 'InferTaggedBinders)]
[(BinderP 'InferTaggedBinders, GenStgRhs 'InferTaggedBinders)]
binds) [TgStgRhs]
rhss)

-- Rewrite a RHS
rewriteRhs :: (Id,TagSig) -> InferStgRhs
           -> RM (TgStgRhs)
rewriteRhs :: (Id, TagSig) -> GenStgRhs 'InferTaggedBinders -> RM TgStgRhs
rewriteRhs (Id
_id, TagSig
_tagSig) (StgRhsCon CostCentreStack
ccs DataCon
con ConstructorNumber
cn [StgTickish]
ticks [StgArg]
args Type
typ) = {-# SCC rewriteRhs_ #-} do
    -- pprTraceM "rewriteRhs" (ppr _id)

    -- Look up the nodes representing the constructor arguments.
    fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args

    -- Filter out non-strict fields.
    let strictFields =
            DataCon -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([StgArg] -> [Bool] -> [(StgArg, Bool)]
forall a b. [a] -> [b] -> [(a, b)]
zip [StgArg]
args [Bool]
fieldInfos) :: [(StgArg,Bool)] -- (nth-argument, tagInfo)
    -- Filter out already tagged arguments.
    let needsEval = ((StgArg, Bool) -> StgArg) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, Bool) -> StgArg
forall a b. (a, b) -> a
fst ([(StgArg, Bool)] -> [StgArg])
-> ([(StgArg, Bool)] -> [(StgArg, Bool)])
-> [(StgArg, Bool)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. --get the actual argument
                        ((StgArg, Bool) -> Bool) -> [(StgArg, Bool)] -> [(StgArg, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((StgArg, Bool) -> Bool) -> (StgArg, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (StgArg, Bool) -> Bool
forall a b. (a, b) -> b
snd) ([(StgArg, Bool)] -> [StgArg]) -> [(StgArg, Bool)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$ -- Keep untagged (False) elements.
                        [(StgArg, Bool)]
strictFields :: [StgArg]
    let evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]

    if (null evalArgs)
        then return $! (StgRhsCon ccs con cn ticks args typ)
        else do
            --assert not (isTaggedSig tagSig)
            -- pprTraceM "CreatingSeqs for " $ ppr _id <+> ppr node_id

            -- At this point iff we have  possibly untagged arguments to strict fields
            -- we convert the RHS into a RhsClosure which will evaluate the arguments
            -- before allocating the constructor.
            let ty_stub = String -> a
forall a. HasCallStack => String -> a
panic String
"mkSeqs shouldn't use the type arg"
            conExpr <- mkSeqs args evalArgs (\[StgArg]
taggedArgs -> DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [[PrimRep]]
forall {a}. a
ty_stub)

            fvs <- fvArgs args
            -- lcls <- getFVs
            -- pprTraceM "RhsClosureConversion" (ppr (StgRhsClosure fvs ccs ReEntrant [] $! conExpr) $$ text "lcls:" <> ppr lcls)

            -- We mark the closure updatable to retain sharing in the case that
            -- conExpr is an infinite recursive data type. See #23783.
            return $! (StgRhsClosure fvs ccs Updatable [] $! conExpr) typ
rewriteRhs (Id, TagSig)
_binding (StgRhsClosure XRhsClosure 'InferTaggedBinders
fvs CostCentreStack
ccs UpdateFlag
flag [BinderP 'InferTaggedBinders]
args GenStgExpr 'InferTaggedBinders
body Type
typ) = do
    TopLevelFlag -> [(Id, TagSig)] -> RM TgStgRhs -> RM TgStgRhs
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
        DIdSet -> RM TgStgRhs -> RM TgStgRhs
forall a. DIdSet -> RM a -> RM a
withClosureLcls DIdSet
XRhsClosure 'InferTaggedBinders
fvs (RM TgStgRhs -> RM TgStgRhs) -> RM TgStgRhs -> RM TgStgRhs
forall a b. (a -> b) -> a -> b
$
            XRhsClosure 'CodeGen
-> CostCentreStack
-> UpdateFlag
-> [BinderP 'CodeGen]
-> TgStgExpr
-> Type
-> TgStgRhs
forall (pass :: StgPass).
XRhsClosure pass
-> CostCentreStack
-> UpdateFlag
-> [BinderP pass]
-> GenStgExpr pass
-> Type
-> GenStgRhs pass
StgRhsClosure XRhsClosure 'InferTaggedBinders
XRhsClosure 'CodeGen
fvs CostCentreStack
ccs UpdateFlag
flag (((Id, TagSig) -> Id) -> [(Id, TagSig)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (Id, TagSig) -> Id
forall a b. (a, b) -> a
fst [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
args) (TgStgExpr -> Type -> TgStgRhs)
-> RM TgStgExpr -> RM (Type -> TgStgRhs)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
body RM (Type -> TgStgRhs) -> RM Type -> RM TgStgRhs
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ
        -- return (closure)

fvArgs :: [StgArg] -> RM DVarSet
fvArgs :: [StgArg] -> RM DIdSet
fvArgs [StgArg]
args = do
    fv_lcls <- RM IdSet
getFVs
    -- pprTraceM "fvArgs" (text "args:" <> ppr args $$ text "lcls:" <> pprVarSet (fv_lcls) (braces . fsep . map ppr) )
    return $ mkDVarSet [ v | StgVarArg v <- args, elemVarSet v fv_lcls]

rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs :: [StgArg] -> RM [StgArg]
rewriteArgs = (StgArg -> RM StgArg) -> [StgArg] -> RM [StgArg]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM StgArg
rewriteArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg :: StgArg -> RM StgArg
rewriteArg (StgVarArg Id
v) = Id -> StgArg
StgVarArg (Id -> StgArg) -> RM Id -> RM StgArg
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> Id -> RM Id
rewriteId Id
v
rewriteArg  (lit :: StgArg
lit@StgLitArg{}) = StgArg -> RM StgArg
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return StgArg
lit

rewriteId :: Id -> RM Id
rewriteId :: Id -> RM Id
rewriteId Id
v = do
    !is_tagged <- Id -> RM Bool
isTagged Id
v
    if is_tagged then return $! setIdTagSig v (TagSig TagProper)
                 else return v

rewriteExpr :: GenStgExpr 'InferTaggedBinders -> RM (GenStgExpr 'CodeGen)
rewriteExpr :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgCase {})            = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase GenStgExpr 'InferTaggedBinders
e
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgLet {})             = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet GenStgExpr 'InferTaggedBinders
e
rewriteExpr (e :: GenStgExpr 'InferTaggedBinders
e@StgLetNoEscape {})     = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
e
rewriteExpr (StgTick StgTickish
t GenStgExpr 'InferTaggedBinders
e)             = StgTickish -> TgStgExpr -> TgStgExpr
forall (pass :: StgPass).
StgTickish -> GenStgExpr pass -> GenStgExpr pass
StgTick StgTickish
t (TgStgExpr -> TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
e
rewriteExpr e :: GenStgExpr 'InferTaggedBinders
e@(StgConApp {})          = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr e :: GenStgExpr 'InferTaggedBinders
e@(StgApp {})             = GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp GenStgExpr 'InferTaggedBinders
e
rewriteExpr (StgLit Literal
lit)              = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! (Literal -> TgStgExpr
forall (pass :: StgPass). Literal -> GenStgExpr pass
StgLit Literal
lit)
rewriteExpr (StgOpApp StgOp
op [StgArg]
args Type
res_ty) = (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op) ([StgArg] -> Type -> TgStgExpr)
-> RM [StgArg] -> RM (Type -> TgStgExpr)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> [StgArg] -> RM [StgArg]
rewriteArgs [StgArg]
args RM (Type -> TgStgExpr) -> RM Type -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
res_ty


rewriteCase :: InferStgExpr -> RM TgStgExpr
rewriteCase :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteCase (StgCase GenStgExpr 'InferTaggedBinders
scrut BinderP 'InferTaggedBinders
bndr AltType
alt_type [GenStgAlt 'InferTaggedBinders]
alts) =
    TopLevelFlag -> (Id, TagSig) -> RM TgStgExpr -> RM TgStgExpr
forall a. TopLevelFlag -> (Id, TagSig) -> RM a -> RM a
withBinder TopLevelFlag
NotTopLevel (Id, TagSig)
BinderP 'InferTaggedBinders
bndr (RM TgStgExpr -> RM TgStgExpr) -> RM TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$
        (TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM
     (TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase RM
  (TgStgExpr -> Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM TgStgExpr
-> RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
scrut RM (Id -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM Id -> RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            Id -> RM Id
rewriteId ((Id, TagSig) -> Id
forall a b. (a, b) -> a
fst (Id, TagSig)
BinderP 'InferTaggedBinders
bndr) RM (AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM AltType -> RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            AltType -> RM AltType
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure AltType
alt_type RM ([GenStgAlt 'CodeGen] -> TgStgExpr)
-> RM [GenStgAlt 'CodeGen] -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
            (GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen))
-> [GenStgAlt 'InferTaggedBinders] -> RM [GenStgAlt 'CodeGen]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt [GenStgAlt 'InferTaggedBinders]
alts

rewriteCase GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible: nodeCase"

rewriteAlt :: InferStgAlt -> RM TgStgAlt
rewriteAlt :: GenStgAlt 'InferTaggedBinders -> RM (GenStgAlt 'CodeGen)
rewriteAlt alt :: GenStgAlt 'InferTaggedBinders
alt@GenStgAlt{alt_con :: forall (pass :: StgPass). GenStgAlt pass -> AltCon
alt_con=AltCon
_, alt_bndrs :: forall (pass :: StgPass). GenStgAlt pass -> [BinderP pass]
alt_bndrs=[BinderP 'InferTaggedBinders]
bndrs, alt_rhs :: forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs=GenStgExpr 'InferTaggedBinders
rhs} =
    TopLevelFlag
-> [(Id, TagSig)]
-> RM (GenStgAlt 'CodeGen)
-> RM (GenStgAlt 'CodeGen)
forall a. TopLevelFlag -> [(Id, TagSig)] -> RM a -> RM a
withBinders TopLevelFlag
NotTopLevel [(Id, TagSig)]
[BinderP 'InferTaggedBinders]
bndrs (RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen))
-> RM (GenStgAlt 'CodeGen) -> RM (GenStgAlt 'CodeGen)
forall a b. (a -> b) -> a -> b
$ do
        !rhs' <- GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteExpr GenStgExpr 'InferTaggedBinders
rhs
        return $! alt {alt_bndrs = map fst bndrs, alt_rhs = rhs'}

rewriteLet :: InferStgExpr -> RM TgStgExpr
rewriteLet :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLet (StgLet XLet 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
    (!bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
    withBind NotTopLevel bind $ do
        -- pprTraceM "withBindLet" (ppr $ bindersOfX bind)
        !expr' <- rewriteExpr expr
        return $! (StgLet xt bind' expr')
rewriteLet GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

rewriteLetNoEscape :: InferStgExpr -> RM TgStgExpr
rewriteLetNoEscape :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteLetNoEscape (StgLetNoEscape XLetNoEscape 'InferTaggedBinders
xt GenStgBinding 'InferTaggedBinders
bind GenStgExpr 'InferTaggedBinders
expr) = do
    (!bind') <- TopLevelFlag
-> GenStgBinding 'InferTaggedBinders -> RM (GenStgBinding 'CodeGen)
rewriteBinds TopLevelFlag
NotTopLevel GenStgBinding 'InferTaggedBinders
bind
    withBind NotTopLevel bind $ do
        !expr' <- rewriteExpr expr
        return $! (StgLetNoEscape xt bind' expr')
rewriteLetNoEscape GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

rewriteConApp :: InferStgExpr -> RM TgStgExpr
rewriteConApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteConApp (StgConApp DataCon
con ConstructorNumber
cn [StgArg]
args [[PrimRep]]
tys) = do
    -- We check if the strict field arguments are already known to be tagged.
    -- If not we evaluate them first.
    fieldInfos <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
    let strictIndices = DataCon -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con ([Bool] -> [StgArg] -> [(Bool, StgArg)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
fieldInfos [StgArg]
args) :: [(Bool, StgArg)]
    let needsEval = ((Bool, StgArg) -> StgArg) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, StgArg) -> StgArg
forall a b. (a, b) -> b
snd ([(Bool, StgArg)] -> [StgArg])
-> ([(Bool, StgArg)] -> [(Bool, StgArg)])
-> [(Bool, StgArg)]
-> [StgArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, StgArg) -> Bool) -> [(Bool, StgArg)] -> [(Bool, StgArg)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Bool, StgArg) -> Bool) -> (Bool, StgArg) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Bool, StgArg) -> Bool
forall a b. (a, b) -> a
fst) ([(Bool, StgArg)] -> [StgArg]) -> [(Bool, StgArg)] -> [StgArg]
forall a b. (a -> b) -> a -> b
$ [(Bool, StgArg)]
strictIndices :: [StgArg]
    let evalArgs = [Id
v | StgVarArg Id
v <- [StgArg]
needsEval] :: [Id]
    if (not $ null evalArgs)
        then do
            -- pprTraceM "Creating conAppSeqs for " $ ppr nodeId <+> parens ( ppr evalArgs ) -- <+> parens ( ppr fieldInfos )
            mkSeqs args evalArgs (\[StgArg]
taggedArgs -> DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> TgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
StgConApp DataCon
con ConstructorNumber
cn [StgArg]
taggedArgs [[PrimRep]]
tys)
        else return $! (StgConApp con cn args tys)

rewriteConApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

-- Special case: Atomic binders, usually in a case context like `case f of ...`.
rewriteApp :: InferStgExpr -> RM TgStgExpr
rewriteApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteApp (StgApp Id
f []) = do
    f' <- Id -> RM Id
rewriteId Id
f
    return $! StgApp f' []
rewriteApp (StgApp Id
f [StgArg]
args)
    -- pprTrace "rewriteAppOther" (ppr f <+> ppr args) False
    -- = undefined
    | Just [CbvMark]
marks <- Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f
    , [CbvMark]
relevant_marks <- (CbvMark -> Bool) -> [CbvMark] -> [CbvMark]
forall a. (a -> Bool) -> [a] -> [a]
dropWhileEndLE (Bool -> Bool
not (Bool -> Bool) -> (CbvMark -> Bool) -> CbvMark -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CbvMark -> Bool
isMarkedCbv) [CbvMark]
marks
    , (CbvMark -> Bool) -> [CbvMark] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any CbvMark -> Bool
isMarkedCbv [CbvMark]
relevant_marks
    = Bool
-> SDoc -> ([CbvMark] -> RM TgStgExpr) -> [CbvMark] -> RM TgStgExpr
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr ([CbvMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [CbvMark]
relevant_marks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [StgArg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StgArg]
args) (Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
f SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [StgArg] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [StgArg]
args SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$ [CbvMark] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CbvMark]
relevant_marks)
      [CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks

    where
      -- If the function expects any argument to be call-by-value ensure the argument is already
      -- evaluated.
      unliftArg :: [CbvMark] -> RM TgStgExpr
unliftArg [CbvMark]
relevant_marks = do
        argTags <- (StgArg -> RM Bool) -> [StgArg] -> RM [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM StgArg -> RM Bool
isArgTagged [StgArg]
args
        let argInfo = (StgArg -> CbvMark -> Bool -> (StgArg, CbvMark, Bool))
-> [StgArg] -> [CbvMark] -> [Bool] -> [(StgArg, CbvMark, Bool)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 ((,,)) [StgArg]
args ([CbvMark]
relevant_marks[CbvMark] -> [CbvMark] -> [CbvMark]
forall a. [a] -> [a] -> [a]
++CbvMark -> [CbvMark]
forall a. a -> [a]
repeat CbvMark
NotMarkedCbv)  [Bool]
argTags :: [(StgArg, CbvMark, Bool)]

            -- untagged cbv argument positions
            cbvArgInfo = ((StgArg, CbvMark, Bool) -> Bool)
-> [(StgArg, CbvMark, Bool)] -> [(StgArg, CbvMark, Bool)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(StgArg, CbvMark, Bool)
x -> (StgArg, CbvMark, Bool) -> CbvMark
forall a b c. (a, b, c) -> b
sndOf3 (StgArg, CbvMark, Bool)
x CbvMark -> CbvMark -> Bool
forall a. Eq a => a -> a -> Bool
== CbvMark
MarkedCbv Bool -> Bool -> Bool
&& (StgArg, CbvMark, Bool) -> Bool
forall a b c. (a, b, c) -> c
thdOf3 (StgArg, CbvMark, Bool)
x Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
False) [(StgArg, CbvMark, Bool)]
argInfo
            cbvArgIds = [Id
x | StgVarArg Id
x <- ((StgArg, CbvMark, Bool) -> StgArg)
-> [(StgArg, CbvMark, Bool)] -> [StgArg]
forall a b. (a -> b) -> [a] -> [b]
map (StgArg, CbvMark, Bool) -> StgArg
forall a b c. (a, b, c) -> a
fstOf3 [(StgArg, CbvMark, Bool)]
cbvArgInfo] :: [Id]
        mkSeqs args cbvArgIds (\[StgArg]
cbv_args -> Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
cbv_args)

rewriteApp (StgApp Id
f [StgArg]
args) = TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (m :: * -> *) a. Monad m => a -> m a
return (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$ Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
f [StgArg]
args
rewriteApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

{-
Note [Rewriting primop arguments]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Given an application `op# x y`, is it worth applying `rewriteArg` to
`x` and `y`?  All that will do will be to set the `tagSig` for that
occurrence of `x` and `y` to record whether it is evaluated and
properly tagged. For the vast majority of primops that's a waste of
time: the argument is an `Int#` or something.

But code generation for `seq#` and the `dataToTag#` ops /does/ consult that
tag, to statically avoid generating an eval.  All three do so via `cgIdApp`,
which in turn uses `getCallMethod` which looks at the `tagSig`.

So for these we should call `rewriteArgs`.

-}

rewriteOpApp :: InferStgExpr -> RM TgStgExpr
rewriteOpApp :: GenStgExpr 'InferTaggedBinders -> RM TgStgExpr
rewriteOpApp (StgOpApp StgOp
op [StgArg]
args Type
res_ty) = case StgOp
op of
  op :: StgOp
op@(StgPrimOp PrimOp
primOp)
    | PrimOp
primOp PrimOp -> PrimOp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimOp
DataToTagSmallOp Bool -> Bool -> Bool
|| PrimOp
primOp PrimOp -> PrimOp -> Bool
forall a. Eq a => a -> a -> Bool
== PrimOp
DataToTagLargeOp
    -- see Note [Rewriting primop arguments]
    -> (StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op) ([StgArg] -> Type -> TgStgExpr)
-> RM [StgArg] -> RM (Type -> TgStgExpr)
forall (m :: * -> *) a b. Monad m => (a -> b) -> m a -> m b
<$!> [StgArg] -> RM [StgArg]
rewriteArgs [StgArg]
args RM (Type -> TgStgExpr) -> RM Type -> RM TgStgExpr
forall a b. RM (a -> b) -> RM a -> RM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> RM Type
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
res_ty
  StgOp
_ -> TgStgExpr -> RM TgStgExpr
forall a. a -> RM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TgStgExpr -> RM TgStgExpr) -> TgStgExpr -> RM TgStgExpr
forall a b. (a -> b) -> a -> b
$! StgOp -> [StgArg] -> Type -> TgStgExpr
forall (pass :: StgPass).
StgOp -> [StgArg] -> Type -> GenStgExpr pass
StgOpApp StgOp
op [StgArg]
args Type
res_ty
rewriteOpApp GenStgExpr 'InferTaggedBinders
_ = String -> RM TgStgExpr
forall a. HasCallStack => String -> a
panic String
"Impossible"

-- `mkSeq` x x' e generates `case x of x' -> e`
-- We could also substitute x' for x in e but that's so rarely beneficial
-- that we don't bother.
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq :: Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
id Id
bndr !TgStgExpr
expr =
    -- pprTrace "mkSeq" (ppr (id,bndr)) $
    let altTy :: AltType
altTy = Id -> [GenStgAlt 'CodeGen] -> AltType
forall (p :: StgPass). Id -> [GenStgAlt p] -> AltType
mkStgAltTypeFromStgAlts Id
bndr [GenStgAlt 'CodeGen]
alt
        alt :: [GenStgAlt 'CodeGen]
alt   = [GenStgAlt {alt_con :: AltCon
alt_con = AltCon
DEFAULT, alt_bndrs :: [BinderP 'CodeGen]
alt_bndrs = [], alt_rhs :: TgStgExpr
alt_rhs = TgStgExpr
expr}]
    in TgStgExpr
-> BinderP 'CodeGen -> AltType -> [GenStgAlt 'CodeGen] -> TgStgExpr
forall (pass :: StgPass).
GenStgExpr pass
-> BinderP pass -> AltType -> [GenStgAlt pass] -> GenStgExpr pass
StgCase (Id -> [StgArg] -> TgStgExpr
forall (pass :: StgPass). Id -> [StgArg] -> GenStgExpr pass
StgApp Id
id []) Id
BinderP 'CodeGen
bndr AltType
altTy [GenStgAlt 'CodeGen]
alt

-- `mkSeqs args vs mkExpr` will force all vs, and construct
-- an argument list args' where each v is replaced by it's evaluated
-- counterpart v'.
-- That is if we call `mkSeqs [StgVar x, StgLit l] [x] mkExpr` then
-- the result will be (case x of x' { _DEFAULT -> <mkExpr [StgVar x', StgLit l]>}
{-# INLINE mkSeqs #-} -- We inline to avoid allocating mkExpr
mkSeqs  :: [StgArg] -- ^ Original arguments
        -> [Id]     -- ^ var args to be evaluated ahead of time
        -> ([StgArg] -> TgStgExpr)
                    -- ^ Function that reconstructs the expressions when passed
                    -- the list of evaluated arguments.
        -> RM TgStgExpr
mkSeqs :: [StgArg] -> [Id] -> ([StgArg] -> TgStgExpr) -> RM TgStgExpr
mkSeqs [StgArg]
args [Id]
untaggedIds [StgArg] -> TgStgExpr
mkExpr = do
    argMap <- (Id -> RM (Id, Id)) -> [Id] -> RM [(Id, Id)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Id
arg -> (Id
arg,) (Id -> (Id, Id)) -> RM Id -> RM (Id, Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Id -> RM Id
mkLocalArgId Id
arg ) [Id]
untaggedIds :: RM [(InId, OutId)]
    -- mapM_ (pprTraceM "Forcing strict args before allocation:" . ppr) argMap
    let taggedArgs :: [StgArg]
            = map   (\StgArg
v -> case StgArg
v of
                        StgVarArg Id
v' -> Id -> StgArg
StgVarArg (Id -> StgArg) -> Id -> StgArg
forall a b. (a -> b) -> a -> b
$ Id -> Maybe Id -> Id
forall a. a -> Maybe a -> a
fromMaybe Id
v' (Maybe Id -> Id) -> Maybe Id -> Id
forall a b. (a -> b) -> a -> b
$ Id -> [(Id, Id)] -> Maybe Id
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Id
v' [(Id, Id)]
argMap
                        StgArg
lit -> StgArg
lit)
                    args

    let conBody = [StgArg] -> TgStgExpr
mkExpr [StgArg]
taggedArgs
    let body = ((Id, Id) -> TgStgExpr -> TgStgExpr)
-> TgStgExpr -> [(Id, Id)] -> TgStgExpr
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(Id
v,Id
bndr) TgStgExpr
expr -> Id -> Id -> TgStgExpr -> TgStgExpr
mkSeq Id
v Id
bndr TgStgExpr
expr) TgStgExpr
conBody [(Id, Id)]
argMap
    return $! body

-- Out of all arguments passed at runtime only return these ending up in a
-- strict field
getStrictConArgs :: Outputable a => DataCon -> [a] -> [a]
getStrictConArgs :: forall a. Outputable a => DataCon -> [a] -> [a]
getStrictConArgs DataCon
con [a]
args
    -- These are always lazy in their arguments.
    | DataCon -> Bool
isUnboxedTupleDataCon DataCon
con = []
    | DataCon -> Bool
isUnboxedSumDataCon DataCon
con = []
    -- For proper data cons we have to check.
    | Bool
otherwise =
        Bool -> SDoc -> [a] -> [a]
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr   ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [StrictnessMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (HasDebugCallStack => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con))
                    (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Mismatched con arg and con rep strictness lengths:" SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$
                     String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Con" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> DataCon -> SDoc
forall a. Outputable a => a -> SDoc
ppr DataCon
con SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"is applied to" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> [a] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [a]
args SDoc -> SDoc -> SDoc
forall doc. IsDoc doc => doc -> doc -> doc
$$
                     String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"But seems to have arity" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Int -> SDoc
forall a. Outputable a => a -> SDoc
ppr ([StrictnessMark] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [StrictnessMark]
repStrictness)) ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$
        [ a
arg | (a
arg,StrictnessMark
MarkedStrict)
                    <- String -> [a] -> [StrictnessMark] -> [(a, StrictnessMark)]
forall a b. HasDebugCallStack => String -> [a] -> [b] -> [(a, b)]
zipEqual String
"getStrictConArgs"
                                [a]
args
                                [StrictnessMark]
repStrictness]
        where
            repStrictness :: [StrictnessMark]
repStrictness = (HasDebugCallStack => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con)