{-# LANGUAGE RecordWildCards #-}

-- | Adds cost-centers after the core pipline has run.
module GHC.Core.LateCC
    ( -- * Inserting cost centres
      addLateCostCenters
    ) where

import GHC.Prelude

import GHC.Core
import GHC.Core.LateCC.OverloadedCalls
import GHC.Core.LateCC.TopLevelBinds
import GHC.Core.LateCC.Types
import GHC.Core.LateCC.Utils
import GHC.Core.Seq
import qualified GHC.Data.Strict as Strict
import GHC.Core.Utils
import GHC.Tc.Utils.TcType
import GHC.Types.SrcLoc
import GHC.Utils.Error
import GHC.Utils.Logger
import GHC.Utils.Outputable
import GHC.Types.RepType (mightBeFunTy)

-- | Late cost center insertion logic used by the driver
addLateCostCenters ::
     Logger
  -- ^ Logger
  -> LateCCConfig
  -- ^ Late cost center configuration
  -> CoreProgram
  -- ^ The program
  -> IO (CoreProgram, LateCCState (Strict.Maybe SrcSpan))
addLateCostCenters :: Logger
-> LateCCConfig
-> CoreProgram
-> IO (CoreProgram, LateCCState (Maybe SrcSpan))
addLateCostCenters Logger
logger LateCCConfig{Bool
LateCCEnv
LateCCBindSpec
lateCCConfig_whichBinds :: LateCCBindSpec
lateCCConfig_overloadedCalls :: Bool
lateCCConfig_env :: LateCCEnv
lateCCConfig_env :: LateCCConfig -> LateCCEnv
lateCCConfig_overloadedCalls :: LateCCConfig -> Bool
lateCCConfig_whichBinds :: LateCCConfig -> LateCCBindSpec
..} CoreProgram
core_binds = do

    -- If top-level late CCs are enabled via either -fprof-late or
    -- -fprof-late-overloaded, add them
    (top_level_cc_binds, top_level_late_cc_state) <-
      case LateCCBindSpec
lateCCConfig_whichBinds of
        LateCCBindSpec
LateCCNone ->
          (CoreProgram, LateCCState ()) -> IO (CoreProgram, LateCCState ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CoreProgram
core_binds, () -> LateCCState ()
forall s. s -> LateCCState s
initLateCCState ())
        LateCCBindSpec
_ ->
          Logger
-> SDoc
-> ((CoreProgram, LateCCState ()) -> ())
-> IO (CoreProgram, LateCCState ())
-> IO (CoreProgram, LateCCState ())
forall (m :: * -> *) a.
MonadIO m =>
Logger -> SDoc -> (a -> ()) -> m a -> m a
withTiming
            Logger
logger
            (String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"LateTopLevelCCs" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc
brackets (Module -> SDoc
forall a. Outputable a => a -> SDoc
ppr Module
this_mod))
            (\(CoreProgram
binds, LateCCState ()
late_cc_state) -> CoreProgram -> ()
seqBinds CoreProgram
binds () -> () -> ()
forall a b. a -> b -> b
`seq` LateCCState ()
late_cc_state LateCCState () -> () -> ()
forall a b. a -> b -> b
`seq` ())
            (IO (CoreProgram, LateCCState ())
 -> IO (CoreProgram, LateCCState ()))
-> IO (CoreProgram, LateCCState ())
-> IO (CoreProgram, LateCCState ())
forall a b. (a -> b) -> a -> b
$ {-# SCC lateTopLevelCCs #-} do
              (CoreProgram, LateCCState ()) -> IO (CoreProgram, LateCCState ())
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((CoreProgram, LateCCState ()) -> IO (CoreProgram, LateCCState ()))
-> (CoreProgram, LateCCState ())
-> IO (CoreProgram, LateCCState ())
forall a b. (a -> b) -> a -> b
$
                LateCCEnv
-> LateCCState ()
-> (CoreBind -> LateCCM () CoreBind)
-> CoreProgram
-> (CoreProgram, LateCCState ())
forall s.
LateCCEnv
-> LateCCState s
-> (CoreBind -> LateCCM s CoreBind)
-> CoreProgram
-> (CoreProgram, LateCCState s)
doLateCostCenters
                  LateCCEnv
lateCCConfig_env
                  (() -> LateCCState ()
forall s. s -> LateCCState s
initLateCCState ())
                  ((CoreExpr -> Bool) -> CoreBind -> LateCCM () CoreBind
forall s. (CoreExpr -> Bool) -> CoreBind -> LateCCM s CoreBind
topLevelBindsCC CoreExpr -> Bool
top_level_cc_pred)
                  CoreProgram
core_binds

    -- If overloaded call CCs are enabled via -fprof-late-overloaded-calls, add
    -- them
    (late_cc_binds, late_cc_state) <-
      if lateCCConfig_overloadedCalls then
        withTiming
            logger
            (text "LateOverloadedCallsCCs" <+> brackets (ppr this_mod))
            (\(CoreProgram
binds, LateCCState (Maybe SrcSpan)
late_cc_state) -> CoreProgram -> ()
seqBinds CoreProgram
binds () -> () -> ()
forall a b. a -> b -> b
`seq` LateCCState (Maybe SrcSpan)
late_cc_state LateCCState (Maybe SrcSpan) -> () -> ()
forall a b. a -> b -> b
`seq` ())
            $ {-# SCC lateoverloadedCallsCCs #-} do
              pure $
                doLateCostCenters
                  lateCCConfig_env
                  (top_level_late_cc_state { lateCCState_extra = Strict.Nothing })
                  overloadedCallsCC
                  top_level_cc_binds
      else
        return
          ( top_level_cc_binds
          , top_level_late_cc_state { lateCCState_extra = Strict.Nothing }
          )

    return (late_cc_binds, late_cc_state)
  where
    top_level_cc_pred :: CoreExpr -> Bool
    top_level_cc_pred :: CoreExpr -> Bool
top_level_cc_pred =
        case LateCCBindSpec
lateCCConfig_whichBinds of
          LateCCBindSpec
LateCCBinds -> \CoreExpr
rhs ->
            -- Make sure we record any functions. Even if it's something like `f = g`.
            Type -> Bool
mightBeFunTy (HasDebugCallStack => CoreExpr -> Type
CoreExpr -> Type
exprType CoreExpr
rhs) Bool -> Bool -> Bool
||
            -- If the RHS is a CAF doing work also insert a CC.
            Bool -> Bool
not (CoreExpr -> Bool
exprIsWorkFree CoreExpr
rhs)
          LateCCBindSpec
LateCCOverloadedBinds ->
            Type -> Bool
isOverloadedTy (Type -> Bool) -> (CoreExpr -> Type) -> CoreExpr -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HasDebugCallStack => CoreExpr -> Type
CoreExpr -> Type
exprType
          LateCCBindSpec
LateCCNone ->
            -- This is here for completeness, we won't actually use this
            -- predicate in this case since we'll shortcut.
            Bool -> CoreExpr -> Bool
forall a b. a -> b -> a
const Bool
False

    this_mod :: Module
this_mod = LateCCEnv -> Module
lateCCEnv_module LateCCEnv
lateCCConfig_env