{-
(c) The GRASP/AQUA Project, Glasgow University, 1992-1998

\section[FloatOut]{Float bindings outwards (towards the top level)}

``Long-distance'' floating of bindings towards the top level.
-}



module GHC.Core.Opt.FloatOut ( floatOutwards ) where

import GHC.Prelude

import GHC.Core
import GHC.Core.Utils
import GHC.Core.Make
-- import GHC.Core.Opt.Arity ( exprArity, etaExpand )
import GHC.Core.Opt.Monad ( FloatOutSwitches(..) )

import GHC.Driver.Flags  ( DumpFlag (..) )
import GHC.Utils.Logger
import GHC.Types.Id      ( Id, idType,
--                           idArity, isDeadEndId,
                           isJoinId, idJoinPointHood )
import GHC.Types.Tickish
import GHC.Core.Opt.SetLevels
import GHC.Types.Unique.Supply ( UniqSupply )
import GHC.Data.Bag
import GHC.Utils.Misc
import GHC.Data.Maybe
import GHC.Utils.Outputable
import GHC.Utils.Panic
import GHC.Core.Type
import qualified Data.IntMap as M

import Data.List        ( partition )

{-
        -----------------
        Overall game plan
        -----------------

The Big Main Idea is:

        To float out sub-expressions that can thereby get outside
        a non-one-shot value lambda, and hence may be shared.


To achieve this we may need to do two things:

   a) Let-bind the sub-expression:

        f (g x)  ==>  let lvl = f (g x) in lvl

      Now we can float the binding for 'lvl'.

   b) More than that, we may need to abstract wrt a type variable

        \x -> ... /\a -> let v = ...a... in ....

      Here the binding for v mentions 'a' but not 'x'.  So we
      abstract wrt 'a', to give this binding for 'v':

            vp = /\a -> ...a...
            v  = vp a

      Now the binding for vp can float out unimpeded.
      I can't remember why this case seemed important enough to
      deal with, but I certainly found cases where important floats
      didn't happen if we did not abstract wrt tyvars.

With this in mind we can also achieve another goal: lambda lifting.
We can make an arbitrary (function) binding float to top level by
abstracting wrt *all* local variables, not just type variables, leaving
a binding that can be floated right to top level.  Whether or not this
happens is controlled by a flag.


Random comments
~~~~~~~~~~~~~~~

At the moment we never float a binding out to between two adjacent
lambdas.  For example:

@
        \x y -> let t = x+x in ...
===>
        \x -> let t = x+x in \y -> ...
@
Reason: this is less efficient in the case where the original lambda
is never partially applied.

But there's a case I've seen where this might not be true.  Consider:
@
elEm2 x ys
  = elem' x ys
  where
    elem' _ []  = False
    elem' x (y:ys)      = x==y || elem' x ys
@
It turns out that this generates a subexpression of the form
@
        \deq x ys -> let eq = eqFromEqDict deq in ...
@
which might usefully be separated to
@
        \deq -> let eq = eqFromEqDict deq in \xy -> ...
@
Well, maybe.  We don't do this at the moment.


************************************************************************
*                                                                      *
\subsection[floatOutwards]{@floatOutwards@: let-floating interface function}
*                                                                      *
************************************************************************
-}

floatOutwards :: Logger
              -> FloatOutSwitches
              -> UniqSupply
              -> CoreProgram -> IO CoreProgram

floatOutwards :: Logger
-> FloatOutSwitches -> UniqSupply -> [CoreBind] -> IO [CoreBind]
floatOutwards Logger
logger FloatOutSwitches
float_sws UniqSupply
us [CoreBind]
pgm
  = do {
        let { annotated_w_levels :: [LevelledBind]
annotated_w_levels = FloatOutSwitches -> [CoreBind] -> UniqSupply -> [LevelledBind]
setLevels FloatOutSwitches
float_sws [CoreBind]
pgm UniqSupply
us ;
              ([FloatStats]
fss, [Bag CoreBind]
binds_s')    = [(FloatStats, Bag CoreBind)] -> ([FloatStats], [Bag CoreBind])
forall (f :: * -> *) a b. Functor f => f (a, b) -> (f a, f b)
unzip ((LevelledBind -> (FloatStats, Bag CoreBind))
-> [LevelledBind] -> [(FloatStats, Bag CoreBind)]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind [LevelledBind]
annotated_w_levels)
            } ;

        Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_verbose_core2core String
"Levels added:"
                  DumpFormat
FormatCore
                  ([SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat ((LevelledBind -> SDoc) -> [LevelledBind] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map LevelledBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr [LevelledBind]
annotated_w_levels));

        let { (Int
tlets, Int
ntlets, Int
lams) = FloatStats -> (Int, Int, Int)
get_stats ([FloatStats] -> FloatStats
sum_stats [FloatStats]
fss) };

        Logger -> DumpFlag -> String -> DumpFormat -> SDoc -> IO ()
putDumpFileMaybe Logger
logger DumpFlag
Opt_D_dump_simpl_stats String
"FloatOut stats:"
                DumpFormat
FormatText
                ([SDoc] -> SDoc
forall doc. IsLine doc => [doc] -> doc
hcat [ Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
tlets,  String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lets floated to top level; ",
                        Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
ntlets, String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lets floated elsewhere; from ",
                        Int -> SDoc
forall doc. IsLine doc => Int -> doc
int Int
lams,   String -> SDoc
forall doc. IsLine doc => String -> doc
text String
" Lambda groups"]);

        [CoreBind] -> IO [CoreBind]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Bag CoreBind -> [CoreBind]
forall a. Bag a -> [a]
bagToList ([Bag CoreBind] -> Bag CoreBind
forall a. [Bag a] -> Bag a
unionManyBags [Bag CoreBind]
binds_s'))
    }

floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind :: LevelledBind -> (FloatStats, Bag CoreBind)
floatTopBind LevelledBind
bind
  = case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fs, FloatBinds
floats, [CoreBind]
bind') ->
    let float_bag :: Bag CoreBind
float_bag = FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
floats
    in case [CoreBind]
bind' of
      -- bind' can't have unlifted values or join points, so can only be one
      -- value bind, rec or non-rec (see comment on floatBind)
      [Rec [(Var, Expr Var)]
prs]    -> (FloatStats
fs, CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag ([(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec (Bag CoreBind -> [(Var, Expr Var)] -> [(Var, Expr Var)]
addTopFloatPairs Bag CoreBind
float_bag [(Var, Expr Var)]
prs)))
      [NonRec Var
b Expr Var
e] -> (FloatStats
fs, Bag CoreBind
float_bag Bag CoreBind -> CoreBind -> Bag CoreBind
forall a. Bag a -> a -> Bag a
`snocBag` Var -> Expr Var -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Var
b Expr Var
e)
      [CoreBind]
_            -> String -> SDoc -> (FloatStats, Bag CoreBind)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"floatTopBind" ([CoreBind] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [CoreBind]
bind') }

{-
************************************************************************
*                                                                      *
\subsection[FloatOut-Bind]{Floating in a binding (the business end)}
*                                                                      *
************************************************************************
-}

floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
  -- Returns a list with either
  --   * A single non-recursive binding (value or join point), or
  --   * The following, in order:
  --     * Zero or more non-rec unlifted bindings
  --     * One or both of:
  --       * A recursive group of join binds
  --       * A recursive group of value binds
  -- See Note [Floating out of Rec rhss] for why things get arranged this way.
floatBind :: LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind (NonRec (TB Var
var FloatSpec
_) Expr (TaggedBndr FloatSpec)
rhs)
  = case (Var
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatRhs Var
var Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Var
rhs') ->
      (FloatStats
fs, FloatBinds
rhs_floats, [Var -> Expr Var -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Var
var Expr Var
rhs']) }

floatBind (Rec [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs)
  = case ((TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
 -> (FloatStats, FloatBinds,
     ([(Var, Expr Var)], [(Var, Expr Var)])))
-> [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
-> (FloatStats, FloatBinds,
    [([(Var, Expr Var)], [(Var, Expr Var)])])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Var, Expr Var)], [(Var, Expr Var)]))
do_pair [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
pairs of { (FloatStats
fs, FloatBinds
rhs_floats, [([(Var, Expr Var)], [(Var, Expr Var)])]
new_pairs) ->
    let ([[(Var, Expr Var)]]
new_ul_pairss, [[(Var, Expr Var)]]
new_other_pairss) = [([(Var, Expr Var)], [(Var, Expr Var)])]
-> ([[(Var, Expr Var)]], [[(Var, Expr Var)]])
forall (f :: * -> *) a b. Functor f => f (a, b) -> (f a, f b)
unzip [([(Var, Expr Var)], [(Var, Expr Var)])]
new_pairs
        ([(Var, Expr Var)]
new_join_pairs, [(Var, Expr Var)]
new_l_pairs)     = ((Var, Expr Var) -> Bool)
-> [(Var, Expr Var)] -> ([(Var, Expr Var)], [(Var, Expr Var)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Var -> Bool
isJoinId (Var -> Bool)
-> ((Var, Expr Var) -> Var) -> (Var, Expr Var) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Var, Expr Var) -> Var
forall a b. (a, b) -> a
fst)
                                                      ([[(Var, Expr Var)]] -> [(Var, Expr Var)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Var, Expr Var)]]
new_other_pairss)
        -- Can't put the join points and the values in the same rec group
        new_rec_binds :: [CoreBind]
new_rec_binds | [(Var, Expr Var)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Var, Expr Var)]
new_join_pairs = [ [(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Var, Expr Var)]
new_l_pairs    ]
                      | [(Var, Expr Var)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Var, Expr Var)]
new_l_pairs    = [ [(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Var, Expr Var)]
new_join_pairs ]
                      | Bool
otherwise           = [ [(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Var, Expr Var)]
new_l_pairs
                                              , [(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec [(Var, Expr Var)]
new_join_pairs ]
        new_non_rec_binds :: [CoreBind]
new_non_rec_binds = [ Var -> Expr Var -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Var
b Expr Var
e | (Var
b, Expr Var
e) <- [[(Var, Expr Var)]] -> [(Var, Expr Var)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Var, Expr Var)]]
new_ul_pairss ]
    in
    (FloatStats
fs, FloatBinds
rhs_floats, [CoreBind]
new_non_rec_binds [CoreBind] -> [CoreBind] -> [CoreBind]
forall a. [a] -> [a] -> [a]
++ [CoreBind]
new_rec_binds) }
  where
    do_pair :: (LevelledBndr, LevelledExpr)
            -> (FloatStats, FloatBinds,
                ([(Id,CoreExpr)],  -- Non-recursive unlifted value bindings
                 [(Id,CoreExpr)])) -- Join points and lifted value bindings
    do_pair :: (TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))
-> (FloatStats, FloatBinds, ([(Var, Expr Var)], [(Var, Expr Var)]))
do_pair (TB Var
name FloatSpec
spec, Expr (TaggedBndr FloatSpec)
rhs)
      | Level -> Bool
isTopLvl Level
dest_lvl  -- See Note [floatBind for top level]
      = case (Var
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatRhs Var
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Var
rhs') ->
        (FloatStats
fs, FloatBinds
emptyFloats, ([], Bag CoreBind -> [(Var, Expr Var)] -> [(Var, Expr Var)]
addTopFloatPairs (FloatBinds -> Bag CoreBind
flattenTopFloats FloatBinds
rhs_floats)
                                                [(Var
name, Expr Var
rhs')]))}
      | Bool
otherwise         -- Note [Floating out of Rec rhss]
      = case (Var
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatRhs Var
name Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Var
rhs') ->
        case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
dest_lvl FloatBinds
rhs_floats) of { (FloatBinds
rhs_floats', Bag FloatBind
heres) ->
        case (Bag FloatBind
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
splitRecFloats Bag FloatBind
heres) of { ([(Var, Expr Var)]
ul_pairs, [(Var, Expr Var)]
pairs, Bag FloatBind
case_heres) ->
        let pairs' :: [(Var, Expr Var)]
pairs' = (Var
name, Bag FloatBind -> Expr Var -> Expr Var
installUnderLambdas Bag FloatBind
case_heres Expr Var
rhs') (Var, Expr Var) -> [(Var, Expr Var)] -> [(Var, Expr Var)]
forall a. a -> [a] -> [a]
: [(Var, Expr Var)]
pairs in
        (FloatStats
fs, FloatBinds
rhs_floats', ([(Var, Expr Var)]
ul_pairs, [(Var, Expr Var)]
pairs')) }}}
      where
        dest_lvl :: Level
dest_lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
spec

splitRecFloats :: Bag FloatBind
               -> ([(Id,CoreExpr)], -- Non-recursive unlifted value bindings
                   [(Id,CoreExpr)], -- Join points and lifted value bindings
                   Bag FloatBind)   -- A tail of further bindings
-- The "tail" begins with a case
-- See Note [Floating out of Rec rhss]
splitRecFloats :: Bag FloatBind
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
splitRecFloats Bag FloatBind
fs
  = [(Var, Expr Var)]
-> [(Var, Expr Var)]
-> [FloatBind]
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
go [] [] (Bag FloatBind -> [FloatBind]
forall a. Bag a -> [a]
bagToList Bag FloatBind
fs)
  where
    go :: [(Var, Expr Var)]
-> [(Var, Expr Var)]
-> [FloatBind]
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
go [(Var, Expr Var)]
ul_prs [(Var, Expr Var)]
prs (FloatLet (NonRec Var
b Expr Var
r) : [FloatBind]
fs) | HasDebugCallStack => Type -> Bool
Type -> Bool
isUnliftedType (Var -> Type
idType Var
b)
                                               -- NB: isUnliftedType is OK here as binders always
                                               -- have a fixed RuntimeRep.
                                               , Bool -> Bool
not (Var -> Bool
isJoinId Var
b)
                                               = [(Var, Expr Var)]
-> [(Var, Expr Var)]
-> [FloatBind]
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
go ((Var
b,Expr Var
r)(Var, Expr Var) -> [(Var, Expr Var)] -> [(Var, Expr Var)]
forall a. a -> [a] -> [a]
:[(Var, Expr Var)]
ul_prs) [(Var, Expr Var)]
prs [FloatBind]
fs
                                               | Bool
otherwise
                                               = [(Var, Expr Var)]
-> [(Var, Expr Var)]
-> [FloatBind]
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
go [(Var, Expr Var)]
ul_prs ((Var
b,Expr Var
r)(Var, Expr Var) -> [(Var, Expr Var)] -> [(Var, Expr Var)]
forall a. a -> [a] -> [a]
:[(Var, Expr Var)]
prs) [FloatBind]
fs
    go [(Var, Expr Var)]
ul_prs [(Var, Expr Var)]
prs (FloatLet (Rec [(Var, Expr Var)]
prs')   : [FloatBind]
fs) = [(Var, Expr Var)]
-> [(Var, Expr Var)]
-> [FloatBind]
-> ([(Var, Expr Var)], [(Var, Expr Var)], Bag FloatBind)
go [(Var, Expr Var)]
ul_prs ([(Var, Expr Var)]
prs' [(Var, Expr Var)] -> [(Var, Expr Var)] -> [(Var, Expr Var)]
forall a. [a] -> [a] -> [a]
++ [(Var, Expr Var)]
prs) [FloatBind]
fs
    go [(Var, Expr Var)]
ul_prs [(Var, Expr Var)]
prs [FloatBind]
fs                           = ([(Var, Expr Var)] -> [(Var, Expr Var)]
forall a. [a] -> [a]
reverse [(Var, Expr Var)]
ul_prs, [(Var, Expr Var)]
prs,
                                                  [FloatBind] -> Bag FloatBind
forall a. [a] -> Bag a
listToBag [FloatBind]
fs)
                                                   -- Order only matters for
                                                   -- non-rec

installUnderLambdas :: Bag FloatBind -> CoreExpr -> CoreExpr
-- See Note [Floating out of Rec rhss]
installUnderLambdas :: Bag FloatBind -> Expr Var -> Expr Var
installUnderLambdas Bag FloatBind
floats Expr Var
e
  | Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag Bag FloatBind
floats = Expr Var
e
  | Bool
otherwise         = Expr Var -> Expr Var
go Expr Var
e
  where
    go :: Expr Var -> Expr Var
go (Lam Var
b Expr Var
e)                 = Var -> Expr Var -> Expr Var
forall b. b -> Expr b -> Expr b
Lam Var
b (Expr Var -> Expr Var
go Expr Var
e)
    go Expr Var
e                         = Bag FloatBind -> Expr Var -> Expr Var
install Bag FloatBind
floats Expr Var
e

---------------
floatList :: (a -> (FloatStats, FloatBinds, b)) -> [a] -> (FloatStats, FloatBinds, [b])
floatList :: forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
_ [] = (FloatStats
zeroStats, FloatBinds
emptyFloats, [])
floatList a -> (FloatStats, FloatBinds, b)
f (a
a:[a]
as) = case a -> (FloatStats, FloatBinds, b)
f a
a            of { (FloatStats
fs_a,  FloatBinds
binds_a,  b
b)  ->
                     case (a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList a -> (FloatStats, FloatBinds, b)
f [a]
as of { (FloatStats
fs_as, FloatBinds
binds_as, [b]
bs) ->
                     (FloatStats
fs_a FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fs_as, FloatBinds
binds_a FloatBinds -> FloatBinds -> FloatBinds
`plusFloats`  FloatBinds
binds_as, b
bb -> [b] -> [b]
forall a. a -> [a] -> [a]
:[b]
bs) }}

{-
Note [Floating out of Rec rhss]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Consider   Rec { f<1,0> = \xy. body }
From the body we may get some floats. The ones with level <1,0> must
stay here, since they may mention f.  Ideally we'd like to make them
part of the Rec block pairs -- but we can't if there are any
FloatCases involved.

Nor is it a good idea to dump them in the rhs, but outside the lambda
    f = case x of I# y -> \xy. body
because now f's arity might get worse, which is Not Good. (And if
there's an SCC around the RHS it might not get better again.
See #5342.)

So, gruesomely, we split the floats into
 * the outer FloatLets, which can join the Rec, and
 * an inner batch starting in a FloatCase, which are then
   pushed *inside* the lambdas.
This loses full-laziness the rare situation where there is a
FloatCase and a Rec interacting.

If there are unlifted FloatLets (that *aren't* join points) among the floats,
we can't add them to the recursive group without angering Core Lint, but since
they must be ok-for-speculation, they can't actually be making any recursive
calls, so we can safely pull them out and keep them non-recursive.

(Why is something getting floated to <1,0> that doesn't make a recursive call?
The case that came up in testing was that f *and* the unlifted binding were
getting floated *to the same place*:

  \x<2,0> ->
    ... <3,0>
    letrec { f<F<2,0>> =
      ... let x'<F<2,0>> = x +# 1# in ...
    } in ...

Everything gets labeled "float to <2,0>" because it all depends on x, but this
makes f and x' look mutually recursive when they're not.

The test was shootout/k-nucleotide, as compiled using commit 47d5dd68 on the
wip/join-points branch.

TODO: This can probably be solved somehow in GHC.Core.Opt.SetLevels. The difference between
"this *is at* level <2,0>" and "this *depends on* level <2,0>" is very
important.)

Note [floatBind for top level]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We may have a *nested* binding whose destination level is (FloatMe tOP_LEVEL), thus
         letrec { foo <0,0> = .... (let bar<0,0> = .. in ..) .... }
The binding for bar will be in the "tops" part of the floating binds,
and thus not partitioned by floatBody.

We could perhaps get rid of the 'tops' component of the floating binds,
but this case works just as well.


************************************************************************

\subsection[FloatOut-Expr]{Floating in expressions}
*                                                                      *
************************************************************************
-}

floatBody :: Level
          -> LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)

floatBody :: Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
arg       -- Used rec rhss, and case-alternative rhss
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
arg) of { (FloatStats
fsa, FloatBinds
floats, Expr Var
arg') ->
    case (Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel Level
lvl FloatBinds
floats) of { (FloatBinds
floats', Bag FloatBind
heres) ->
        -- Dump bindings are bound here
    (FloatStats
fsa, FloatBinds
floats', Bag FloatBind -> Expr Var -> Expr Var
install Bag FloatBind
heres Expr Var
arg') }}

-----------------

{- Note [Floating past breakpoints]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
We used to disallow floating out of breakpoint ticks (see #10052). However, I
think this is too restrictive.

Consider the case of an expression scoped over by a breakpoint tick,

  tick<...> (let x = ... in f x)

In this case it is completely legal to float out x, despite the fact that
breakpoint ticks are scoped,

  let x = ... in (tick<...>  f x)

The reason here is that we know that the breakpoint will still be hit when the
expression is entered since the tick still scopes over the RHS.

-}

floatExpr :: LevelledExpr
          -> (FloatStats, FloatBinds, CoreExpr)
floatExpr :: Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr (Var Var
v)   = (FloatStats
zeroStats, FloatBinds
emptyFloats, Var -> Expr Var
forall b. Var -> Expr b
Var Var
v)
floatExpr (Type Type
ty) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Type -> Expr Var
forall b. Type -> Expr b
Type Type
ty)
floatExpr (Coercion Coercion
co) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Coercion -> Expr Var
forall b. Coercion -> Expr b
Coercion Coercion
co)
floatExpr (Lit Literal
lit) = (FloatStats
zeroStats, FloatBinds
emptyFloats, Literal -> Expr Var
forall b. Literal -> Expr b
Lit Literal
lit)

floatExpr (App Expr (TaggedBndr FloatSpec)
e Expr (TaggedBndr FloatSpec)
a)
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr  Expr (TaggedBndr FloatSpec)
e) of { (FloatStats
fse, FloatBinds
floats_e, Expr Var
e') ->
    case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr  Expr (TaggedBndr FloatSpec)
a) of { (FloatStats
fsa, FloatBinds
floats_a, Expr Var
a') ->
    (FloatStats
fse FloatStats -> FloatStats -> FloatStats
`add_stats` FloatStats
fsa, FloatBinds
floats_e FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
floats_a, Expr Var -> Expr Var -> Expr Var
forall b. Expr b -> Expr b -> Expr b
App Expr Var
e' Expr Var
a') }}

floatExpr lam :: Expr (TaggedBndr FloatSpec)
lam@(Lam (TB Var
_ FloatSpec
lam_spec) Expr (TaggedBndr FloatSpec)
_)
  = let ([TaggedBndr FloatSpec]
bndrs_w_lvls, Expr (TaggedBndr FloatSpec)
body) = Expr (TaggedBndr FloatSpec)
-> ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall b. Expr b -> ([b], Expr b)
collectBinders Expr (TaggedBndr FloatSpec)
lam
        bndrs :: [Var]
bndrs                = [Var
b | TB Var
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs_w_lvls]
        bndr_lvl :: Level
bndr_lvl             = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec
        -- All the binders have the same level
        -- See GHC.Core.Opt.SetLevels.lvlLamBndrs
        -- Use asJoinCeilLvl to make this the join ceiling
    in
    case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatBody Level
bndr_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fs, FloatBinds
floats, Expr Var
body') ->
    (FloatStats -> FloatBinds -> FloatStats
add_to_stats FloatStats
fs FloatBinds
floats, FloatBinds
floats, [Var] -> Expr Var -> Expr Var
forall b. [b] -> Expr b -> Expr b
mkLams [Var]
bndrs Expr Var
body') }

floatExpr (Tick CoreTickish
tickish Expr (TaggedBndr FloatSpec)
expr)
  -- If possible, float out past the tick
  | let float_out_of_tick :: Bool
float_out_of_tick
          -- See Note [Floating past breakpoints]
          | Breakpoint{} <- CoreTickish
tickish
          = Bool
True
          | Bool
otherwise
          -- We can float code out of non-scoped ticks
          = CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishHasNoScope CoreTickish
tickish
  , Bool
float_out_of_tick
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Var
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, CoreTickish -> Expr Var -> Expr Var
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Var
expr') }

  -- We can't move code out of the tick
  | Bool
otherwise
  = Bool
-> (FloatStats, FloatBinds, Expr Var)
-> (FloatStats, FloatBinds, Expr Var)
forall a. HasCallStack => Bool -> a -> a
assert (Bool -> Bool
not (CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCounts CoreTickish
tickish) Bool -> Bool -> Bool
|| CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCanSplit CoreTickish
tickish) ((FloatStats, FloatBinds, Expr Var)
 -> (FloatStats, FloatBinds, Expr Var))
-> (FloatStats, FloatBinds, Expr Var)
-> (FloatStats, FloatBinds, Expr Var)
forall a b. (a -> b) -> a -> b
$
    case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
expr)    of { (FloatStats
fs, FloatBinds
floating_defns, Expr Var
expr') ->
        -- Wrap floated code with the correct tick scope, but using 'mkNoCount'
        -- to ensure we don't duplicate counters.
    let annotated_defns :: FloatBinds
annotated_defns = CoreTickish -> FloatBinds -> FloatBinds
wrapTick (CoreTickish -> CoreTickish
forall (pass :: TickishPass). GenTickish pass -> GenTickish pass
mkNoCount CoreTickish
tickish) FloatBinds
floating_defns
    in
    (FloatStats
fs, FloatBinds
annotated_defns, CoreTickish -> Expr Var -> Expr Var
forall b. CoreTickish -> Expr b -> Expr b
Tick CoreTickish
tickish Expr Var
expr') }


floatExpr (Cast Expr (TaggedBndr FloatSpec)
expr Coercion
co)
  = case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
expr) of { (FloatStats
fs, FloatBinds
floating_defns, Expr Var
expr') ->
    (FloatStats
fs, FloatBinds
floating_defns, Expr Var -> Coercion -> Expr Var
forall b. Expr b -> Coercion -> Expr b
Cast Expr Var
expr' Coercion
co) }

floatExpr (Let LevelledBind
bind Expr (TaggedBndr FloatSpec)
body)
  = case FloatSpec
bind_spec of
      FloatMe Level
dest_lvl
        -> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind) of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
           case (Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Var
body') ->
           let new_bind_floats :: FloatBinds
new_bind_floats = (FloatBinds -> FloatBinds -> FloatBinds)
-> FloatBinds -> [FloatBinds] -> FloatBinds
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBinds -> FloatBinds -> FloatBinds
plusFloats FloatBinds
emptyFloats
                                   ((CoreBind -> FloatBinds) -> [CoreBind] -> [FloatBinds]
forall a b. (a -> b) -> [a] -> [b]
map (Level -> CoreBind -> FloatBinds
unitLetFloat Level
dest_lvl) [CoreBind]
binds') in
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
new_bind_floats
                         FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , Expr Var
body') }}

      StayPut Level
bind_lvl  -- See Note [Avoiding unnecessary floating]
        -> case (LevelledBind -> (FloatStats, FloatBinds, [CoreBind])
floatBind LevelledBind
bind)          of { (FloatStats
fsb, FloatBinds
bind_floats, [CoreBind]
binds') ->
           case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
body) of { (FloatStats
fse, FloatBinds
body_floats, Expr Var
body') ->
           ( FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fsb FloatStats
fse
           , FloatBinds
bind_floats FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
body_floats
           , (CoreBind -> Expr Var -> Expr Var)
-> Expr Var -> [CoreBind] -> Expr Var
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> Expr Var -> Expr Var
CoreBind -> Expr Var -> Expr Var
forall b. Bind b -> Expr b -> Expr b
Let Expr Var
body' [CoreBind]
binds' ) }}
  where
    bind_spec :: FloatSpec
bind_spec = case LevelledBind
bind of
                 NonRec (TB Var
_ FloatSpec
s) Expr (TaggedBndr FloatSpec)
_     -> FloatSpec
s
                 Rec ((TB Var
_ FloatSpec
s, Expr (TaggedBndr FloatSpec)
_) : [(TaggedBndr FloatSpec, Expr (TaggedBndr FloatSpec))]
_) -> FloatSpec
s
                 Rec []                -> String -> FloatSpec
forall a. HasCallStack => String -> a
panic String
"floatExpr:rec"

floatExpr (Case Expr (TaggedBndr FloatSpec)
scrut (TB Var
case_bndr FloatSpec
case_spec) Type
ty [Alt (TaggedBndr FloatSpec)]
alts)
  = case FloatSpec
case_spec of
      FloatMe Level
dest_lvl  -- Case expression moves
        | [Alt con :: AltCon
con@(DataAlt {}) [TaggedBndr FloatSpec]
bndrs Expr (TaggedBndr FloatSpec)
rhs] <- [Alt (TaggedBndr FloatSpec)]
alts
        -> case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Var
scrut') ->
           case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
rhs   of { (FloatStats
fsb, FloatBinds
fdb, Expr Var
rhs') ->
           let
             float :: FloatBinds
float = Level -> Expr Var -> Var -> AltCon -> [Var] -> FloatBinds
unitCaseFloat Level
dest_lvl Expr Var
scrut'
                          Var
case_bndr AltCon
con [Var
b | TB Var
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs]
           in
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsb, FloatBinds
fde FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
float FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fdb, Expr Var
rhs') }}
        | Bool
otherwise
        -> String -> SDoc -> (FloatStats, FloatBinds, Expr Var)
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"Floating multi-case" ([Alt (TaggedBndr FloatSpec)] -> SDoc
forall a. Outputable a => a -> SDoc
ppr [Alt (TaggedBndr FloatSpec)]
alts)

      StayPut Level
bind_lvl  -- Case expression stays put
        -> case Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
scrut of { (FloatStats
fse, FloatBinds
fde, Expr Var
scrut') ->
           case (Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Var))
-> [Alt (TaggedBndr FloatSpec)]
-> (FloatStats, FloatBinds, [Alt Var])
forall a b.
(a -> (FloatStats, FloatBinds, b))
-> [a] -> (FloatStats, FloatBinds, [b])
floatList (Level
-> Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Var)
float_alt Level
bind_lvl) [Alt (TaggedBndr FloatSpec)]
alts of { (FloatStats
fsa, FloatBinds
fda, [Alt Var]
alts')  ->
           (FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
fse FloatStats
fsa, FloatBinds
fda FloatBinds -> FloatBinds -> FloatBinds
`plusFloats` FloatBinds
fde, Expr Var -> Var -> Type -> [Alt Var] -> Expr Var
forall b. Expr b -> b -> Type -> [Alt b] -> Expr b
Case Expr Var
scrut' Var
case_bndr Type
ty [Alt Var]
alts')
           }}
  where
    float_alt :: Level
-> Alt (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Alt Var)
float_alt Level
bind_lvl (Alt AltCon
con [TaggedBndr FloatSpec]
bs Expr (TaggedBndr FloatSpec)
rhs)
        = case (Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatBody Level
bind_lvl Expr (TaggedBndr FloatSpec)
rhs) of { (FloatStats
fs, FloatBinds
rhs_floats, Expr Var
rhs') ->
          (FloatStats
fs, FloatBinds
rhs_floats, AltCon -> [Var] -> Expr Var -> Alt Var
forall b. AltCon -> [b] -> Expr b -> Alt b
Alt AltCon
con [Var
b | TB Var
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bs] Expr Var
rhs') }

floatRhs :: CoreBndr
         -> LevelledExpr
         -> (FloatStats, FloatBinds, CoreExpr)
floatRhs :: Var
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatRhs Var
bndr Expr (TaggedBndr FloatSpec)
rhs
  | JoinPoint Int
join_arity <- Var -> JoinPointHood
idJoinPointHood Var
bndr
  , Just ([TaggedBndr FloatSpec]
bndrs, Expr (TaggedBndr FloatSpec)
body) <- Int
-> Expr (TaggedBndr FloatSpec)
-> [TaggedBndr FloatSpec]
-> Maybe ([TaggedBndr FloatSpec], Expr (TaggedBndr FloatSpec))
forall {t} {a}.
(Eq t, Num t) =>
t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect Int
join_arity Expr (TaggedBndr FloatSpec)
rhs []
  = case [TaggedBndr FloatSpec]
bndrs of
      []                -> Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
      (TB Var
_ FloatSpec
lam_spec):[TaggedBndr FloatSpec]
_ ->
        let lvl :: Level
lvl = FloatSpec -> Level
floatSpecLevel FloatSpec
lam_spec in
        case Level
-> Expr (TaggedBndr FloatSpec)
-> (FloatStats, FloatBinds, Expr Var)
floatBody Level
lvl Expr (TaggedBndr FloatSpec)
body of { (FloatStats
fs, FloatBinds
floats, Expr Var
body') ->
        (FloatStats
fs, FloatBinds
floats, [Var] -> Expr Var -> Expr Var
forall b. [b] -> Expr b -> Expr b
mkLams [Var
b | TB Var
b FloatSpec
_ <- [TaggedBndr FloatSpec]
bndrs] Expr Var
body') }
  | Bool
otherwise
  = Expr (TaggedBndr FloatSpec) -> (FloatStats, FloatBinds, Expr Var)
floatExpr Expr (TaggedBndr FloatSpec)
rhs
  where
    try_collect :: t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect t
0 Expr a
expr      [a]
acc = ([a], Expr a) -> Maybe ([a], Expr a)
forall a. a -> Maybe a
Just ([a] -> [a]
forall a. [a] -> [a]
reverse [a]
acc, Expr a
expr)
    try_collect t
n (Lam a
b Expr a
e) [a]
acc = t -> Expr a -> [a] -> Maybe ([a], Expr a)
try_collect (t
nt -> t -> t
forall a. Num a => a -> a -> a
-t
1) Expr a
e (a
ba -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
acc)
    try_collect t
_ Expr a
_         [a]
_   = Maybe ([a], Expr a)
forall a. Maybe a
Nothing

{-
Note [Avoiding unnecessary floating]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
In general we want to avoid floating a let unnecessarily, because
it might worsen strictness:
    let
       x = ...(let y = e in y+y)....
Here y is demanded.  If we float it outside the lazy 'x=..' then
we'd have to zap its demand info, and it may never be restored.

So at a 'let' we leave the binding right where the are unless
the binding will escape a value lambda, e.g.

(\x -> let y = fac 100 in y)

That's what the partitionByMajorLevel does in the floatExpr (Let ...)
case.

Notice, though, that we must take care to drop any bindings
from the body of the let that depend on the staying-put bindings.

We used instead to do the partitionByMajorLevel on the RHS of an '=',
in floatRhs.  But that was quite tiresome.  We needed to test for
values or trivial rhss, because (in particular) we don't want to insert
new bindings between the "=" and the "\".  E.g.
        f = \x -> let <bind> in <body>
We do not want
        f = let <bind> in \x -> <body>
(a) The simplifier will immediately float it further out, so we may
        as well do so right now; in general, keeping rhss as manifest
        values is good
(b) If a float-in pass follows immediately, it might add yet more
        bindings just after the '='.  And some of them might (correctly)
        be strict even though the 'let f' is lazy, because f, being a value,
        gets its demand-info zapped by the simplifier.
And even all that turned out to be very fragile, and broke
altogether when profiling got in the way.

So now we do the partition right at the (Let..) itself.

************************************************************************
*                                                                      *
\subsection{Utility bits for floating stats}
*                                                                      *
************************************************************************

I didn't implement this with unboxed numbers.  I don't want to be too
strict in this stuff, as it is rarely turned on.  (WDP 95/09)
-}

data FloatStats
  = FlS Int  -- Number of top-floats * lambda groups they've been past
        Int  -- Number of non-top-floats * lambda groups they've been past
        Int  -- Number of lambda (groups) seen

get_stats :: FloatStats -> (Int, Int, Int)
get_stats :: FloatStats -> (Int, Int, Int)
get_stats (FlS Int
a Int
b Int
c) = (Int
a, Int
b, Int
c)

zeroStats :: FloatStats
zeroStats :: FloatStats
zeroStats = Int -> Int -> Int -> FloatStats
FlS Int
0 Int
0 Int
0

sum_stats :: [FloatStats] -> FloatStats
sum_stats :: [FloatStats] -> FloatStats
sum_stats [FloatStats]
xs = (FloatStats -> FloatStats -> FloatStats)
-> FloatStats -> [FloatStats] -> FloatStats
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatStats -> FloatStats -> FloatStats
add_stats FloatStats
zeroStats [FloatStats]
xs

add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats :: FloatStats -> FloatStats -> FloatStats
add_stats (FlS Int
a1 Int
b1 Int
c1) (FlS Int
a2 Int
b2 Int
c2)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a2) (Int
b1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
b2) (Int
c1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c2)

add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats :: FloatStats -> FloatBinds -> FloatStats
add_to_stats (FlS Int
a Int
b Int
c) (FB Bag CoreBind
tops MajorEnv
others)
  = Int -> Int -> Int -> FloatStats
FlS (Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag CoreBind -> Int
forall a. Bag a -> Int
lengthBag Bag CoreBind
tops)
        (Int
b Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bag FloatBind -> Int
forall a. Bag a -> Int
lengthBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
others))
        (Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

{-
************************************************************************
*                                                                      *
\subsection{Utility bits for floating}
*                                                                      *
************************************************************************

Note [Representation of FloatBinds]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The FloatBinds types is somewhat important.  We can get very large numbers
of floating bindings, often all destined for the top level.  A typical example
is     x = [4,2,5,2,5, .... ]
Then we get lots of small expressions like (fromInteger 4), which all get
lifted to top level.

The trouble is that
  (a) we partition these floating bindings *at every binding site*
  (b) GHC.Core.Opt.SetLevels introduces a new bindings site for every float
So we had better not look at each binding at each binding site!

That is why MajorEnv is represented as a finite map.

We keep the bindings destined for the *top* level separate, because
we float them out even if they don't escape a *value* lambda; see
partitionByMajorLevel.
-}

type FloatLet = CoreBind        -- INVARIANT: a FloatLet is always lifted
type MajorEnv = M.IntMap MinorEnv         -- Keyed by major level
type MinorEnv = M.IntMap (Bag FloatBind)  -- Keyed by minor level

data FloatBinds  = FB !(Bag FloatLet)           -- Destined for top level
                      !MajorEnv                 -- Other levels
     -- See Note [Representation of FloatBinds]

instance Outputable FloatBinds where
  ppr :: FloatBinds -> SDoc
ppr (FB Bag CoreBind
fbs MajorEnv
defs)
      = String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"FB" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> (SDoc -> SDoc
forall doc. IsLine doc => doc -> doc
braces (SDoc -> SDoc) -> SDoc -> SDoc
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
forall doc. IsDoc doc => [doc] -> doc
vcat
           [ String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"tops ="     SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> Bag CoreBind -> SDoc
forall a. Outputable a => a -> SDoc
ppr Bag CoreBind
fbs
           , String -> SDoc
forall doc. IsLine doc => String -> doc
text String
"non-tops =" SDoc -> SDoc -> SDoc
forall doc. IsLine doc => doc -> doc -> doc
<+> MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs ])

flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats :: FloatBinds -> Bag CoreBind
flattenTopFloats (FB Bag CoreBind
tops MajorEnv
defs)
  = Bool -> SDoc -> Bag CoreBind -> Bag CoreBind
forall a. HasCallStack => Bool -> SDoc -> a -> a
assertPpr (Bag FloatBind -> Bool
forall a. Bag a -> Bool
isEmptyBag (MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
defs)) (MajorEnv -> SDoc
forall a. Outputable a => a -> SDoc
ppr MajorEnv
defs) (Bag CoreBind -> Bag CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> a -> b
$
    Bag CoreBind
tops

addTopFloatPairs :: Bag CoreBind -> [(Id,CoreExpr)] -> [(Id,CoreExpr)]
addTopFloatPairs :: Bag CoreBind -> [(Var, Expr Var)] -> [(Var, Expr Var)]
addTopFloatPairs Bag CoreBind
float_bag [(Var, Expr Var)]
prs
  = (CoreBind -> [(Var, Expr Var)] -> [(Var, Expr Var)])
-> [(Var, Expr Var)] -> Bag CoreBind -> [(Var, Expr Var)]
forall a b. (a -> b -> b) -> b -> Bag a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr CoreBind -> [(Var, Expr Var)] -> [(Var, Expr Var)]
forall {a}. Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add [(Var, Expr Var)]
prs Bag CoreBind
float_bag
  where
    add :: Bind a -> [(a, Expr a)] -> [(a, Expr a)]
add (NonRec a
b Expr a
r) [(a, Expr a)]
prs  = (a
b,Expr a
r)(a, Expr a) -> [(a, Expr a)] -> [(a, Expr a)]
forall a. a -> [a] -> [a]
:[(a, Expr a)]
prs
    add (Rec [(a, Expr a)]
prs1)   [(a, Expr a)]
prs2 = [(a, Expr a)]
prs1 [(a, Expr a)] -> [(a, Expr a)] -> [(a, Expr a)]
forall a. [a] -> [a] -> [a]
++ [(a, Expr a)]
prs2

flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor :: MajorEnv -> Bag FloatBind
flattenMajor = (IntMap (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> MajorEnv -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr (Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> (IntMap (Bag FloatBind) -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> Bag FloatBind
-> Bag FloatBind
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor) Bag FloatBind
forall a. Bag a
emptyBag

flattenMinor :: MinorEnv -> Bag FloatBind
flattenMinor :: IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> Bag FloatBind -> IntMap (Bag FloatBind) -> Bag FloatBind
forall a b. (a -> b -> b) -> b -> IntMap a -> b
M.foldr Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags Bag FloatBind
forall a. Bag a
emptyBag

emptyFloats :: FloatBinds
emptyFloats :: FloatBinds
emptyFloats = Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag MajorEnv
forall a. IntMap a
M.empty

unitCaseFloat :: Level -> CoreExpr -> Id -> AltCon -> [Var] -> FloatBinds
unitCaseFloat :: Level -> Expr Var -> Var -> AltCon -> [Var] -> FloatBinds
unitCaseFloat (Level Int
major Int
minor) Expr Var
e Var
b AltCon
con [Var]
bs
  = Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (Expr Var -> Var -> AltCon -> [Var] -> FloatBind
FloatCase Expr Var
e Var
b AltCon
con [Var]
bs)

unitLetFloat :: Level -> FloatLet -> FloatBinds
unitLetFloat :: Level -> CoreBind -> FloatBinds
unitLetFloat lvl :: Level
lvl@(Level Int
major Int
minor) CoreBind
b
  | Level -> Bool
isTopLvl Level
lvl = Bag CoreBind -> MajorEnv -> FloatBinds
FB (CoreBind -> Bag CoreBind
forall a. a -> Bag a
unitBag CoreBind
b) MajorEnv
forall a. IntMap a
M.empty
  | Bool
otherwise    = Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
forall a. Bag a
emptyBag (Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major (Int -> Bag FloatBind -> IntMap (Bag FloatBind)
forall a. Int -> a -> IntMap a
M.singleton Int
minor Bag FloatBind
floats))
  where
    floats :: Bag FloatBind
floats = FloatBind -> Bag FloatBind
forall a. a -> Bag a
unitBag (CoreBind -> FloatBind
FloatLet CoreBind
b)

plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats :: FloatBinds -> FloatBinds -> FloatBinds
plusFloats (FB Bag CoreBind
t1 MajorEnv
l1) (FB Bag CoreBind
t2 MajorEnv
l2)
  = Bag CoreBind -> MajorEnv -> FloatBinds
FB (Bag CoreBind
t1 Bag CoreBind -> Bag CoreBind -> Bag CoreBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` Bag CoreBind
t2) (MajorEnv
l1 MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` MajorEnv
l2)

plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor :: MajorEnv -> MajorEnv -> MajorEnv
plusMajor = (IntMap (Bag FloatBind)
 -> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv -> MajorEnv
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor

plusMinor :: MinorEnv -> MinorEnv -> MinorEnv
plusMinor :: IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
plusMinor = (Bag FloatBind -> Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
-> IntMap (Bag FloatBind)
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
M.unionWith Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
unionBags

install :: Bag FloatBind -> CoreExpr -> CoreExpr
install :: Bag FloatBind -> Expr Var -> Expr Var
install Bag FloatBind
defn_groups Expr Var
expr
  = (FloatBind -> Expr Var -> Expr Var)
-> Expr Var -> Bag FloatBind -> Expr Var
forall a b. (a -> b -> b) -> b -> Bag a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr FloatBind -> Expr Var -> Expr Var
wrapFloat Expr Var
expr Bag FloatBind
defn_groups

partitionByLevel
        :: Level                -- Partitioning level
        -> FloatBinds           -- Defns to be divided into 2 piles...
        -> (FloatBinds,         -- Defns  with level strictly < partition level,
            Bag FloatBind)      -- The rest

{-
--       ---- partitionByMajorLevel ----
-- Float it if we escape a value lambda,
--     *or* if we get to the top level
--     *or* if it's a case-float and its minor level is < current
--
-- If we can get to the top level, say "yes" anyway. This means that
--      x = f e
-- transforms to
--    lvl = e
--    x = f lvl
-- which is as it should be

partitionByMajorLevel (Level major _) (FB tops defns)
  = (FB tops outer, heres `unionBags` flattenMajor inner)
  where
    (outer, mb_heres, inner) = M.splitLookup major defns
    heres = case mb_heres of
               Nothing -> emptyBag
               Just h  -> flattenMinor h
-}

partitionByLevel :: Level -> FloatBinds -> (FloatBinds, Bag FloatBind)
partitionByLevel (Level Int
major Int
minor) (FB Bag CoreBind
tops MajorEnv
defns)
  = (Bag CoreBind -> MajorEnv -> FloatBinds
FB Bag CoreBind
tops (MajorEnv
outer_maj MajorEnv -> MajorEnv -> MajorEnv
`plusMajor` Int -> IntMap (Bag FloatBind) -> MajorEnv
forall a. Int -> a -> IntMap a
M.singleton Int
major IntMap (Bag FloatBind)
outer_min),
     Bag FloatBind
here_min Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` IntMap (Bag FloatBind) -> Bag FloatBind
flattenMinor IntMap (Bag FloatBind)
inner_min
              Bag FloatBind -> Bag FloatBind -> Bag FloatBind
forall a. Bag a -> Bag a -> Bag a
`unionBags` MajorEnv -> Bag FloatBind
flattenMajor MajorEnv
inner_maj)

  where
    (MajorEnv
outer_maj, Maybe (IntMap (Bag FloatBind))
mb_here_maj, MajorEnv
inner_maj) = Int
-> MajorEnv -> (MajorEnv, Maybe (IntMap (Bag FloatBind)), MajorEnv)
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
major MajorEnv
defns
    (IntMap (Bag FloatBind)
outer_min, Maybe (Bag FloatBind)
mb_here_min, IntMap (Bag FloatBind)
inner_min) = case Maybe (IntMap (Bag FloatBind))
mb_here_maj of
                                            Maybe (IntMap (Bag FloatBind))
Nothing -> (IntMap (Bag FloatBind)
forall a. IntMap a
M.empty, Maybe (Bag FloatBind)
forall a. Maybe a
Nothing, IntMap (Bag FloatBind)
forall a. IntMap a
M.empty)
                                            Just IntMap (Bag FloatBind)
min_defns -> Int
-> IntMap (Bag FloatBind)
-> (IntMap (Bag FloatBind), Maybe (Bag FloatBind),
    IntMap (Bag FloatBind))
forall a. Int -> IntMap a -> (IntMap a, Maybe a, IntMap a)
M.splitLookup Int
minor IntMap (Bag FloatBind)
min_defns
    here_min :: Bag FloatBind
here_min = Maybe (Bag FloatBind)
mb_here_min Maybe (Bag FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a. Maybe a -> a -> a
`orElse` Bag FloatBind
forall a. Bag a
emptyBag

wrapTick :: CoreTickish -> FloatBinds -> FloatBinds
wrapTick :: CoreTickish -> FloatBinds -> FloatBinds
wrapTick CoreTickish
t (FB Bag CoreBind
tops MajorEnv
defns)
  = Bool -> FloatBinds -> FloatBinds
forall a. HasCallStack => Bool -> a -> a
assert (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ CoreTickish -> Bool
forall (pass :: TickishPass). GenTickish pass -> Bool
tickishCounts CoreTickish
t) (FloatBinds -> FloatBinds) -> FloatBinds -> FloatBinds
forall a b. (a -> b) -> a -> b
$
    Bag CoreBind -> MajorEnv -> FloatBinds
FB ((CoreBind -> CoreBind) -> Bag CoreBind -> Bag CoreBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag CoreBind -> CoreBind
wrap_bind Bag CoreBind
tops)
       ((IntMap (Bag FloatBind) -> IntMap (Bag FloatBind))
-> MajorEnv -> MajorEnv
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map ((Bag FloatBind -> Bag FloatBind)
-> IntMap (Bag FloatBind) -> IntMap (Bag FloatBind)
forall a b. (a -> b) -> IntMap a -> IntMap b
M.map Bag FloatBind -> Bag FloatBind
wrap_defns) MajorEnv
defns)
  where
    wrap_defns :: Bag FloatBind -> Bag FloatBind
wrap_defns = (FloatBind -> FloatBind) -> Bag FloatBind -> Bag FloatBind
forall a b. (a -> b) -> Bag a -> Bag b
mapBag FloatBind -> FloatBind
wrap_one

    wrap_bind :: CoreBind -> CoreBind
wrap_bind (NonRec Var
binder Expr Var
rhs) = Var -> Expr Var -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec Var
binder (Expr Var -> Expr Var
maybe_tick Expr Var
rhs)
    wrap_bind (Rec [(Var, Expr Var)]
pairs)         = [(Var, Expr Var)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ((Expr Var -> Expr Var) -> [(Var, Expr Var)] -> [(Var, Expr Var)]
forall (f :: * -> *) b c a.
Functor f =>
(b -> c) -> f (a, b) -> f (a, c)
mapSnd Expr Var -> Expr Var
maybe_tick [(Var, Expr Var)]
pairs)

    wrap_one :: FloatBind -> FloatBind
wrap_one (FloatLet CoreBind
bind)      = CoreBind -> FloatBind
FloatLet (CoreBind -> CoreBind
wrap_bind CoreBind
bind)
    wrap_one (FloatCase Expr Var
e Var
b AltCon
c [Var]
bs) = Expr Var -> Var -> AltCon -> [Var] -> FloatBind
FloatCase (Expr Var -> Expr Var
maybe_tick Expr Var
e) Var
b AltCon
c [Var]
bs

    maybe_tick :: Expr Var -> Expr Var
maybe_tick
      -- We don't need to wrap an SCC tick around HNFs that we floated out of
      -- the SCC, as that is an invariant of the semantics for SCCs.
      -- Conversely, inlining of HNFs inside an SCC is allowed, and
      -- indeed the HNF we're floating here might well be inlined back
      -- again, and we don't want to end up with duplicate ticks.
      | CoreTickish -> TickishPlacement
forall (pass :: TickishPass). GenTickish pass -> TickishPlacement
tickishPlace CoreTickish
t TickishPlacement -> TickishPlacement -> Bool
forall a. Eq a => a -> a -> Bool
== TickishPlacement
PlaceCostCentre
      = CoreTickish -> Expr Var -> Expr Var
mkTickNoHNF CoreTickish
t
      | Bool
otherwise
      = CoreTickish -> Expr Var -> Expr Var
mkTick CoreTickish
t