{-# LANGUAGE LambdaCase #-}

module GHC.StgToJS.Sinker.Collect
  ( collectArgsTop
  , collectArgs
  , selectUsedOnce
  )
  where

import GHC.Prelude
import GHC.Types.Unique.Set
import GHC.Stg.Syntax
import GHC.Types.Id
import GHC.Types.Unique

-- | fold over all id in StgArg used at the top level in an StgRhsCon
collectArgsTop :: CgStgBinding -> [Id]
collectArgsTop :: CgStgBinding -> [Id]
collectArgsTop = \case
  StgNonRec BinderP 'CodeGen
_b GenStgRhs 'CodeGen
r -> GenStgRhs 'CodeGen -> [Id]
collectArgsTopRhs GenStgRhs 'CodeGen
r
  StgRec [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs      -> ((Id, GenStgRhs 'CodeGen) -> [Id])
-> [(Id, GenStgRhs 'CodeGen)] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (GenStgRhs 'CodeGen -> [Id]
collectArgsTopRhs (GenStgRhs 'CodeGen -> [Id])
-> ((Id, GenStgRhs 'CodeGen) -> GenStgRhs 'CodeGen)
-> (Id, GenStgRhs 'CodeGen)
-> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, GenStgRhs 'CodeGen) -> GenStgRhs 'CodeGen
forall a b. (a, b) -> b
snd) [(Id, GenStgRhs 'CodeGen)]
[(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs
  where
    collectArgsTopRhs :: CgStgRhs -> [Id]
    collectArgsTopRhs :: GenStgRhs 'CodeGen -> [Id]
collectArgsTopRhs = \case
      StgRhsCon CostCentreStack
_ccs DataCon
_dc ConstructorNumber
_mu [StgTickish]
_ticks [StgArg]
args Type
_typ -> (StgArg -> [Id]) -> [StgArg] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
      StgRhsClosure {}                        -> []

-- | fold over all Id in StgArg in the AST
collectArgs :: CgStgBinding -> [Id]
collectArgs :: CgStgBinding -> [Id]
collectArgs = \case
  StgNonRec BinderP 'CodeGen
_b GenStgRhs 'CodeGen
r -> GenStgRhs 'CodeGen -> [Id]
collectArgsR GenStgRhs 'CodeGen
r
  StgRec [(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs      -> ((Id, GenStgRhs 'CodeGen) -> [Id])
-> [(Id, GenStgRhs 'CodeGen)] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (GenStgRhs 'CodeGen -> [Id]
collectArgsR (GenStgRhs 'CodeGen -> [Id])
-> ((Id, GenStgRhs 'CodeGen) -> GenStgRhs 'CodeGen)
-> (Id, GenStgRhs 'CodeGen)
-> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, GenStgRhs 'CodeGen) -> GenStgRhs 'CodeGen
forall a b. (a, b) -> b
snd) [(Id, GenStgRhs 'CodeGen)]
[(BinderP 'CodeGen, GenStgRhs 'CodeGen)]
bs
  where
    collectArgsR :: CgStgRhs -> [Id]
    collectArgsR :: GenStgRhs 'CodeGen -> [Id]
collectArgsR = \case
      StgRhsClosure XRhsClosure 'CodeGen
_x0 CostCentreStack
_x1 UpdateFlag
_x2 [BinderP 'CodeGen]
_x3 GenStgExpr 'CodeGen
e Type
_typ     -> GenStgExpr 'CodeGen -> [Id]
collectArgsE GenStgExpr 'CodeGen
e
      StgRhsCon CostCentreStack
_ccs DataCon
_con ConstructorNumber
_mu [StgTickish]
_ticks [StgArg]
args Type
_typ -> (StgArg -> [Id]) -> [StgArg] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args

    collectArgsAlt :: CgStgAlt -> [Id]
    collectArgsAlt :: CgStgAlt -> [Id]
collectArgsAlt CgStgAlt
alt = GenStgExpr 'CodeGen -> [Id]
collectArgsE (CgStgAlt -> GenStgExpr 'CodeGen
forall (pass :: StgPass). GenStgAlt pass -> GenStgExpr pass
alt_rhs CgStgAlt
alt)

    collectArgsE :: CgStgExpr -> [Id]
    collectArgsE :: GenStgExpr 'CodeGen -> [Id]
collectArgsE = \case
      StgApp Id
x [StgArg]
args
        -> Id
x Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: (StgArg -> [Id]) -> [StgArg] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
      StgConApp DataCon
_con ConstructorNumber
_mn [StgArg]
args [[PrimRep]]
_ts
        -> (StgArg -> [Id]) -> [StgArg] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
      StgOpApp StgOp
_x [StgArg]
args Type
_t
        -> (StgArg -> [Id]) -> [StgArg] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap StgArg -> [Id]
collectArgsA [StgArg]
args
      StgCase GenStgExpr 'CodeGen
e BinderP 'CodeGen
_b AltType
_a [CgStgAlt]
alts
        -> GenStgExpr 'CodeGen -> [Id]
collectArgsE GenStgExpr 'CodeGen
e [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ (CgStgAlt -> [Id]) -> [CgStgAlt] -> [Id]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap CgStgAlt -> [Id]
collectArgsAlt [CgStgAlt]
alts
      StgLet XLet 'CodeGen
_x CgStgBinding
b GenStgExpr 'CodeGen
e
        -> CgStgBinding -> [Id]
collectArgs CgStgBinding
b [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ GenStgExpr 'CodeGen -> [Id]
collectArgsE GenStgExpr 'CodeGen
e
      StgLetNoEscape XLetNoEscape 'CodeGen
_x CgStgBinding
b GenStgExpr 'CodeGen
e
        -> CgStgBinding -> [Id]
collectArgs CgStgBinding
b [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ GenStgExpr 'CodeGen -> [Id]
collectArgsE GenStgExpr 'CodeGen
e
      StgTick StgTickish
_i GenStgExpr 'CodeGen
e
        -> GenStgExpr 'CodeGen -> [Id]
collectArgsE GenStgExpr 'CodeGen
e
      StgLit Literal
_
        -> []

collectArgsA :: StgArg -> [Id]
collectArgsA :: StgArg -> [Id]
collectArgsA = \case
  StgVarArg Id
i -> [Id
i]
  StgLitArg Literal
_ -> []

selectUsedOnce :: (Foldable t, Uniquable a) => t a -> UniqSet a
selectUsedOnce :: forall (t :: * -> *) a.
(Foldable t, Uniquable a) =>
t a -> UniqSet a
selectUsedOnce = (UniqSet a, UniqSet a) -> UniqSet a
forall a b. (a, b) -> a
fst ((UniqSet a, UniqSet a) -> UniqSet a)
-> (t a -> (UniqSet a, UniqSet a)) -> t a -> UniqSet a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> (UniqSet a, UniqSet a) -> (UniqSet a, UniqSet a))
-> (UniqSet a, UniqSet a) -> t a -> (UniqSet a, UniqSet a)
forall a b. (a -> b -> b) -> b -> t a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> (UniqSet a, UniqSet a) -> (UniqSet a, UniqSet a)
forall {a}.
Uniquable a =>
a -> (UniqSet a, UniqSet a) -> (UniqSet a, UniqSet a)
g (UniqSet a
forall a. UniqSet a
emptyUniqSet, UniqSet a
forall a. UniqSet a
emptyUniqSet)
  where
    g :: a -> (UniqSet a, UniqSet a) -> (UniqSet a, UniqSet a)
g a
i t :: (UniqSet a, UniqSet a)
t@(UniqSet a
once, UniqSet a
mult)
      | a
i a -> UniqSet a -> Bool
forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` UniqSet a
mult = (UniqSet a, UniqSet a)
t
      | a
i a -> UniqSet a -> Bool
forall a. Uniquable a => a -> UniqSet a -> Bool
`elementOfUniqSet` UniqSet a
once
        = (UniqSet a -> a -> UniqSet a
forall a. Uniquable a => UniqSet a -> a -> UniqSet a
delOneFromUniqSet UniqSet a
once a
i, UniqSet a -> a -> UniqSet a
forall a. Uniquable a => UniqSet a -> a -> UniqSet a
addOneToUniqSet UniqSet a
mult a
i)
      | Bool
otherwise = (UniqSet a -> a -> UniqSet a
forall a. Uniquable a => UniqSet a -> a -> UniqSet a
addOneToUniqSet UniqSet a
once a
i, UniqSet a
mult)