{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}

module GHC.StgToJS.Sinker.Sinker (sinkPgm) where

import GHC.Prelude
import GHC.Types.Unique.Set
import GHC.Types.Unique.FM
import GHC.Types.Var.Set
import GHC.Stg.Syntax
import GHC.Types.Id
import GHC.Types.Name
import GHC.Unit.Module
import GHC.Types.Literal
import GHC.Data.Graph.Directed
import GHC.StgToJS.Sinker.Collect
import GHC.StgToJS.Sinker.StringsUnfloat

import GHC.Utils.Misc (partitionWith)
import GHC.StgToJS.Utils

import Data.Char
import Data.List (partition)
import Data.Maybe
import Data.ByteString (ByteString)

-- | Unfloat some top-level unexported things
--
-- GHC floats constants to the top level. This is fine in native code, but with JS
-- they occupy some global variable name. We can unfloat some unexported things:
--
-- - global constructors, as long as they're referenced only once by another global
--      constructor and are not in a recursive binding group
-- - literals (small literals may also be sunk if they are used more than once)
sinkPgm :: Module
        -> [CgStgTopBinding]
        -> (UniqFM Id CgStgExpr, [CgStgTopBinding])
sinkPgm :: Module
-> [CgStgTopBinding] -> (UniqFM Id CgStgExpr, [CgStgTopBinding])
sinkPgm Module
m [CgStgTopBinding]
pgm
  = (UniqFM Id CgStgExpr
sunk, (GenStgBinding 'CodeGen -> CgStgTopBinding)
-> [GenStgBinding 'CodeGen] -> [CgStgTopBinding]
forall a b. (a -> b) -> [a] -> [b]
map GenStgBinding 'CodeGen -> CgStgTopBinding
forall (pass :: StgPass).
GenStgBinding pass -> GenStgTopBinding pass
StgTopLifted [GenStgBinding 'CodeGen]
pgm''' [CgStgTopBinding] -> [CgStgTopBinding] -> [CgStgTopBinding]
forall a. [a] -> [a] -> [a]
++ [CgStgTopBinding]
stringLits)
  where
    selectLifted :: CgStgTopBinding -> Either CgStgBinding (Id, ByteString)
    selectLifted :: CgStgTopBinding -> Either (GenStgBinding 'CodeGen) (Id, ByteString)
selectLifted (StgTopLifted GenStgBinding 'CodeGen
b)      = GenStgBinding 'CodeGen
-> Either (GenStgBinding 'CodeGen) (Id, ByteString)
forall a b. a -> Either a b
Left GenStgBinding 'CodeGen
b
    selectLifted (StgTopStringLit Id
i ByteString
b) = (Id, ByteString)
-> Either (GenStgBinding 'CodeGen) (Id, ByteString)
forall a b. b -> Either a b
Right (Id
i, ByteString
b)

    ([GenStgBinding 'CodeGen]
pgm', [(Id, ByteString)]
allStringLits) = (CgStgTopBinding
 -> Either (GenStgBinding 'CodeGen) (Id, ByteString))
-> [CgStgTopBinding]
-> ([GenStgBinding 'CodeGen], [(Id, ByteString)])
forall a b c. (a -> Either b c) -> [a] -> ([b], [c])
partitionWith CgStgTopBinding -> Either (GenStgBinding 'CodeGen) (Id, ByteString)
selectLifted [CgStgTopBinding]
pgm
    usedOnceIds :: UniqSet Id
usedOnceIds = [Id] -> UniqSet Id
forall (t :: * -> *) a.
(Foldable t, Uniquable a) =>
t a -> UniqSet a
selectUsedOnce ([Id] -> UniqSet Id) -> [Id] -> UniqSet Id
forall a b. (a -> b) -> a -> b
$ (GenStgBinding 'CodeGen -> [Id])
-> [GenStgBinding 'CodeGen] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap GenStgBinding 'CodeGen -> [Id]
collectArgs [GenStgBinding 'CodeGen]
pgm'

    stringLitsUFM :: UniqFM Name (Id, ByteString)
stringLitsUFM = [(Name, (Id, ByteString))] -> UniqFM Name (Id, ByteString)
forall key elt. Uniquable key => [(key, elt)] -> UniqFM key elt
listToUFM ([(Name, (Id, ByteString))] -> UniqFM Name (Id, ByteString))
-> [(Name, (Id, ByteString))] -> UniqFM Name (Id, ByteString)
forall a b. (a -> b) -> a -> b
$ (\(Id
i, ByteString
b) -> (Id -> Name
idName Id
i, (Id
i, ByteString
b))) ((Id, ByteString) -> (Name, (Id, ByteString)))
-> [(Id, ByteString)] -> [(Name, (Id, ByteString))]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Id, ByteString)]
allStringLits
    ([GenStgBinding 'CodeGen]
pgm'', UniqSet Name
_actuallyUnfloatedStringLitNames) =
      UniqSet Name
-> UniqFM Name ByteString
-> [GenStgBinding 'CodeGen]
-> ([GenStgBinding 'CodeGen], UniqSet Name)
unfloatStringLits
        (Id -> Name
idName (Id -> Name) -> UniqSet Id -> UniqSet Name
forall b a. Uniquable b => (a -> b) -> UniqSet a -> UniqSet b
`mapUniqSet` UniqSet Id
usedOnceIds)
        ((Id, ByteString) -> ByteString
forall a b. (a, b) -> b
snd ((Id, ByteString) -> ByteString)
-> UniqFM Name (Id, ByteString) -> UniqFM Name ByteString
forall {k} elt1 elt2 (key :: k).
(elt1 -> elt2) -> UniqFM key elt1 -> UniqFM key elt2
`mapUFM` UniqFM Name (Id, ByteString)
stringLitsUFM)
        [GenStgBinding 'CodeGen]
pgm'

    stringLits :: [CgStgTopBinding]
stringLits = (Id -> ByteString -> CgStgTopBinding)
-> (Id, ByteString) -> CgStgTopBinding
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Id -> ByteString -> CgStgTopBinding
forall (pass :: StgPass). Id -> ByteString -> GenStgTopBinding pass
StgTopStringLit ((Id, ByteString) -> CgStgTopBinding)
-> [(Id, ByteString)] -> [CgStgTopBinding]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Id, ByteString)]
allStringLits

    (UniqFM Id CgStgExpr
sunk, [GenStgBinding 'CodeGen]
pgm''') = Module
-> UniqSet Id
-> [GenStgBinding 'CodeGen]
-> (UniqFM Id CgStgExpr, [GenStgBinding 'CodeGen])
sinkPgm' Module
m UniqSet Id
usedOnceIds [GenStgBinding 'CodeGen]
pgm''

sinkPgm'
  :: Module
       -- ^ the module, since we treat definitions from the current module
       -- differently
  -> IdSet
       -- ^ the set of used once ids
  -> [CgStgBinding]
       -- ^ the bindings
  -> (UniqFM Id CgStgExpr, [CgStgBinding])
       -- ^ a map with sunken replacements for nodes, for where the replacement
       -- does not fit in the 'StgBinding' AST and the new bindings
sinkPgm' :: Module
-> UniqSet Id
-> [GenStgBinding 'CodeGen]
-> (UniqFM Id CgStgExpr, [GenStgBinding 'CodeGen])
sinkPgm' Module
m UniqSet Id
usedOnceIds [GenStgBinding 'CodeGen]
pgm =
  let usedOnce :: UniqSet Id
usedOnce = UniqSet Id -> [GenStgBinding 'CodeGen] -> UniqSet Id
collectTopLevelUsedOnce UniqSet Id
usedOnceIds [GenStgBinding 'CodeGen]
pgm
      sinkables :: UniqFM Id CgStgExpr
sinkables = [(Id, CgStgExpr)] -> UniqFM Id CgStgExpr
forall key elt. Uniquable key => [(key, elt)] -> UniqFM key elt
listToUFM ([(Id, CgStgExpr)] -> UniqFM Id CgStgExpr)
-> [(Id, CgStgExpr)] -> UniqFM Id CgStgExpr
forall a b. (a -> b) -> a -> b
$
          (GenStgBinding 'CodeGen -> [(Id, CgStgExpr)])
-> [GenStgBinding 'CodeGen] -> [(Id, CgStgExpr)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap GenStgBinding 'CodeGen -> [(Id, CgStgExpr)]
alwaysSinkable [GenStgBinding 'CodeGen]
pgm [(Id, CgStgExpr)] -> [(Id, CgStgExpr)] -> [(Id, CgStgExpr)]
forall a. [a] -> [a] -> [a]
++
          (GenStgBinding 'CodeGen -> [(Id, CgStgExpr)])
-> [GenStgBinding 'CodeGen] -> [(Id, CgStgExpr)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (((Id, CgStgExpr) -> Bool) -> [(Id, CgStgExpr)] -> [(Id, CgStgExpr)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Id -> UniqSet Id -> Bool
forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` UniqSet Id
usedOnce) (Id -> Bool) -> ((Id, CgStgExpr) -> Id) -> (Id, CgStgExpr) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, CgStgExpr) -> Id
forall a b. (a, b) -> a
fst) ([(Id, CgStgExpr)] -> [(Id, CgStgExpr)])
-> (GenStgBinding 'CodeGen -> [(Id, CgStgExpr)])
-> GenStgBinding 'CodeGen
-> [(Id, CgStgExpr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Module -> GenStgBinding 'CodeGen -> [(Id, CgStgExpr)]
onceSinkable Module
m) [GenStgBinding 'CodeGen]
pgm
      isSunkBind :: GenStgBinding 'CodeGen -> Bool
isSunkBind (StgNonRec BinderP 'CodeGen
b GenStgRhs 'CodeGen
_e) | Id -> UniqFM Id CgStgExpr -> Bool
forall key elt. Uniquable key => key -> UniqFM key elt -> Bool
elemUFM Id
BinderP 'CodeGen
b UniqFM Id CgStgExpr
sinkables = Bool
True
      isSunkBind GenStgBinding 'CodeGen
_                                      = Bool
False
  in (UniqFM Id CgStgExpr
sinkables, (GenStgBinding 'CodeGen -> Bool)
-> [GenStgBinding 'CodeGen] -> [GenStgBinding 'CodeGen]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> (GenStgBinding 'CodeGen -> Bool)
-> GenStgBinding 'CodeGen
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. GenStgBinding 'CodeGen -> Bool
isSunkBind) ([GenStgBinding 'CodeGen] -> [GenStgBinding 'CodeGen])
-> [GenStgBinding 'CodeGen] -> [GenStgBinding 'CodeGen]
forall a b. (a -> b) -> a -> b
$ Module -> [GenStgBinding 'CodeGen] -> [GenStgBinding 'CodeGen]
topSortDecls Module
m [GenStgBinding 'CodeGen]
pgm)

-- | always sinkable, values that may be duplicated in the generated code (e.g.
-- small literals)
alwaysSinkable :: CgStgBinding -> [(Id, CgStgExpr)]
alwaysSinkable :: GenStgBinding 'CodeGen -> [(Id, CgStgExpr)]
alwaysSinkable (StgRec {})       = []
alwaysSinkable (StgNonRec BinderP 'CodeGen
b GenStgRhs 'CodeGen
rhs) = case GenStgRhs 'CodeGen
rhs of
  StgRhsClosure XRhsClosure 'CodeGen
_ CostCentreStack
_ UpdateFlag
_ [BinderP 'CodeGen]
_ e :: CgStgExpr
e@(StgLit Literal
l) Type
_
    | Literal -> Bool
isSmallSinkableLit Literal
l
    , Id -> Bool
isLocal Id
BinderP 'CodeGen
b
    -> [(Id
BinderP 'CodeGen
b,CgStgExpr
e)]
  StgRhsCon CostCentreStack
_ccs DataCon
dc ConstructorNumber
cnum [StgTickish]
_ticks as :: [StgArg]
as@[StgLitArg Literal
l] Type
_typ
    | Literal -> Bool
isSmallSinkableLit Literal
l
    , Id -> Bool
isLocal Id
BinderP 'CodeGen
b
    , DataCon -> Bool
isUnboxableCon DataCon
dc
    -> [(Id
BinderP 'CodeGen
b,DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> CgStgExpr
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
StgConApp DataCon
dc ConstructorNumber
cnum [StgArg]
as [])]
  GenStgRhs 'CodeGen
_ -> []

isSmallSinkableLit :: Literal -> Bool
isSmallSinkableLit :: Literal -> Bool
isSmallSinkableLit (LitChar Char
c)     = Char -> Int
ord Char
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
100000
isSmallSinkableLit (LitNumber LitNumType
_ Integer
i) = Integer -> Integer
forall a. Num a => a -> a
abs Integer
i Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
100000
isSmallSinkableLit Literal
_               = Bool
False


-- | once sinkable: may be sunk, but duplication is not ok
onceSinkable :: Module -> CgStgBinding -> [(Id, CgStgExpr)]
onceSinkable :: Module -> GenStgBinding 'CodeGen -> [(Id, CgStgExpr)]
onceSinkable Module
_m (StgNonRec BinderP 'CodeGen
b GenStgRhs 'CodeGen
rhs)
  | Just CgStgExpr
e <- GenStgRhs 'CodeGen -> Maybe CgStgExpr
forall {pass :: StgPass}. GenStgRhs pass -> Maybe (GenStgExpr pass)
getSinkable GenStgRhs 'CodeGen
rhs
  , Id -> Bool
isLocal Id
BinderP 'CodeGen
b = [(Id
BinderP 'CodeGen
b,CgStgExpr
e)]
  where
    getSinkable :: GenStgRhs pass -> Maybe (GenStgExpr pass)
getSinkable = \case
      StgRhsCon CostCentreStack
_ccs DataCon
dc ConstructorNumber
cnum [StgTickish]
_ticks [StgArg]
args Type
_typ -> GenStgExpr pass -> Maybe (GenStgExpr pass)
forall a. a -> Maybe a
Just (DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
forall (pass :: StgPass).
DataCon
-> ConstructorNumber -> [StgArg] -> [[PrimRep]] -> GenStgExpr pass
StgConApp DataCon
dc ConstructorNumber
cnum [StgArg]
args [])
      StgRhsClosure XRhsClosure pass
_ CostCentreStack
_ UpdateFlag
_ [BinderP pass]
_ e :: GenStgExpr pass
e@(StgLit{}) Type
_typ -> GenStgExpr pass -> Maybe (GenStgExpr pass)
forall a. a -> Maybe a
Just GenStgExpr pass
e
      GenStgRhs pass
_                                       -> Maybe (GenStgExpr pass)
forall a. Maybe a
Nothing
onceSinkable Module
_ GenStgBinding 'CodeGen
_ = []

-- | collect all idents used only once in an argument at the top level
--   and never anywhere else
collectTopLevelUsedOnce :: IdSet -> [CgStgBinding] -> IdSet
collectTopLevelUsedOnce :: UniqSet Id -> [GenStgBinding 'CodeGen] -> UniqSet Id
collectTopLevelUsedOnce UniqSet Id
usedOnceIds [GenStgBinding 'CodeGen]
binds = UniqSet Id -> UniqSet Id -> UniqSet Id
forall a. UniqSet a -> UniqSet a -> UniqSet a
intersectUniqSets UniqSet Id
usedOnceIds ([Id] -> UniqSet Id
forall (t :: * -> *) a.
(Foldable t, Uniquable a) =>
t a -> UniqSet a
selectUsedOnce [Id]
top_args)
  where
    top_args :: [Id]
top_args = (GenStgBinding 'CodeGen -> [Id])
-> [GenStgBinding 'CodeGen] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap GenStgBinding 'CodeGen -> [Id]
collectArgsTop [GenStgBinding 'CodeGen]
binds

isLocal :: Id -> Bool
isLocal :: Id -> Bool
isLocal Id
i = Maybe Module -> Bool
forall a. Maybe a -> Bool
isNothing (Name -> Maybe Module
nameModule_maybe (Name -> Maybe Module) -> (Id -> Name) -> Id -> Maybe Module
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Name
idName (Id -> Maybe Module) -> Id -> Maybe Module
forall a b. (a -> b) -> a -> b
$ Id
i) Bool -> Bool -> Bool
&& Bool -> Bool
not (Id -> Bool
isExportedId Id
i)

-- | since we have sequential initialization, topsort the non-recursive
-- constructor bindings
topSortDecls :: Module -> [CgStgBinding] -> [CgStgBinding]
topSortDecls :: Module -> [GenStgBinding 'CodeGen] -> [GenStgBinding 'CodeGen]
topSortDecls Module
_m [GenStgBinding 'CodeGen]
binds = [GenStgBinding 'CodeGen]
rest [GenStgBinding 'CodeGen]
-> [GenStgBinding 'CodeGen] -> [GenStgBinding 'CodeGen]
forall a. [a] -> [a] -> [a]
++ [GenStgBinding 'CodeGen]
nr'
  where
    ([GenStgBinding 'CodeGen]
nr, [GenStgBinding 'CodeGen]
rest) = (GenStgBinding 'CodeGen -> Bool)
-> [GenStgBinding 'CodeGen]
-> ([GenStgBinding 'CodeGen], [GenStgBinding 'CodeGen])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition GenStgBinding 'CodeGen -> Bool
forall {pass :: StgPass}. GenStgBinding pass -> Bool
isNonRec [GenStgBinding 'CodeGen]
binds
    isNonRec :: GenStgBinding pass -> Bool
isNonRec StgNonRec{} = Bool
True
    isNonRec GenStgBinding pass
_           = Bool
False
    vs :: [Node Id (GenStgBinding 'CodeGen)]
vs   = (GenStgBinding 'CodeGen -> Node Id (GenStgBinding 'CodeGen))
-> [GenStgBinding 'CodeGen] -> [Node Id (GenStgBinding 'CodeGen)]
forall a b. (a -> b) -> [a] -> [b]
map GenStgBinding 'CodeGen -> Node Id (GenStgBinding 'CodeGen)
GenStgBinding 'CodeGen
-> Node (BinderP 'CodeGen) (GenStgBinding 'CodeGen)
forall {pass :: StgPass}.
GenStgBinding pass -> Node (BinderP pass) (GenStgBinding pass)
getV [GenStgBinding 'CodeGen]
nr
    keys :: UniqSet Id
keys = [Id] -> UniqSet Id
forall a. Uniquable a => [a] -> UniqSet a
mkUniqSet ((Node Id (GenStgBinding 'CodeGen) -> Id)
-> [Node Id (GenStgBinding 'CodeGen)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Node Id (GenStgBinding 'CodeGen) -> Id
forall key payload. Node key payload -> key
node_key [Node Id (GenStgBinding 'CodeGen)]
vs)
    getV :: GenStgBinding pass -> Node (BinderP pass) (GenStgBinding pass)
getV e :: GenStgBinding pass
e@(StgNonRec BinderP pass
b GenStgRhs pass
_) = GenStgBinding pass
-> BinderP pass
-> [BinderP pass]
-> Node (BinderP pass) (GenStgBinding pass)
forall key payload. payload -> key -> [key] -> Node key payload
DigraphNode GenStgBinding pass
e BinderP pass
b []
    getV GenStgBinding pass
_                 = [Char] -> Node (BinderP pass) (GenStgBinding pass)
forall a. HasCallStack => [Char] -> a
error [Char]
"topSortDecls: getV, unexpected binding"
    collectDeps :: GenStgBinding 'CodeGen -> [(Id, Id)]
collectDeps (StgNonRec BinderP 'CodeGen
b (StgRhsCon CostCentreStack
_cc DataCon
_dc ConstructorNumber
_cnum [StgTickish]
_ticks [StgArg]
args Type
_typ)) =
      [ (Id
i, Id
BinderP 'CodeGen
b) | StgVarArg Id
i <- [StgArg]
args, Id
i Id -> UniqSet Id -> Bool
forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` UniqSet Id
keys ]
    collectDeps GenStgBinding 'CodeGen
_ = []
    g :: Graph (Node Id (GenStgBinding 'CodeGen))
g = [Node Id (GenStgBinding 'CodeGen)]
-> [(Id, Id)] -> Graph (Node Id (GenStgBinding 'CodeGen))
forall key payload.
Ord key =>
[Node key payload] -> [(key, key)] -> Graph (Node key payload)
graphFromVerticesAndAdjacency [Node Id (GenStgBinding 'CodeGen)]
vs ((GenStgBinding 'CodeGen -> [(Id, Id)])
-> [GenStgBinding 'CodeGen] -> [(Id, Id)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap GenStgBinding 'CodeGen -> [(Id, Id)]
collectDeps [GenStgBinding 'CodeGen]
nr)
    nr' :: [GenStgBinding 'CodeGen]
nr' | (Bool -> Bool
not (Bool -> Bool) -> ([()] -> Bool) -> [()] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [()] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [()| CyclicSCC [Node Id (GenStgBinding 'CodeGen)]
_ <- Graph (Node Id (GenStgBinding 'CodeGen))
-> [SCC (Node Id (GenStgBinding 'CodeGen))]
forall node. Graph node -> [SCC node]
stronglyConnCompG Graph (Node Id (GenStgBinding 'CodeGen))
g]
            = [Char] -> [GenStgBinding 'CodeGen]
forall a. HasCallStack => [Char] -> a
error [Char]
"topSortDecls: unexpected cycle"
        | Bool
otherwise = (Node Id (GenStgBinding 'CodeGen) -> GenStgBinding 'CodeGen)
-> [Node Id (GenStgBinding 'CodeGen)] -> [GenStgBinding 'CodeGen]
forall a b. (a -> b) -> [a] -> [b]
map Node Id (GenStgBinding 'CodeGen) -> GenStgBinding 'CodeGen
forall key payload. Node key payload -> payload
node_payload (Graph (Node Id (GenStgBinding 'CodeGen))
-> [Node Id (GenStgBinding 'CodeGen)]
forall node. Graph node -> [node]
topologicalSortG Graph (Node Id (GenStgBinding 'CodeGen))
g)