{-# LANGUAGE DerivingStrategies #-}

module GHC.CmmToAsm.Reg.Regs (
        Regs(..),
        noRegs,
        addRegMaxFmt, addRegsMaxFmt,
        mkRegsMaxFmt,
        minusCoveredRegs,
        minusRegs,
        unionRegsMaxFmt,
        unionManyRegsMaxFmt,
        intersectRegsMaxFmt,
        shrinkingRegs,
        mapRegs,
        elemRegs, lookupReg,

  ) where

import GHC.Prelude

import GHC.Platform.Reg     ( Reg )
import GHC.CmmToAsm.Format  ( Format, RegWithFormat(..), isVecFormat )

import GHC.Utils.Outputable ( Outputable )
import GHC.Types.Unique     ( Uniquable(..) )
import GHC.Types.Unique.Set

import Data.Coerce ( coerce )

-----------------------------------------------------------------------------

-- | A set of registers, with their respective formats, mostly for use in
-- register liveness analysis.  See Note [Register formats in liveness analysis]
-- in GHC.CmmToAsm.Reg.Liveness.
newtype Regs = Regs { Regs -> UniqSet RegWithFormat
getRegs :: UniqSet RegWithFormat }
  deriving newtype (Regs -> Regs -> Bool
(Regs -> Regs -> Bool) -> (Regs -> Regs -> Bool) -> Eq Regs
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Regs -> Regs -> Bool
== :: Regs -> Regs -> Bool
$c/= :: Regs -> Regs -> Bool
/= :: Regs -> Regs -> Bool
Eq, Regs -> SDoc
(Regs -> SDoc) -> Outputable Regs
forall a. (a -> SDoc) -> Outputable a
$cppr :: Regs -> SDoc
ppr :: Regs -> SDoc
Outputable)

maxRegWithFormat :: RegWithFormat -> RegWithFormat -> RegWithFormat
maxRegWithFormat :: RegWithFormat -> RegWithFormat -> RegWithFormat
maxRegWithFormat r1 :: RegWithFormat
r1@(RegWithFormat Reg
_ Format
fmt1) r2 :: RegWithFormat
r2@(RegWithFormat Reg
_ Format
fmt2)
  = if Format
fmt1 Format -> Format -> Bool
forall a. Ord a => a -> a -> Bool
>= Format
fmt2
    then RegWithFormat
r1
    else RegWithFormat
r2
  -- Re-using one of the arguments avoids allocating a new 'RegWithFormat',
  -- compared with returning 'RegWithFormat r1 (max fmt1 fmt2)'.

noRegs :: Regs
noRegs :: Regs
noRegs = UniqSet RegWithFormat -> Regs
Regs UniqSet RegWithFormat
forall a. UniqSet a
emptyUniqSet

addRegsMaxFmt :: Regs -> [RegWithFormat] -> Regs
addRegsMaxFmt :: Regs -> [RegWithFormat] -> Regs
addRegsMaxFmt = (Regs -> RegWithFormat -> Regs) -> Regs -> [RegWithFormat] -> Regs
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Regs -> RegWithFormat -> Regs
addRegMaxFmt

mkRegsMaxFmt :: [RegWithFormat] -> Regs
mkRegsMaxFmt :: [RegWithFormat] -> Regs
mkRegsMaxFmt = Regs -> [RegWithFormat] -> Regs
addRegsMaxFmt Regs
noRegs

addRegMaxFmt :: Regs -> RegWithFormat -> Regs
addRegMaxFmt :: Regs -> RegWithFormat -> Regs
addRegMaxFmt = (UniqSet RegWithFormat -> RegWithFormat -> UniqSet RegWithFormat)
-> Regs -> RegWithFormat -> Regs
forall a b. Coercible a b => a -> b
coerce ((UniqSet RegWithFormat -> RegWithFormat -> UniqSet RegWithFormat)
 -> Regs -> RegWithFormat -> Regs)
-> (UniqSet RegWithFormat
    -> RegWithFormat -> UniqSet RegWithFormat)
-> Regs
-> RegWithFormat
-> Regs
forall a b. (a -> b) -> a -> b
$ (RegWithFormat -> RegWithFormat -> RegWithFormat)
-> UniqSet RegWithFormat -> RegWithFormat -> UniqSet RegWithFormat
forall a.
Uniquable a =>
(a -> a -> a) -> UniqSet a -> a -> UniqSet a
strictAddOneToUniqSet_C RegWithFormat -> RegWithFormat -> RegWithFormat
maxRegWithFormat
  -- Don't build up thunks when combining with 'maxRegWithFormat'

-- | Remove 2nd argument registers from the 1st argument, but only
-- if the format in the second argument is at least as large as the format
-- in the first argument.
minusCoveredRegs :: Regs -> Regs -> Regs
minusCoveredRegs :: Regs -> Regs -> Regs
minusCoveredRegs = (UniqSet RegWithFormat
 -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs -> Regs -> Regs
forall a b. Coercible a b => a -> b
coerce ((UniqSet RegWithFormat
  -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
 -> Regs -> Regs -> Regs)
-> (UniqSet RegWithFormat
    -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs
-> Regs
-> Regs
forall a b. (a -> b) -> a -> b
$ (RegWithFormat -> RegWithFormat -> Maybe RegWithFormat)
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
forall a.
(a -> a -> Maybe a) -> UniqSet a -> UniqSet a -> UniqSet a
minusUniqSet_C RegWithFormat -> RegWithFormat -> Maybe RegWithFormat
f
  where
    f :: RegWithFormat -> RegWithFormat -> Maybe RegWithFormat
    f :: RegWithFormat -> RegWithFormat -> Maybe RegWithFormat
f r1 :: RegWithFormat
r1@(RegWithFormat Reg
_ Format
fmt1) (RegWithFormat Reg
_ Format
fmt2) =
      if Format
fmt2 Format -> Format -> Bool
forall a. Ord a => a -> a -> Bool
>= Format
fmt1
           Bool -> Bool -> Bool
||
         Bool -> Bool
not ( Format -> Bool
isVecFormat Format
fmt1 )
          -- See Wrinkle [Don't allow scalar partial writes]
          -- in Note [Register formats in liveness analysis] in GHC.CmmToAsm.Reg.Liveness.
      then Maybe RegWithFormat
forall a. Maybe a
Nothing
      else RegWithFormat -> Maybe RegWithFormat
forall a. a -> Maybe a
Just RegWithFormat
r1

-- | Remove 2nd argument registers from the 1st argument, regardless of format.
--
-- See also 'minusCoveredRegs', which looks at the formats.
minusRegs :: Regs -> Regs -> Regs
minusRegs :: Regs -> Regs -> Regs
minusRegs = (UniqSet RegWithFormat
 -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs -> Regs -> Regs
forall a b. Coercible a b => a -> b
coerce ((UniqSet RegWithFormat
  -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
 -> Regs -> Regs -> Regs)
-> (UniqSet RegWithFormat
    -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs
-> Regs
-> Regs
forall a b. (a -> b) -> a -> b
$ forall a. UniqSet a -> UniqSet a -> UniqSet a
minusUniqSet @RegWithFormat

unionRegsMaxFmt :: Regs -> Regs -> Regs
unionRegsMaxFmt :: Regs -> Regs -> Regs
unionRegsMaxFmt = (UniqSet RegWithFormat
 -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs -> Regs -> Regs
forall a b. Coercible a b => a -> b
coerce ((UniqSet RegWithFormat
  -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
 -> Regs -> Regs -> Regs)
-> (UniqSet RegWithFormat
    -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs
-> Regs
-> Regs
forall a b. (a -> b) -> a -> b
$ (RegWithFormat -> RegWithFormat -> RegWithFormat)
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
forall a. (a -> a -> a) -> UniqSet a -> UniqSet a -> UniqSet a
strictUnionUniqSets_C RegWithFormat -> RegWithFormat -> RegWithFormat
maxRegWithFormat
  -- Don't build up thunks when combining with 'maxRegWithFormat'

unionManyRegsMaxFmt :: [Regs] -> Regs
unionManyRegsMaxFmt :: [Regs] -> Regs
unionManyRegsMaxFmt = ([UniqSet RegWithFormat] -> UniqSet RegWithFormat)
-> [Regs] -> Regs
forall a b. Coercible a b => a -> b
coerce (([UniqSet RegWithFormat] -> UniqSet RegWithFormat)
 -> [Regs] -> Regs)
-> ([UniqSet RegWithFormat] -> UniqSet RegWithFormat)
-> [Regs]
-> Regs
forall a b. (a -> b) -> a -> b
$ (RegWithFormat -> RegWithFormat -> RegWithFormat)
-> [UniqSet RegWithFormat] -> UniqSet RegWithFormat
forall a. (a -> a -> a) -> [UniqSet a] -> UniqSet a
strictUnionManyUniqSets_C RegWithFormat -> RegWithFormat -> RegWithFormat
maxRegWithFormat
  -- Don't build up thunks when combining with 'maxRegWithFormat'

intersectRegsMaxFmt :: Regs -> Regs -> Regs
intersectRegsMaxFmt :: Regs -> Regs -> Regs
intersectRegsMaxFmt = (UniqSet RegWithFormat
 -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs -> Regs -> Regs
forall a b. Coercible a b => a -> b
coerce ((UniqSet RegWithFormat
  -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
 -> Regs -> Regs -> Regs)
-> (UniqSet RegWithFormat
    -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs
-> Regs
-> Regs
forall a b. (a -> b) -> a -> b
$ (RegWithFormat -> RegWithFormat -> RegWithFormat)
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
forall a. (a -> a -> a) -> UniqSet a -> UniqSet a -> UniqSet a
strictIntersectUniqSets_C RegWithFormat -> RegWithFormat -> RegWithFormat
maxRegWithFormat
  -- Don't build up thunks when combining with 'maxRegWithFormat'

-- | Computes the set of registers in both arguments whose size is smaller in
-- the second argument than in the first.
shrinkingRegs :: Regs -> Regs -> Regs
shrinkingRegs :: Regs -> Regs -> Regs
shrinkingRegs = (UniqSet RegWithFormat
 -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs -> Regs -> Regs
forall a b. Coercible a b => a -> b
coerce ((UniqSet RegWithFormat
  -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
 -> Regs -> Regs -> Regs)
-> (UniqSet RegWithFormat
    -> UniqSet RegWithFormat -> UniqSet RegWithFormat)
-> Regs
-> Regs
-> Regs
forall a b. (a -> b) -> a -> b
$ (RegWithFormat -> RegWithFormat -> Maybe RegWithFormat)
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
-> UniqSet RegWithFormat
forall a.
(a -> a -> Maybe a) -> UniqSet a -> UniqSet a -> UniqSet a
minusUniqSet_C RegWithFormat -> RegWithFormat -> Maybe RegWithFormat
f
  where
    f :: RegWithFormat -> RegWithFormat -> Maybe RegWithFormat
    f :: RegWithFormat -> RegWithFormat -> Maybe RegWithFormat
f (RegWithFormat Reg
_ Format
fmt1) r2 :: RegWithFormat
r2@(RegWithFormat Reg
_ Format
fmt2)
      | Format
fmt2 Format -> Format -> Bool
forall a. Ord a => a -> a -> Bool
< Format
fmt1
      = RegWithFormat -> Maybe RegWithFormat
forall a. a -> Maybe a
Just RegWithFormat
r2
      | Bool
otherwise
      = Maybe RegWithFormat
forall a. Maybe a
Nothing

-- | Map a function that may change the 'Unique' of the register,
-- which entails going via lists.
--
-- See Note [UniqSet invariant] in GHC.Types.Unique.Set.
mapRegs :: (Reg -> Reg) -> Regs -> Regs
mapRegs :: (Reg -> Reg) -> Regs -> Regs
mapRegs Reg -> Reg
f (Regs UniqSet RegWithFormat
live) =
  UniqSet RegWithFormat -> Regs
UniqSet RegWithFormat -> Regs
Regs (UniqSet RegWithFormat -> Regs) -> UniqSet RegWithFormat -> Regs
forall a b. (a -> b) -> a -> b
$
    (RegWithFormat -> RegWithFormat)
-> UniqSet RegWithFormat -> UniqSet RegWithFormat
forall b a. Uniquable b => (a -> b) -> UniqSet a -> UniqSet b
mapUniqSet (\ (RegWithFormat Reg
r Format
fmt) -> Reg -> Format -> RegWithFormat
RegWithFormat (Reg -> Reg
f Reg
r) Format
fmt) UniqSet RegWithFormat
live

elemRegs :: Reg -> Regs -> Bool
elemRegs :: Reg -> Regs -> Bool
elemRegs Reg
r (Regs UniqSet RegWithFormat
live) = Unique -> UniqSet RegWithFormat -> Bool
forall a. Unique -> UniqSet a -> Bool
elemUniqSet_Directly (Reg -> Unique
forall a. Uniquable a => a -> Unique
getUnique Reg
r) UniqSet RegWithFormat
live

lookupReg :: Reg -> Regs -> Maybe Format
lookupReg :: Reg -> Regs -> Maybe Format
lookupReg Reg
r (Regs UniqSet RegWithFormat
live) =
  RegWithFormat -> Format
regWithFormat_format (RegWithFormat -> Format) -> Maybe RegWithFormat -> Maybe Format
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UniqSet RegWithFormat -> Unique -> Maybe RegWithFormat
forall a. UniqSet a -> Unique -> Maybe a
lookupUniqSet_Directly UniqSet RegWithFormat
live (Reg -> Unique
forall a. Uniquable a => a -> Unique
getUnique Reg
r)