{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998

\section[FloatOut]{Float bindings outwards (towards the top level)}

``Long-distance'' floating of bindings towards the top level.
-}



module GHC.Core.Opt.FloatOut ( floatOutwards ) where

import GHC.Prelude

import GHC.Core
import GHC.Core.Utils
import GHC.Core.Make
-- import GHC.Core.Opt.Arity ( exprArity, etaExpand )
import GHC.Core.Opt.Monad ( FloatOutSwitches(..) )

import GHC.Driver.Flags  ( DumpFlag (..) )
import GHC.Utils.Logger
import GHC.Types.Id      ( Id, idType,
--                           idArity, isDeadEndId,
                           isJoinId, idJoinPointHood )
import GHC.Types.Tickish
import GHC.Core.Opt.SetLevels
import GHC.Types.Unique.Supply ( UniqSupply )
import GHC.Data.Bag
import GHC.Utils.Misc
import GHC.Data.Maybe
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Core.Type
import qualified Data.IntMap as M

import Data.List        ( partition )

{-
        -----------------
        Overall game plan
        -----------------

The Big Main Idea is:

        To float out sub-expressions that can thereby get outside
        a non-one-shot value lambda, and hence may be shared.


To achieve this we may need to do two things:

   a) Let-bind the sub-expression:

        f (g x)  ==>  let lvl = f (g x) in lvl

      Now we can float the binding for 'lvl'.

   b) More than that, we may need to abstract wrt a type variable

        \x -> ... /\a -> let v = ...a... in ....

      Here the binding for v mentions 'a' but not 'x'.  So we
      abstract wrt 'a', to give this binding for 'v':

            vp = /\a -> ...a...
            v  = vp a

      Now the binding for vp can float out unimpeded.
      I can't remember why this case seemed important enough to
      deal with, but I certainly found cases where important floats
      didn't happen if we did not abstract wrt tyvars.

With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level.  Whether or not this
happens is controlled by a flag.


Random comments
~~~~~~~~~~~~~~~

At the moment we never float a binding out to between two adjacent
lambdas.  For example:

@
        \x y -> let t = x+x in ...
===>
        \x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.

But there's a case I've seen where this might not be true.  Consider:
@
elEm2 x ys
  = elem' x ys
  where
    elem' _ []  = False
    elem' x (y:ys)      = x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
        \deq x ys -> let eq = eqFromEqDict deq in ...
@
which might usefully be separated to
@
        \deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe.  We don't do this at the moment.


************************************************************************
*                                                                      *
\subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
*                                                                      *
************************************************************************
-}

floatOutwards :: Logger
              -> FloatOutSwitches
              -> UniqSupply
              -> CoreProgram -> IO CoreProgram

floatOutwards :: Logger
-> FloatOutSwitches -> UniqSupply -> [CoreBind] -> IO [CoreBind]
floatOutwards Logger
logger FloatOutSwitches
float_sws UniqSupply
us [CoreBind]
pgm
  = do {
        let { annotated_w_levels :: [LevelledBind]
annotated_w_levels = FloatOutSwitches -> [CoreBind] -> UniqSupply -> [LevelledBind]
setLevels FloatOutSwitches
float_sws [CoreBind]
pgm UniqSupply
us ;
              ([FloatStats]
fss, [Bag CoreBind]
binds_s')    = [(FloatStats, Bag CoreBind)] -> ([FloatStats], [Bag CoreBind])
forall a b. [(a, b)] -> ([a], [b])
unzip ((LevelledBind -> (FloatStats, Bag CoreBind))
-> [LevelledBind] -> [(FloatStats, Bag CoreBind)]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind [LevelledBind]
annotated_w_levels)
            } ;

        Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_verbose_core2core String
"Levels added:"
                  DumpFormat
FormatCore
                  ([SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat ((LevelledBind -> SDoc) -> [LevelledBind] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr [LevelledBind]
annotated_w_levels));

        let { (Int
tlets, Int
ntlets, Int
lams) = FloatStats -> (Int, Int, Int)
get_stats ([FloatStats] -> FloatStats
sum_stats [FloatStats]
fss) };

        Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_dump_simpl_stats String
"FloatOut stats:"
                DumpFormat
FormatText
                ([SDoc] -> SDoc
forall doc. IsLine doc => [doc] -> doc
hcat [ Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
tlets,  String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lets floated to top level; ",
                        Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
ntlets, String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lets floated elsewhere; from ",
                        Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
lams,   String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lambda groups"]);

        [CoreBind] -> IO [CoreBind]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bag CoreBind -> [CoreBind]
forall a. Bag a -> [a]
bagToList ([Bag CoreBind] -> Bag CoreBind
forall a. [Bag a] -> Bag a
unionManyBags [Bag CoreBind]
binds_s'))
    }

floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind LevelledBind
bind
  = case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fs, FloatBinds
floats, [CoreBind]
bind') ->
    let float_bag :: Bag CoreBind
float_bag = FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
floats
    in case [CoreBind]
bind' of
      -- bind' can't have unlifted values or join points, so can only be one
      -- value bind, rec or non-rec (see comment on floatBind)
      [Rec [(Id, Expr Id)]
prs]    -> (FloatStats
fs, CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag ([(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs Bag CoreBind
float_bag [(Id, Expr Id)]
prs)))
      [NonRec Id
b Expr Id
e] -> (FloatStats
fs, Bag CoreBind
float_bag Bag CoreBind -> CoreBind -> Bag CoreBind
forall a. Bag a -> a -> Bag a
`snocBag` Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
b Expr Id
e)
      [CoreBind]
_            -> String -> SDoc -> (FloatStats, Bag CoreBind)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatTopBind" ([CoreBind] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CoreBind]
bind') }

{-
************************************************************************
*                                                                      *
\subsection[FloatOut-Bind]{Floating in a binding (the business end)}
*                                                                      *
************************************************************************
-}

floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
  -- Returns a list with either
  --   * A single non-recursive binding (value or join point), or
  --   * The following, in order:
  --     * Zero or more non-rec unlifted bindings
  --     * One or both of:
  --       * A recursive group of join binds
  --       * A recursive group of value binds
  -- See Note [Floating out of Rec rhss] for why things get arranged this way.
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind (NonRec (TB Id
var FloatSpec
_) Expr (TaggedBndr FloatSpec)
rhs)
  = case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
var Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
      (FloatStats
fs, FloatBinds
rhs_floats, [Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
var Expr Id
rhs']) }

floatBind (Rec [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs)
  = case ((TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
 -> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)])))
-> [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
-> (FloatStats, FloatBinds, [([(Id, Expr Id)], [(Id, Expr Id)])])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)]))
do_pair [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs of { (FloatStats
fs, FloatBinds
rhs_floats, [([(Id, Expr Id)], [(Id, Expr Id)])]
new_pairs) ->
    let ([[(Id, Expr Id)]]
new_ul_pairss, [[(Id, Expr Id)]]
new_other_pairss) = [([(Id, Expr Id)], [(Id, Expr Id)])]
-> ([[(Id, Expr Id)]], [[(Id, Expr Id)]])
forall a b. [(a, b)] -> ([a], [b])
unzip [([(Id, Expr Id)], [(Id, Expr Id)])]
new_pairs
        ([(Id, Expr Id)]
new_join_pairs, [(Id, Expr Id)]
new_l_pairs)     = ((Id, Expr Id) -> Bool)
-> [(Id, Expr Id)] -> ([(Id, Expr Id)], [(Id, Expr Id)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Id -> Bool
isJoinId (Id -> Bool) -> ((Id, Expr Id) -> Id) -> (Id, Expr Id) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Id, Expr Id) -> Id
forall a b. (a, b) -> a
fst)
                                                      ([[(Id, Expr Id)]] -> [(Id, Expr Id)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, Expr Id)]]
new_other_pairss)
        -- Can't put the join points and the values in the same rec group
        new_rec_binds :: [CoreBind]
new_rec_binds | [(Id, Expr Id)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Id, Expr Id)]
new_join_pairs = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_l_pairs    ]
                      | [(Id, Expr Id)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Id, Expr Id)]
new_l_pairs    = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_join_pairs ]
                      | Bool
otherwise           = [ [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_l_pairs
                                              , [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Id, Expr Id)]
new_join_pairs ]
        new_non_rec_binds :: [CoreBind]
new_non_rec_binds = [ Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
b Expr Id
e | (Id
b, Expr Id
e) <- [[(Id, Expr Id)]] -> [(Id, Expr Id)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Id, Expr Id)]]
new_ul_pairss ]
    in
    (FloatStats
fs, FloatBinds
rhs_floats, [CoreBind]
new_non_rec_binds [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ [CoreBind]
new_rec_binds) }
  where
    do_pair :: (LevelledBndr, LevelledExpr)
            -> (FloatStats, FloatBinds,
                ([(Id,CoreExpr)],  -- Non-recursive unlifted value bindings
                 [(Id,CoreExpr)])) -- Join points and lifted value bindings
    do_pair :: (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Id, Expr Id)], [(Id, Expr Id)]))
do_pair (TB Id
name FloatSpec
spec, Expr (TaggedBndr FloatSpec)
rhs)
      | Level -> Bool
isTopLvl Level
dest_lvl  -- See Note [floatBind for top level]
      = case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
        (FloatStats
fs, FloatBinds
emptyFloats, ([], Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs (FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
rhs_floats)
                                                [(Id
name, Expr Id
rhs')]))}
      | Bool
otherwise         -- Note [Floating out of Rec rhss]
      = case (Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
        case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
dest_lvl FloatBinds
rhs_floats) of { (FloatBinds
rhs_floats', Bag FloatBind
heres) ->
        case (Bag FloatBind -> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
splitRecFloats Bag FloatBind
heres) of { ([(Id, Expr Id)]
ul_pairs, [(Id, Expr Id)]
pairs, Bag FloatBind
case_heres) ->
        let pairs' :: [(Id, Expr Id)]
pairs' = (Id
name, Bag FloatBind -> Expr Id -> Expr Id
installUnderLambdas Bag FloatBind
case_heres Expr Id
rhs') (Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
: [(Id, Expr Id)]
pairs in
        (FloatStats
fs, FloatBinds
rhs_floats', ([(Id, Expr Id)]
ul_pairs, [(Id, Expr Id)]
pairs')) }}}
      where
        dest_lvl :: Level
dest_lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
spec

splitRecFloats :: Bag FloatBind
               -> ([(Id,CoreExpr)], -- Non-recursive unlifted value bindings
                   [(Id,CoreExpr)], -- Join points and lifted value bindings
                   Bag FloatBind)   -- A tail of further bindings
-- The "tail" begins with a case
-- See Note [Floating out of Rec rhss]
splitRecFloats :: Bag FloatBind -> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
splitRecFloats Bag FloatBind
fs
  = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [] [] (Bag FloatBind -> [FloatBind]
forall a. Bag a -> [a]
bagToList Bag FloatBind
fs)
  where
    go :: [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs (FloatLet (NonRec Id
b Expr Id
r) : [FloatBind]
fs) | HasDebugCallStack => Type -> Bool
Type -> Bool
isUnliftedType (Id -> Type
idType Id
b)
                                               -- NB: isUnliftedType is OK here as binders always
                                               -- have a fixed RuntimeRep.
                                               , Bool -> Bool
not (Id -> Bool
isJoinId Id
b)
                                               = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go ((Id
b,Expr Id
r)(Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
:[(Id, Expr Id)]
ul_prs) [(Id, Expr Id)]
prs [FloatBind]
fs
                                               | Bool
otherwise
                                               = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs ((Id
b,Expr Id
r)(Id, Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. a -> [a] -> [a]
:[(Id, Expr Id)]
prs) [FloatBind]
fs
    go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs (FloatLet (Rec [(Id, Expr Id)]
prs')   : [FloatBind]
fs) = [(Id, Expr Id)]
-> [(Id, Expr Id)]
-> [FloatBind]
-> ([(Id, Expr Id)], [(Id, Expr Id)], Bag FloatBind)
go [(Id, Expr Id)]
ul_prs ([(Id, Expr Id)]
prs' [(Id, Expr Id)] -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. [a] -> [a] -> [a]
++ [(Id, Expr Id)]
prs) [FloatBind]
fs
    go [(Id, Expr Id)]
ul_prs [(Id, Expr Id)]
prs [FloatBind]
fs                           = ([(Id, Expr Id)] -> [(Id, Expr Id)]
forall a. [a] -> [a]
reverse [(Id, Expr Id)]
ul_prs, [(Id, Expr Id)]
prs,
                                                  [FloatBind] -> Bag FloatBind
forall a. [a] -> Bag a
listToBag [FloatBind]
fs)
                                                   -- Order only matters for
                                                   -- non-rec

installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
-- See Note [Floating out of Rec rhss]
installUnderLambdas :: Bag FloatBind -> Expr Id -> Expr Id
installUnderLambdas Bag FloatBind
floats Expr Id
e
  | Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
floats = Expr Id
e
  | Bool
otherwise         = Expr Id -> Expr Id
go Expr Id
e
  where
    go :: Expr Id -> Expr Id
go (Lam Id
b Expr Id
e)                 = Id -> Expr Id -> Expr Id
forall b. b -> Expr b -> Expr b
Lam Id
b (Expr Id -> Expr Id
go Expr Id
e)
    go Expr Id
e                         = Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
floats Expr Id
e

---------------
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList :: forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
_ [] = (FloatStats
zeroStats, FloatBinds
emptyFloats, [])
floatList a -> (FloatStats, FloatBinds, b)
f (a
a:[a]
as) = case a -> (FloatStats, FloatBinds, b)
f a
a            of { (FloatStats
fs_a,  FloatBinds
binds_a,  b
b)  ->
                     case (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
f [a]
as of { (FloatStats
fs_as, FloatBinds
binds_as, [b]
bs) ->
                     (FloatStats
fs_a FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fs_as, FloatBinds
binds_a FloatBinds -> FloatBinds -> FloatBinds
`plusFloats`  FloatBinds
binds_as, b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs) }}

{-
Note [Floating out of Rec rhss]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider   Rec { f<1,0> = \xy. body }
From the body we may get some floats. The ones with level <1,0> must
stay here, since they may mention f.  Ideally we'd like to make them
part of the Rec block pairs -- but we can't if there are any
FloatCases involved.

Nor is it a good idea to dump them in the rhs, but outside the lambda
    f = case x of I# y -> \xy. body
because now f's arity might get worse, which is Not Good. (And if
there's an SCC around the RHS it might not get better again.
See #5342.)

So, gruesomely, we split the floats into
 * the outer FloatLets, which can join the Rec, and
 * an inner batch starting in a FloatCase, which are then
   pushed *inside* the lambdas.
This loses full-laziness the rare situation where there is a
FloatCase and a Rec interacting.

If there are unlifted FloatLets (that *aren't* join points) among the floats,
we can't add them to the recursive group without angering Core Lint, but since
they must be ok-for-speculation, they can't actually be making any recursive
calls, so we can safely pull them out and keep them non-recursive.

(Why is something getting floated to <1,0> that doesn't make a recursive call?
The case that came up in testing was that f *and* the unlifted binding were
getting floated *to the same place*:

  \x<2,0> ->
    ... <3,0>
    letrec { f<F<2,0>> =
      ... let x'<F<2,0>> = x +# 1# in ...
    } in ...

Everything gets labeled "float to <2,0>" because it all depends on x, but this
makes f and x' look mutually recursive when they're not.

The test was shootout/k-nucleotide, as compiled using commit 47d5dd68 on the
wip/join-points branch.

TODO: This can probably be solved somehow in GHC.Core.Opt.SetLevels. The difference between
"this *is at* level <2,0>" and "this *depends on* level <2,0>" is very
important.)

Note [floatBind for top level]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We may have a *nested* binding whose destination level is (FloatMe tOP_LEVEL), thus
         letrec { foo <0,0> = .... (let bar<0,0> = .. in ..) .... }
The binding for bar will be in the "tops" part of the floating binds,
and thus not partitioned by floatBody.

We could perhaps get rid of the 'tops' component of the floating binds,
but this case works just as well.


************************************************************************

\subsection[FloatOut-Expr]{Floating in expressions}
*                                                                      *
************************************************************************
-}

floatBody :: Level
          -> LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)

floatBody :: Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
arg       -- Used rec rhss, and case-alternative rhss
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
arg) of { (FloatStats
fsa, FloatBinds
floats, Expr Id
arg') ->
    case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
lvl FloatBinds
floats) of { (FloatBinds
floats', Bag FloatBind
heres) ->
        -- Dump bindings are bound here
    (FloatStats
fsa, FloatBinds
floats', Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
heres Expr Id
arg') }}

-----------------

{- Note [Floating past breakpoints]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We used to disallow floating out of breakpoint ticks (see #10052). However, I
think this is too restrictive.

Consider the case of an expression scoped over by a breakpoint tick,

  tick<...> (let x = ... in f x)

In this case it is completely legal to float out x, despite the fact that
breakpoint ticks are scoped,

  let x = ... in (tick<...>  f x)

The reason here is that we know that the breakpoint will still be hit when the
expression is entered since the tick still scopes over the RHS.

-}

floatExpr :: LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)
floatExpr :: Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr (Var Id
v)   = (FloatStats
zeroStats, FloatBinds
emptyFloats, Id -> Expr Id
forall b. Id -> Expr b
Var Id
v)
floatExpr (Type Type
ty) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Type -> Expr Id
forall b. Type -> Expr b
Type Type
ty)
floatExpr (Coercion Coercion
co) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Coercion -> Expr Id
forall b. Coercion -> Expr b
Coercion Coercion
co)
floatExpr (Lit Literal
lit) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Literal -> Expr Id
forall b. Literal -> Expr b
Lit Literal
lit)

floatExpr (App Expr (TaggedBndr FloatSpec)
e Expr (TaggedBndr FloatSpec)
a)
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr  Expr (TaggedBndr FloatSpec)
e) of { (FloatStats
fse, FloatBinds
floats_e, Expr Id
e') ->
    case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr  Expr (TaggedBndr FloatSpec)
a) of { (FloatStats
fsa, FloatBinds
floats_a, Expr Id
a') ->
    (FloatStats
fse FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fsa, FloatBinds
floats_e FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
floats_a, Expr Id -> Expr Id -> Expr Id
forall b. Expr b -> Expr b -> Expr b
App Expr Id
e' Expr Id
a') }}

floatExpr lam :: Expr (TaggedBndr FloatSpec)
lam@(Lam (TB Id
_ FloatSpec
lam_spec) Expr (TaggedBndr FloatSpec)
_)
  = let ([TaggedBndr FloatSpec]
bndrs_w_lvls, Expr (TaggedBndr FloatSpec)
body) = Expr (TaggedBndr FloatSpec)
-> ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall b. Expr b -> ([b], Expr b)
collectBinders Expr (TaggedBndr FloatSpec)
lam
        bndrs :: [Id]
bndrs                = [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs_w_lvls]
        bndr_lvl :: Level
bndr_lvl             = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec
        -- All the binders have the same level
        -- See GHC.Core.Opt.SetLevels.lvlLamBndrs
        -- Use asJoinCeilLvl to make this the join ceiling
    in
    case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bndr_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fs, FloatBinds
floats, Expr Id
body') ->
    (FloatStats -> FloatBinds -> FloatStats
add_to_stats FloatStats
fs FloatBinds
floats, FloatBinds
floats, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id]
bndrs Expr Id
body') }

floatExpr (Tick CoreTickish
tickish Expr (TaggedBndr FloatSpec)
expr)
  | CoreTickish
tickish CoreTickish -> TickishScoping -> Bool
forall (pass :: TickishPass).
GenTickish pass -> TickishScoping -> Bool
`tickishScopesLike` TickishScoping
SoftScope -- not scoped, can just float
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr') }

  | Bool -> Bool
not (CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCounts CoreTickish
tickish) Bool -> Bool -> Bool
|| CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCanSplit CoreTickish
tickish
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    let -- Annotate bindings floated outwards past an scc expression
        -- with the cc.  We mark that cc as "duplicated", though.
        annotated_defns :: FloatBinds
annotated_defns = CoreTickish -> FloatBinds -> FloatBinds
wrapTick (CoreTickish -> CoreTickish
forall (pass :: TickishPass). GenTickish pass -> GenTickish pass
mkNoCount CoreTickish
tickish) FloatBinds
floating_defns
    in
    (FloatStats
fs, FloatBinds
annotated_defns, CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr') }

  -- See Note [Floating past breakpoints]
  | Breakpoint{} <- CoreTickish
tickish
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, CoreTickish -> Expr Id -> Expr Id
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Id
expr') }

  | Bool
otherwise
  = String -> SDoc -> (FloatStats, FloatBinds, Expr Id)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatExpr tick" (CoreTickish -> SDoc
forall a. Outputable a => a -> SDoc
ppr CoreTickish
tickish)

floatExpr (Cast Expr (TaggedBndr FloatSpec)
expr Coercion
co)
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Id
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, Expr Id -> Coercion -> Expr Id
forall b. Expr b -> Coercion -> Expr b
Cast Expr Id
expr' Coercion
co) }

floatExpr (Let LevelledBind
bind Expr (TaggedBndr FloatSpec)
body)
  = case FloatSpec
bind_spec of
      FloatMe Level
dest_lvl
        -> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
           case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Id
body') ->
           let new_bind_floats :: FloatBinds
new_bind_floats = (FloatBinds -> FloatBinds -> FloatBinds)
-> FloatBinds -> [FloatBinds] -> FloatBinds
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBinds -> FloatBinds -> FloatBinds
plusFloats FloatBinds
emptyFloats
                                   ((CoreBind -> FloatBinds) -> [CoreBind] -> [FloatBinds]
forall a b. (a -> b) -> [a] -> [b]
map (Level -> CoreBind -> FloatBinds
unitLetFloat Level
dest_lvl) [CoreBind]
binds') in
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
new_bind_floats
                         FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , Expr Id
body') }}

      StayPut Level
bind_lvl  -- See Note [Avoiding unnecessary floating]
        -> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind)          of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
           case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Id
body') ->
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , (CoreBind -> Expr Id -> Expr Id)
-> Expr Id -> [CoreBind] -> Expr Id
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> Expr Id -> Expr Id
forall b. Bind b -> Expr b -> Expr b
Let Expr Id
body' [CoreBind]
binds' ) }}
  where
    bind_spec :: FloatSpec
bind_spec = case LevelledBind
bind of
                 NonRec (TB Id
_ FloatSpec
s) Expr (TaggedBndr FloatSpec)
_     -> FloatSpec
s
                 Rec ((TB Id
_ FloatSpec
s, Expr (TaggedBndr FloatSpec)
_) : [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
_) -> FloatSpec
s
                 Rec []                -> String -> FloatSpec
forall a. HasCallStack => String -> a
panic String
"floatExpr:rec"

floatExpr (Case Expr (TaggedBndr FloatSpec)
scrut (TB Id
case_bndr FloatSpec
case_spec) Type
ty [Alt (TaggedBndr FloatSpec)]
alts)
  = case FloatSpec
case_spec of
      FloatMe Level
dest_lvl  -- Case expression moves
        | [Alt con :: AltCon
con@(DataAlt {}) [TaggedBndr FloatSpec]
bndrs Expr (TaggedBndr FloatSpec)
rhs] <- [Alt (TaggedBndr FloatSpec)]
alts
        -> case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Id
scrut') ->
           case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs   of { (FloatStats
fsb, FloatBinds
fdb, Expr Id
rhs') ->
           let
             float :: FloatBinds
float = Level -> Expr Id -> Id -> AltCon -> [Id] -> FloatBinds
unitCaseFloat Level
dest_lvl Expr Id
scrut'
                          Id
case_bndr AltCon
con [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs]
           in
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsb, FloatBinds
fde FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
float FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fdb, Expr Id
rhs') }}
        | Bool
otherwise
        -> String -> SDoc -> (FloatStats, FloatBinds, Expr Id)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Floating multi-case" ([Alt (TaggedBndr FloatSpec)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Alt (TaggedBndr FloatSpec)]
alts)

      StayPut Level
bind_lvl  -- Case expression stays put
        -> case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Id
scrut') ->
           case (Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Id))
-> [Alt (TaggedBndr FloatSpec)]
-> (FloatStats, FloatBinds, [Alt Id])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (Level
-> Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Id)
float_alt Level
bind_lvl) [Alt (TaggedBndr FloatSpec)]
alts of { (FloatStats
fsa, FloatBinds
fda, [Alt Id]
alts')  ->
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsa, FloatBinds
fda FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fde, Expr Id -> Id -> Type -> [Alt Id] -> Expr Id
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr Id
scrut' Id
case_bndr Type
ty [Alt Id]
alts')
           }}
  where
    float_alt :: Level
-> Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Id)
float_alt Level
bind_lvl (Alt AltCon
con [TaggedBndr FloatSpec]
bs Expr (TaggedBndr FloatSpec)
rhs)
        = case (Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Id
rhs') ->
          (FloatStats
fs, FloatBinds
rhs_floats, AltCon -> [Id] -> Expr Id -> Alt Id
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
con [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bs] Expr Id
rhs') }

floatRhs :: CoreBndr
         -> LevelledExpr
         -> (FloatStats, FloatBinds, CoreExpr)
floatRhs :: Id
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatRhs Id
bndr Expr (TaggedBndr FloatSpec)
rhs
  | JoinPoint Int
join_arity <- Id -> JoinPointHood
idJoinPointHood Id
bndr
  , Just ([TaggedBndr FloatSpec]
bndrs, Expr (TaggedBndr FloatSpec)
body) <- Int
-> Expr (TaggedBndr FloatSpec)
-> [TaggedBndr FloatSpec]
-> Maybe ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall {t} {a}.
(Eq t, Num t) =>
t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect Int
join_arity Expr (TaggedBndr FloatSpec)
rhs []
  = case [TaggedBndr FloatSpec]
bndrs of
      []                -> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
      (TB Id
_ FloatSpec
lam_spec):[TaggedBndr FloatSpec]
_ ->
        let lvl :: Level
lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec in
        case Level
-> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
body of { (FloatStats
fs, FloatBinds
floats, Expr Id
body') ->
        (FloatStats
fs, FloatBinds
floats, [Id] -> Expr Id -> Expr Id
forall b. [b] -> Expr b -> Expr b
mkLams [Id
b | TB Id
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs] Expr Id
body') }
  | Bool
otherwise
  = Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Id)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
  where
    try_collect :: t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect t
0 Expr a
expr      [a]
acc = ([a], Expr a) -> Maybe ([a], Expr a)
forall a. a -> Maybe a
Just ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, Expr a
expr)
    try_collect t
n (Lam a
b Expr a
e) [a]
acc = t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) Expr a
e (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc)
    try_collect t
_ Expr a
_         [a]
_   = Maybe ([a], Expr a)
forall a. Maybe a
Nothing

{-
Note [Avoiding unnecessary floating]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In general we want to avoid floating a let unnecessarily, because
it might worsen strictness:
    let
       x = ...(let y = e in y+y)....
Here y is demanded.  If we float it outside the lazy 'x=..' then
we'd have to zap its demand info, and it may never be restored.

So at a 'let' we leave the binding right where the are unless
the binding will escape a value lambda, e.g.

(\x -> let y = fac 100 in y)

That's what the partitionByMajorLevel does in the floatExpr (Let ...)
case.

Notice, though, that we must take care to drop any bindings
from the body of the let that depend on the staying-put bindings.

We used instead to do the partitionByMajorLevel on the RHS of an '=',
in floatRhs.  But that was quite tiresome.  We needed to test for
values or trivial rhss, because (in particular) we don't want to insert
new bindings between the "=" and the "\".  E.g.
        f = \x -> let <bind> in <body>
We do not want
        f = let <bind> in \x -> <body>
(a) The simplifier will immediately float it further out, so we may
        as well do so right now; in general, keeping rhss as manifest
        values is good
(b) If a float-in pass follows immediately, it might add yet more
        bindings just after the '='.  And some of them might (correctly)
        be strict even though the 'let f' is lazy, because f, being a value,
        gets its demand-info zapped by the simplifier.
And even all that turned out to be very fragile, and broke
altogether when profiling got in the way.

So now we do the partition right at the (Let..) itself.

************************************************************************
*                                                                      *
\subsection{Utility bits for floating stats}
*                                                                      *
************************************************************************

I didn't implement this with unboxed numbers.  I don't want to be too
strict in this stuff, as it is rarely turned on.  (WDP 95/09)
-}

data FloatStats
  = FlS Int  -- Number of top-floats * lambda groups they've been past
        Int  -- Number of non-top-floats * lambda groups they've been past
        Int  -- Number of lambda (groups) seen

get_stats :: FloatStats -> (Int, Int, Int)
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS Int
a Int
b Int
c) = (Int
a, Int
b, Int
c)

zeroStats :: FloatStats
zeroStats :: FloatStats
zeroStats = Int -> Int -> Int -> FloatStats
FlS Int
0 Int
0 Int
0

sum_stats :: [FloatStats] -> FloatStats
sum_stats :: [FloatStats] -> FloatStats
sum_stats [FloatStats]
xs = (FloatStats -> FloatStats -> FloatStats)
-> FloatStats -> [FloatStats] -> FloatStats
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
zeroStats [FloatStats]
xs

add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS Int
a1 Int
b1 Int
c1) (FlS Int
a2 Int
b2 Int
c2)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a2) (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b2) (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)

add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS Int
a Int
b Int
c) (FB Bag CoreBind
tops MajorEnv
others)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag CoreBind -> Int
forall a. Bag a -> Int
lengthBag Bag CoreBind
tops)
        (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
others))
        (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

{-
************************************************************************
*                                                                      *
\subsection{Utility bits for floating}
*                                                                      *
************************************************************************

Note [Representation of FloatBinds]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The FloatBinds types is somewhat important.  We can get very large numbers
of floating bindings, often all destined for the top level.  A typical example
is     x = [4,2,5,2,5, .... ]
Then we get lots of small expressions like (fromInteger 4), which all get
lifted to top level.

The trouble is that
  (a) we partition these floating bindings *at every binding site*
  (b) GHC.Core.Opt.SetLevels introduces a new bindings site for every float
So we had better not look at each binding at each binding site!

That is why MajorEnv is represented as a finite map.

We keep the bindings destined for the *top* level separate, because
we float them out even if they don't escape a *value* lambda; see
partitionByMajorLevel.
-}

type FloatLet = CoreBind        -- INVARIANT: a FloatLet is always lifted
type MajorEnv = M.IntMap MinorEnv         -- Keyed by major level
type MinorEnv = M.IntMap (Bag FloatBind)  -- Keyed by minor level

data FloatBinds  = FB !(Bag FloatLet)           -- Destined for top level
                      !MajorEnv                 -- Other levels
     -- See Note [Representation of FloatBinds]

instance Outputable FloatBinds where
  ppr :: FloatBinds -> SDoc
ppr (FB Bag CoreBind
fbs MajorEnv
defs)
      = String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"FB" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> (SDoc -> SDoc
forall doc. IsLine doc => doc -> doc
braces (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat
           [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"tops ="     SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Bag CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag CoreBind
fbs
           , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"non-tops =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs ])

flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB Bag CoreBind
tops MajorEnv
defs)
  = Bool -> SDoc -> Bag CoreBind -> Bag CoreBind
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
defs)) (MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs) (Bag CoreBind -> Bag CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> a -> b
$
    Bag CoreBind
tops

addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs :: Bag CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
addTopFloatPairs Bag CoreBind
float_bag [(Id, Expr Id)]
prs
  = (CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)])
-> [(Id, Expr Id)] -> Bag CoreBind -> [(Id, Expr Id)]
forall a b. (a -> b -> b) -> b -> Bag a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall {a}. Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add [(Id, Expr Id)]
prs Bag CoreBind
float_bag
  where
    add :: Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add (NonRec a
b Expr a
r) [(a, Expr a)]
prs  = (a
b,Expr a
r)(a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
:[(a, Expr a)]
prs
    add (Rec [(a, Expr a)]
prs1)   [(a, Expr a)]
prs2 = [(a, Expr a)]
prs1 [(a, Expr a)] -> [(a, Expr a)] -> [(a, Expr a)]
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)]
prs2

flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = (IntMap (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MajorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr (Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> (IntMap (Bag FloatBind) -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> Bag FloatBind
-> Bag FloatBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor) Bag FloatBind
forall a. Bag a
emptyBag

flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor :: IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> IntMap (Bag FloatBind) -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags Bag FloatBind
forall a. Bag a
emptyBag

emptyFloats :: FloatBinds
emptyFloats :: FloatBinds
emptyFloats = Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty

unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat :: Level -> Expr Id -> Id -> AltCon -> [Id] -> FloatBinds
unitCaseFloat (Level Int
major Int
minor) Expr Id
e Id
b AltCon
con [Id]
bs
  = Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (Expr Id -> Id -> AltCon -> [Id] -> FloatBind
FloatCase Expr Id
e Id
b AltCon
con [Id]
bs)

unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat :: Level -> CoreBind -> FloatBinds
unitLetFloat lvl :: Level
lvl@(Level Int
major Int
minor) CoreBind
b
  | Level -> Bool
isTopLvl Level
lvl = Bag CoreBind -> MajorEnv -> FloatBinds
FB (CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag CoreBind
b) MajorEnv
forall a. IntMap a
M.empty
  | Bool
otherwise    = Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (CoreBind -> FloatBind
FloatLet CoreBind
b)

plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB Bag CoreBind
t1 MajorEnv
l1) (FB Bag CoreBind
t2 MajorEnv
l2)
  = Bag CoreBind -> MajorEnv -> FloatBinds
FB (Bag CoreBind
t1 Bag CoreBind -> Bag CoreBind -> Bag CoreBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag CoreBind
t2) (MajorEnv
l1 MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` MajorEnv
l2)

plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = (IntMap (Bag FloatBind)
 -> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv -> MajorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor

plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor :: IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags

install :: Bag FloatBind -> CoreExpr -> CoreExpr
install :: Bag FloatBind -> Expr Id -> Expr Id
install Bag FloatBind
defn_groups Expr Id
expr
  = (FloatBind -> Expr Id -> Expr Id)
-> Expr Id -> Bag FloatBind -> Expr Id
forall a b. (a -> b -> b) -> b -> Bag a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBind -> Expr Id -> Expr Id
wrapFloat Expr Id
expr Bag FloatBind
defn_groups

partitionByLevel
        :: Level                -- Partitioning level
        -> FloatBinds           -- Defns to be divided into 2 piles...
        -> (FloatBinds,         -- Defns  with level strictly < partition level,
            Bag FloatBind)      -- The rest

{-
--       ---- partitionByMajorLevel ----
-- Float it if we escape a value lambda,
--     *or* if we get to the top level
--     *or* if it's a case-float and its minor level is < current
--
-- If we can get to the top level, say "yes" anyway. This means that
--      x = f e
-- transforms to
--    lvl = e
--    x = f lvl
-- which is as it should be

partitionByMajorLevel (Level major _) (FB tops defns)
  = (FB tops outer, heres `unionBags` flattenMajor inner)
  where
    (outer, mb_heres, inner) = M.splitLookup major defns
    heres = case mb_heres of
               Nothing -> emptyBag
               Just h  -> flattenMinor h
-}

partitionByLevel :: Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel (Level Int
major Int
minor) (FB Bag CoreBind
tops MajorEnv
defns)
  = (Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops (MajorEnv
outer_maj MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major IntMap (Bag FloatBind)
outer_min),
     Bag FloatBind
here_min Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor IntMap (Bag FloatBind)
inner_min
              Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
inner_maj)

  where
    (MajorEnv
outer_maj, Maybe (IntMap (Bag FloatBind))
mb_here_maj, MajorEnv
inner_maj) = Int
-> MajorEnv -> (MajorEnv, Maybe (IntMap (Bag FloatBind)), MajorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
major MajorEnv
defns
    (IntMap (Bag FloatBind)
outer_min, Maybe (Bag FloatBind)
mb_here_min, IntMap (Bag FloatBind)
inner_min) = case Maybe (IntMap (Bag FloatBind))
mb_here_maj of
                                            Maybe (IntMap (Bag FloatBind))
Nothing -> (IntMap (Bag FloatBind)
forall a. IntMap a
M.empty, Maybe (Bag FloatBind)
forall a. Maybe a
Nothing, IntMap (Bag FloatBind)
forall a. IntMap a
M.empty)
                                            Just IntMap (Bag FloatBind)
min_defns -> Int
-> IntMap (Bag FloatBind)
-> (IntMap (Bag FloatBind), Maybe (Bag FloatBind),
    IntMap (Bag FloatBind))
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
minor IntMap (Bag FloatBind)
min_defns
    here_min :: Bag FloatBind
here_min = Maybe (Bag FloatBind)
mb_here_min Maybe (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a. Maybe a -> a -> a
`orElse` Bag FloatBind
forall a. Bag a
emptyBag

wrapTick :: CoreTickish -> FloatBinds -> FloatBinds
wrapTick :: CoreTickish -> FloatBinds -> FloatBinds
wrapTick CoreTickish
t (FB Bag CoreBind
tops MajorEnv
defns)
  = Bag CoreBind -> MajorEnv -> FloatBinds
FB ((CoreBind -> CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag CoreBind -> CoreBind
wrap_bind Bag CoreBind
tops)
       ((IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map ((Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map Bag FloatBind -> Bag FloatBind
wrap_defns) MajorEnv
defns)
  where
    wrap_defns :: Bag FloatBind -> Bag FloatBind
wrap_defns = (FloatBind -> FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag FloatBind -> FloatBind
wrap_one

    wrap_bind :: CoreBind -> CoreBind
wrap_bind (NonRec Id
binder Expr Id
rhs) = Id -> Expr Id -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Id
binder (Expr Id -> Expr Id
maybe_tick Expr Id
rhs)
    wrap_bind (Rec [(Id, Expr Id)]
pairs)         = [(Id, Expr Id)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr Id -> Expr Id) -> [(Id, Expr Id)] -> [(Id, Expr Id)]
forall (f :: * -> *) b c a.
Functor f =>
(b -> c) -> f (a, b) -> f (a, c)
mapSnd Expr Id -> Expr Id
maybe_tick [(Id, Expr Id)]
pairs)

    wrap_one :: FloatBind -> FloatBind
wrap_one (FloatLet CoreBind
bind)      = CoreBind -> FloatBind
FloatLet (CoreBind -> CoreBind
wrap_bind CoreBind
bind)
    wrap_one (FloatCase Expr Id
e Id
b AltCon
c [Id]
bs) = Expr Id -> Id -> AltCon -> [Id] -> FloatBind
FloatCase (Expr Id -> Expr Id
maybe_tick Expr Id
e) Id
b AltCon
c [Id]
bs

    maybe_tick :: Expr Id -> Expr Id
maybe_tick Expr Id
e | Expr Id -> Bool
exprIsHNF Expr Id
e = CoreTickish -> Expr Id -> Expr Id
tickHNFArgs CoreTickish
t Expr Id
e
                 | Bool
otherwise   = CoreTickish -> Expr Id -> Expr Id
mkTick CoreTickish
t Expr Id
e
      -- we don't need to wrap a tick around an HNF when we float it
      -- outside a tick: that is an invariant of the tick semantics
      -- Conversely, inlining of HNFs inside an SCC is allowed, and
      -- indeed the HNF we're floating here might well be inlined back
      -- again, and we don't want to end up with duplicate ticks.