{-# LANGUAGE CPP #-}
module GHC.StgToCmm.TagCheck
( emitTagAssertion, emitArgTagCheck, checkArg, whenCheckTags,
checkArgStatic, checkFunctionArgTags,checkConArgsStatic,checkConArgsDyn) where
#include "ClosureTypes.h"
import GHC.Prelude
import GHC.StgToCmm.Env
import GHC.StgToCmm.Monad
import GHC.StgToCmm.Utils
import GHC.Cmm
import GHC.Cmm.BlockId
import GHC.Cmm.Graph as CmmGraph
import GHC.Core.Type
import GHC.Types.Id
import GHC.Utils.Misc
import GHC.Utils.Outputable
import GHC.Core.DataCon
import Control.Monad
import GHC.StgToCmm.Types
import GHC.Utils.Panic (pprPanic, panic)
import GHC.Stg.Syntax
import GHC.StgToCmm.Closure
import GHC.Cmm.Switch (mkSwitchTargets)
import GHC.Cmm.Info (cmmGetClosureType)
import GHC.Types.RepType (dataConRuntimeRepStrictness)
import GHC.Types.Basic
import GHC.Data.FastString (mkFastString)
import qualified Data.Map as M
checkFunctionArgTags :: SDoc -> Id -> [Id] -> FCode ()
checkFunctionArgTags :: SDoc -> Id -> [Id] -> FCode ()
checkFunctionArgTags SDoc
msg Id
f [Id]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
FCode () -> Maybe [CbvMark] -> ([CbvMark] -> FCode ()) -> FCode ()
forall b a. b -> Maybe a -> (a -> b) -> b
onJust (() -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) (Id -> Maybe [CbvMark]
idCbvMarks_maybe Id
f) (([CbvMark] -> FCode ()) -> FCode ())
-> ([CbvMark] -> FCode ()) -> FCode ()
forall a b. (a -> b) -> a -> b
$ \[CbvMark]
marks -> do
let cbv_args :: [Id]
cbv_args = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
isBoxedType (Type -> Bool) -> (Id -> Type) -> Id -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType) ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Id] -> [Id]
forall a. [Bool] -> [a] -> [a]
filterByList ((CbvMark -> Bool) -> [CbvMark] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map CbvMark -> Bool
isMarkedCbv [CbvMark]
marks) [Id]
args
arg_infos <- (Id -> FCode CgIdInfo) -> [Id] -> FCode [CgIdInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Id -> FCode CgIdInfo
getCgIdInfo [Id]
cbv_args
let arg_cmms = (CgIdInfo -> CmmExpr) -> [CgIdInfo] -> [CmmExpr]
forall a b. (a -> b) -> [a] -> [b]
map CgIdInfo -> CmmExpr
idInfoToAmode [CgIdInfo]
arg_infos
mapM_ (\(CmmExpr
cmm,Id
arg) -> String -> CmmExpr -> FCode ()
emitTagAssertion (SDoc -> String
forall a. Outputable a => a -> String
showPprUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ SDoc
msg SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
arg) CmmExpr
cmm) (zip arg_cmms cbv_args)
checkConArgsStatic :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsStatic :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsStatic SDoc
msg DataCon
con [StgArg]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
let marks :: [StrictnessMark]
marks = HasDebugCallStack => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con
(StrictnessMark -> StgArg -> FCode ())
-> [StrictnessMark] -> [StgArg] -> FCode ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic SDoc
msg) [StrictnessMark]
marks [StgArg]
args
checkConArgsDyn :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsDyn :: SDoc -> DataCon -> [StgArg] -> FCode ()
checkConArgsDyn SDoc
msg DataCon
con [StgArg]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
let marks :: [StrictnessMark]
marks = HasDebugCallStack => DataCon -> [StrictnessMark]
DataCon -> [StrictnessMark]
dataConRuntimeRepStrictness DataCon
con
(CbvMark -> StgArg -> FCode ())
-> [CbvMark] -> [StgArg] -> FCode ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SDoc -> CbvMark -> StgArg -> FCode ()
checkArg SDoc
msg) ((StrictnessMark -> CbvMark) -> [StrictnessMark] -> [CbvMark]
forall a b. (a -> b) -> [a] -> [b]
map StrictnessMark -> CbvMark
cbvFromStrictMark [StrictnessMark]
marks) [StgArg]
args
whenCheckTags :: FCode () -> FCode ()
whenCheckTags :: FCode () -> FCode ()
whenCheckTags FCode ()
act = do
check_tags <- StgToCmmConfig -> Bool
stgToCmmDoTagCheck (StgToCmmConfig -> Bool) -> FCode StgToCmmConfig -> FCode Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FCode StgToCmmConfig
getStgToCmmConfig
when check_tags act
emitTagAssertion :: String -> CmmExpr -> FCode ()
emitTagAssertion :: String -> CmmExpr -> FCode ()
emitTagAssertion String
onWhat CmmExpr
fun = do
{ platform <- FCode Platform
getPlatform
; lret <- newBlockId
; lno_tag <- newBlockId
; lbarf <- newBlockId
; emit $ mkCbranch (cmmIsTagged platform fun)
lret lno_tag (Just True)
; emitLabel lno_tag
; emitComment (mkFastString "closereTypeCheck")
; needsArgTag fun lbarf lret
; emitLabel lbarf
; emitBarf ("Tag inference failed on:" ++ onWhat)
; emitLabel lret
}
needsArgTag :: CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag :: CmmExpr -> BlockId -> BlockId -> FCode ()
needsArgTag CmmExpr
closure BlockId
fail BlockId
lpass = do
profile <- FCode Profile
getProfile
align_check <- stgToCmmAlignCheck <$> getStgToCmmConfig
let clo_ty_e = Profile -> Bool -> CmmExpr -> CmmExpr
cmmGetClosureType Profile
profile Bool
align_check CmmExpr
closure
let targets = Bool
-> (Integer, Integer)
-> Maybe BlockId
-> Map Integer BlockId
-> SwitchTargets
mkSwitchTargets
Bool
False
(INVALID_OBJECT, N_CLOSURE_TYPES)
(BlockId -> Maybe BlockId
forall a. a -> Maybe a
Just BlockId
fail)
([(Integer, BlockId)] -> Map Integer BlockId
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(PAP,lpass)
,(BCO,lpass)
,(FUN,lpass)
,(FUN_1_0,lpass)
,(FUN_0_1,lpass)
,(FUN_2_0,lpass)
,(FUN_1_1,lpass)
,(FUN_0_2,lpass)
,(FUN_STATIC,lpass)
])
emit $ mkSwitch clo_ty_e targets
emit $ mkBranch lpass
emitArgTagCheck :: SDoc -> [CbvMark] -> [Id] -> FCode ()
emitArgTagCheck :: SDoc -> [CbvMark] -> [Id] -> FCode ()
emitArgTagCheck SDoc
info [CbvMark]
marks [Id]
args = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$ do
mod <- FCode Module
getModuleName
let cbv_args = (Id -> Bool) -> [Id] -> [Id]
forall a. (a -> Bool) -> [a] -> [a]
filter (Type -> Bool
isBoxedType (Type -> Bool) -> (Id -> Type) -> Id -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
idType) ([Id] -> [Id]) -> [Id] -> [Id]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Id] -> [Id]
forall a. [Bool] -> [a] -> [a]
filterByList ((CbvMark -> Bool) -> [CbvMark] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map CbvMark -> Bool
isMarkedCbv [CbvMark]
marks) [Id]
args
forM_ cbv_args $ \Id
arg -> do
cginfo <- Id -> FCode CgIdInfo
getCgIdInfo Id
arg
let msg = SDoc -> String
forall a. Outputable a => a -> String
showPprUnsafe (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"Untagged arg:" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> (Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
mod) SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> Char -> SDoc
forall doc. IsLine doc => Char -> doc
char Char
':' SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> SDoc
info SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr Id
arg)
emitTagAssertion msg (idInfoToAmode cginfo)
taggedCgInfo :: CgIdInfo -> Bool
taggedCgInfo :: CgIdInfo -> Bool
taggedCgInfo CgIdInfo
cg_info
= case LambdaFormInfo
lf of
LFCon {} -> Bool
True
LFReEntrant {} -> Bool
True
LFUnlifted {} -> Bool
True
LFThunk {} -> Bool
False
LFUnknown {} -> Bool
False
LambdaFormInfo
LFLetNoEscape -> String -> Bool
forall a. HasCallStack => String -> a
panic String
"Let no escape binding passed to top level con"
where
lf :: LambdaFormInfo
lf = CgIdInfo -> LambdaFormInfo
cg_lf CgIdInfo
cg_info
checkArg :: SDoc -> CbvMark -> StgArg -> FCode ()
checkArg :: SDoc -> CbvMark -> StgArg -> FCode ()
checkArg SDoc
_ CbvMark
NotMarkedCbv StgArg
_ = () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkArg SDoc
msg CbvMark
MarkedCbv StgArg
arg = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$
case StgArg
arg of
StgLitArg Literal
_ -> () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
StgVarArg Id
v -> do
info <- Id -> FCode CgIdInfo
getCgIdInfo Id
v
if taggedCgInfo info
then return ()
else case (cg_loc info) of
CmmLoc CmmExpr
loc -> String -> CmmExpr -> FCode ()
emitTagAssertion (SDoc -> String
forall a. Outputable a => a -> String
showPprUnsafe (SDoc -> String) -> SDoc -> String
forall a b. (a -> b) -> a -> b
$ SDoc
msg SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"arg:" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<> StgArg -> SDoc
forall a. Outputable a => a -> SDoc
ppr StgArg
arg) CmmExpr
loc
LneLoc {} -> String -> FCode ()
forall a. HasCallStack => String -> a
panic String
"LNE-arg"
checkArgStatic :: SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic :: SDoc -> StrictnessMark -> StgArg -> FCode ()
checkArgStatic SDoc
_ StrictnessMark
NotMarkedStrict StgArg
_ = () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
checkArgStatic SDoc
msg StrictnessMark
MarkedStrict StgArg
arg = FCode () -> FCode ()
whenCheckTags (FCode () -> FCode ()) -> FCode () -> FCode ()
forall a b. (a -> b) -> a -> b
$
case StgArg
arg of
StgLitArg Literal
_ -> () -> FCode ()
forall a. a -> FCode a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
StgVarArg Id
v -> do
info <- Id -> FCode CgIdInfo
getCgIdInfo Id
v
if taggedCgInfo info
then return ()
else pprPanic "Arg not tagged as expected" (ppr msg <+> ppr arg)