{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE LambdaCase #-}

-- | JS codegen state monad
module GHC.StgToJS.Monad
  ( runG
  , emitGlobal
  , addDependency
  , emitToplevel
  , emitStatic
  , emitClosureInfo
  , emitForeign
  , assertRtsStat
  , getSettings
  , globalOccs
  , setGlobalIdCache
  , getGlobalIdCache
  , GlobalOcc(..)
  -- * Group
  , modifyGroup
  , resetGroup
  )
where

import GHC.Prelude

import GHC.JS.JStg.Syntax
import GHC.JS.Ident
import GHC.JS.Transform

import GHC.StgToJS.Types

import GHC.Unit.Module
import GHC.Utils.Outputable
import GHC.Stg.Syntax

import GHC.Types.SrcLoc
import GHC.Types.Id
import GHC.Types.Unique.FM
import GHC.Types.ForeignCall

import qualified Control.Monad.Trans.State.Strict as State
import GHC.Data.FastString
import GHC.Data.FastMutInt

import qualified Data.Map  as M
import qualified Data.Set  as S

runG :: StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> G a -> IO a
runG :: forall a.
StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> G a -> IO a
runG StgToJSConfig
config Module
m UniqFM Id CgStgExpr
unfloat G a
action = G a -> GenState -> IO a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
State.evalStateT G a
action (GenState -> IO a) -> IO GenState -> IO a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> IO GenState
initState StgToJSConfig
config Module
m UniqFM Id CgStgExpr
unfloat

initState :: StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> IO GenState
initState :: StgToJSConfig -> Module -> UniqFM Id CgStgExpr -> IO GenState
initState StgToJSConfig
config Module
m UniqFM Id CgStgExpr
unfloat = do
  id_gen <- Int -> IO FastMutInt
newFastMutInt Int
1
  pure $ GenState
    { gsSettings  = config
    , gsModule    = m
    , gsId        = id_gen
    , gsIdents    = emptyIdCache
    , gsUnfloated = unfloat
    , gsGroup     = defaultGenGroupState
    , gsGlobal    = []
    }


modifyGroup :: (GenGroupState -> GenGroupState) -> G ()
modifyGroup :: (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
f = (GenState -> GenState) -> G ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify GenState -> GenState
mod_state
  where
    mod_state :: GenState -> GenState
mod_state GenState
s = GenState
s { gsGroup = f (gsGroup s) }

-- | emit a global (for the current module) toplevel statement
emitGlobal :: JStgStat -> G ()
emitGlobal :: JStgStat -> G ()
emitGlobal JStgStat
stat = (GenState -> GenState) -> G ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify (\GenState
s -> GenState
s { gsGlobal = stat : gsGlobal s })

-- | add a dependency on a particular symbol to the current group
addDependency :: OtherSymb -> G ()
addDependency :: OtherSymb -> G ()
addDependency OtherSymb
symbol = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsExtraDeps = S.insert symbol (ggsExtraDeps g) }

-- | emit a top-level statement for the current binding group
emitToplevel :: JStgStat -> G ()
emitToplevel :: JStgStat -> G ()
emitToplevel JStgStat
s = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsToplevelStats = s : ggsToplevelStats g}

-- | emit static data for the binding group
emitStatic :: FastString -> StaticVal -> Maybe Ident -> G ()
emitStatic :: FastString -> StaticVal -> Maybe Ident -> G ()
emitStatic FastString
ident StaticVal
val Maybe Ident
cc = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group  GenGroupState
g = GenGroupState
g { ggsStatic = mod_static (ggsStatic g) }
    mod_static :: [StaticInfo] -> [StaticInfo]
mod_static [StaticInfo]
s = FastString -> StaticVal -> Maybe Ident -> StaticInfo
StaticInfo FastString
ident StaticVal
val Maybe Ident
cc StaticInfo -> [StaticInfo] -> [StaticInfo]
forall a. a -> [a] -> [a]
: [StaticInfo]
s

-- | add closure info in our binding group. all heap objects must have closure info
emitClosureInfo :: ClosureInfo -> G ()
emitClosureInfo :: ClosureInfo -> G ()
emitClosureInfo ClosureInfo
ci = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsClosureInfo = ci : ggsClosureInfo g}

emitForeign :: Maybe RealSrcSpan
            -> FastString
            -> Safety
            -> CCallConv
            -> [FastString]
            -> FastString
            -> G ()
emitForeign :: Maybe RealSrcSpan
-> FastString
-> Safety
-> CCallConv
-> [FastString]
-> FastString
-> G ()
emitForeign Maybe RealSrcSpan
mbSpan FastString
pat Safety
safety CCallConv
cconv [FastString]
arg_tys FastString
res_ty = (GenGroupState -> GenGroupState) -> G ()
modifyGroup GenGroupState -> GenGroupState
mod_group
  where
    mod_group :: GenGroupState -> GenGroupState
mod_group GenGroupState
g = GenGroupState
g { ggsForeignRefs = new_ref : ggsForeignRefs g }
    new_ref :: ForeignJSRef
new_ref = FastString
-> FastString
-> Safety
-> CCallConv
-> [FastString]
-> FastString
-> ForeignJSRef
ForeignJSRef FastString
spanTxt FastString
pat Safety
safety CCallConv
cconv [FastString]
arg_tys FastString
res_ty
    spanTxt :: FastString
spanTxt = case Maybe RealSrcSpan
mbSpan of
                -- TODO: Is there a better way to concatenate FastStrings?
                Just RealSrcSpan
sp -> [Char] -> FastString
mkFastString ([Char] -> FastString) -> [Char] -> FastString
forall a b. (a -> b) -> a -> b
$
                  FastString -> [Char]
unpackFS (RealSrcSpan -> FastString
srcSpanFile RealSrcSpan
sp) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
                  [Char]
" " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
                  (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (RealSrcSpan -> Int
srcSpanStartLine RealSrcSpan
sp, RealSrcSpan -> Int
srcSpanStartCol RealSrcSpan
sp) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
                  [Char]
"-" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++
                  (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (RealSrcSpan -> Int
srcSpanEndLine RealSrcSpan
sp, RealSrcSpan -> Int
srcSpanEndCol RealSrcSpan
sp)
                Maybe RealSrcSpan
Nothing -> FastString
"<unknown>"






-- | start with a new binding group
resetGroup :: G ()
resetGroup :: G ()
resetGroup = (GenState -> GenState) -> G ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify (\GenState
s -> GenState
s { gsGroup = defaultGenGroupState })

defaultGenGroupState :: GenGroupState
defaultGenGroupState :: GenGroupState
defaultGenGroupState = [JStgStat]
-> [ClosureInfo]
-> [StaticInfo]
-> [StackSlot]
-> Int
-> Set OtherSymb
-> GlobalIdCache
-> [ForeignJSRef]
-> GenGroupState
GenGroupState [] [] [] [] Int
0 Set OtherSymb
forall a. Set a
S.empty GlobalIdCache
emptyGlobalIdCache []

emptyGlobalIdCache :: GlobalIdCache
emptyGlobalIdCache :: GlobalIdCache
emptyGlobalIdCache = UniqFM Ident (IdKey, Id) -> GlobalIdCache
GlobalIdCache UniqFM Ident (IdKey, Id)
forall {k} (key :: k) elt. UniqFM key elt
emptyUFM

emptyIdCache :: IdCache
emptyIdCache :: IdCache
emptyIdCache = Map IdKey Ident -> IdCache
IdCache Map IdKey Ident
forall k a. Map k a
M.empty



assertRtsStat :: G JStgStat -> G JStgStat
assertRtsStat :: G JStgStat -> G JStgStat
assertRtsStat G JStgStat
stat = do
  s <- (GenState -> StgToJSConfig) -> StateT GenState IO StgToJSConfig
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets GenState -> StgToJSConfig
gsSettings
  if csAssertRts s then stat else pure mempty

getSettings :: G StgToJSConfig
getSettings :: StateT GenState IO StgToJSConfig
getSettings = (GenState -> StgToJSConfig) -> StateT GenState IO StgToJSConfig
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets GenState -> StgToJSConfig
gsSettings

getGlobalIdCache :: G GlobalIdCache
getGlobalIdCache :: G GlobalIdCache
getGlobalIdCache = (GenState -> GlobalIdCache) -> G GlobalIdCache
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets (GenGroupState -> GlobalIdCache
ggsGlobalIdCache (GenGroupState -> GlobalIdCache)
-> (GenState -> GenGroupState) -> GenState -> GlobalIdCache
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenState -> GenGroupState
gsGroup)

setGlobalIdCache :: GlobalIdCache -> G ()
setGlobalIdCache :: GlobalIdCache -> G ()
setGlobalIdCache GlobalIdCache
v = (GenState -> GenState) -> G ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify (\GenState
s -> GenState
s { gsGroup = (gsGroup s) { ggsGlobalIdCache = v}})

data GlobalOcc = GlobalOcc
  { GlobalOcc -> Id
global_id    :: !Id
  , GlobalOcc -> Word
global_count :: !Word
  }

instance Outputable GlobalOcc where
  ppr :: GlobalOcc -> SDoc
ppr GlobalOcc
g = SDoc -> Int -> SDoc -> SDoc
hang ([Char] -> SDoc
forall doc. IsLine doc => [Char] -> doc
text [Char]
"GlobalOcc") Int
2 (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat
            [ [SDoc] -> SDoc
forall doc. IsLine doc => [doc] -> doc
hcat [[Char] -> SDoc
forall doc. IsLine doc => [Char] -> doc
text [Char]
"Id:", Id -> SDoc
forall a. Outputable a => a -> SDoc
ppr (GlobalOcc -> Id
global_id GlobalOcc
g)]
            , [SDoc] -> SDoc
forall doc. IsLine doc => [doc] -> doc
hcat [[Char] -> SDoc
forall doc. IsLine doc => [Char] -> doc
text [Char]
"Count:", Word -> SDoc
forall a. Outputable a => a -> SDoc
ppr (GlobalOcc -> Word
global_count GlobalOcc
g)]
            ]

-- | Return occurrences of every global id used in the given JStgStat.
-- Sort by increasing occurrence count.
globalOccs :: JStgStat -> G (UniqFM Id GlobalOcc)
globalOccs :: JStgStat -> G (UniqFM Id GlobalOcc)
globalOccs JStgStat
jst = do
  GlobalIdCache gidc <- G GlobalIdCache
getGlobalIdCache
  -- build a map form Ident Unique to (Id, Count)
  -- Note that different Idents can map to the same Id (e.g. string payload and string offset idents)
  let
    inc GlobalOcc
g1 GlobalOcc
g2 = GlobalOcc
g1 { global_count = global_count g1 + global_count g2 }

    go :: UniqFM Id GlobalOcc -> [Ident] -> UniqFM Id GlobalOcc
    go UniqFM Id GlobalOcc
gids = \case
        []     -> UniqFM Id GlobalOcc
gids
        (Ident
i:[Ident]
is) ->
          -- check if the Id is global
          case UniqFM Ident (IdKey, Id) -> Ident -> Maybe (IdKey, Id)
forall key elt. Uniquable key => UniqFM key elt -> key -> Maybe elt
lookupUFM UniqFM Ident (IdKey, Id)
gidc Ident
i of
            Maybe (IdKey, Id)
Nothing       -> UniqFM Id GlobalOcc -> [Ident] -> UniqFM Id GlobalOcc
go UniqFM Id GlobalOcc
gids [Ident]
is
            Just (IdKey
_k,Id
gid) ->
              -- add it to the list of already found global ids. Increasing
              -- count by 1
              let g :: GlobalOcc
g = Id -> Word -> GlobalOcc
GlobalOcc Id
gid Word
1
              in UniqFM Id GlobalOcc -> [Ident] -> UniqFM Id GlobalOcc
go ((GlobalOcc -> GlobalOcc -> GlobalOcc)
-> UniqFM Id GlobalOcc -> Id -> GlobalOcc -> UniqFM Id GlobalOcc
forall key elt.
Uniquable key =>
(elt -> elt -> elt)
-> UniqFM key elt -> key -> elt -> UniqFM key elt
addToUFM_C GlobalOcc -> GlobalOcc -> GlobalOcc
inc UniqFM Id GlobalOcc
gids Id
gid GlobalOcc
g) [Ident]
is

  pure $ go emptyUFM $ identsS jst