{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UnliftedFFITypes #-}
{-# LANGUAGE GHCForeignImportPrim #-}
{-# OPTIONS_HADDOCK not-home #-}
module GHC.Internal.AllocationLimitHandler
  ( runAllocationLimitHandler
  , setGlobalAllocationLimitHandler
  , AllocationLimitKillBehaviour(..)
  , getAllocationCounterFor
  , setAllocationCounterFor
  , enableAllocationLimitFor
  , disableAllocationLimitFor
  )
  where
import GHC.Internal.Base
import GHC.Internal.Conc.Sync (ThreadId(..))
import GHC.Internal.Data.IORef (IORef, readIORef, writeIORef, newIORef)
import GHC.Internal.Foreign.C.Types
import GHC.Internal.IO (unsafePerformIO)
import GHC.Internal.Int (Int64(..))


{-# NOINLINE allocationLimitHandler #-}
allocationLimitHandler :: IORef (ThreadId -> IO ())
allocationLimitHandler :: IORef (ThreadId -> IO ())
allocationLimitHandler = IO (IORef (ThreadId -> IO ())) -> IORef (ThreadId -> IO ())
forall a. IO a -> a
unsafePerformIO ((ThreadId -> IO ()) -> IO (IORef (ThreadId -> IO ()))
forall a. a -> IO (IORef a)
newIORef ThreadId -> IO ()
defaultHandler)

defaultHandler :: ThreadId -> IO ()
defaultHandler :: ThreadId -> IO ()
defaultHandler ThreadId
_ = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

foreign import ccall "setAllocLimitKill" setAllocLimitKill :: CBool -> CBool -> IO ()

runAllocationLimitHandler :: ThreadId# -> IO ()
runAllocationLimitHandler :: ThreadId# -> IO ()
runAllocationLimitHandler ThreadId#
tid = do
  hook <- IO (ThreadId -> IO ())
getAllocationLimitHandler
  hook $ ThreadId tid

getAllocationLimitHandler :: IO (ThreadId -> IO ())
getAllocationLimitHandler :: IO (ThreadId -> IO ())
getAllocationLimitHandler = IORef (ThreadId -> IO ()) -> IO (ThreadId -> IO ())
forall a. IORef a -> IO a
readIORef IORef (ThreadId -> IO ())
allocationLimitHandler

data AllocationLimitKillBehaviour =
  KillOnAllocationLimit
  -- ^ Throw a @AllocationLimitExceeded@ async exception to the thread when the
  -- allocation limit is exceeded.
  | DontKillOnAllocationLimit
  -- ^ Do not throw an exception when the allocation limit is exceeded.

-- | Define the behaviour for handling allocation limits.
-- The default behaviour is to throw an @AllocationLimitExceeded@ async exception to the thread.
-- This can be overriden using @AllocationLimitKillBehaviour@.
--
-- We can set a user-specified handler, which can be run in addition to
-- or in place of the exception.
-- This allows for instance logging on the allocation limit being exceeded,
-- or dynamically determining whether to terminate the thread.
-- The handler is not guaranteed to run before the thread is terminated or restarted.
--
-- Note: that if you don't terminate the thread, then the allocation limit gets
-- removed.
-- If you wish to keep the allocation limit you will have to reset it using
-- @setAllocationCounter@ and @enableAllocationLimit@.
setGlobalAllocationLimitHandler :: AllocationLimitKillBehaviour -> Maybe (ThreadId -> IO ()) -> IO ()
setGlobalAllocationLimitHandler :: AllocationLimitKillBehaviour -> Maybe (ThreadId -> IO ()) -> IO ()
setGlobalAllocationLimitHandler AllocationLimitKillBehaviour
killBehaviour Maybe (ThreadId -> IO ())
mHandler = do
  shouldRunHandler <- case Maybe (ThreadId -> IO ())
mHandler of
    Just ThreadId -> IO ()
hook -> do
      IORef (ThreadId -> IO ()) -> (ThreadId -> IO ()) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (ThreadId -> IO ())
allocationLimitHandler ThreadId -> IO ()
hook
      CBool -> IO CBool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CBool
1
    Maybe (ThreadId -> IO ())
Nothing -> do
      IORef (ThreadId -> IO ()) -> (ThreadId -> IO ()) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (ThreadId -> IO ())
allocationLimitHandler ThreadId -> IO ()
defaultHandler
      CBool -> IO CBool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CBool
0
  let shouldKill =
        case AllocationLimitKillBehaviour
killBehaviour of
          AllocationLimitKillBehaviour
KillOnAllocationLimit -> CBool
1
          AllocationLimitKillBehaviour
DontKillOnAllocationLimit -> CBool
0
  setAllocLimitKill shouldKill shouldRunHandler

-- | Retrieves the allocation counter for the another thread.
foreign import prim "stg_getOtherThreadAllocationCounterzh" getOtherThreadAllocationCounter#
  :: ThreadId#
  -> State# RealWorld
  -> (# State# RealWorld, Int64# #)

-- | Get the allocation counter for a different thread.
--
-- Note: this doesn't take the current nursery chunk into account.
-- If the thread is running then it may underestimate allocations by the size of a nursery thread.
getAllocationCounterFor :: ThreadId -> IO Int64
getAllocationCounterFor :: ThreadId -> IO Int64
getAllocationCounterFor (ThreadId ThreadId#
t#) = (State# RealWorld -> (# State# RealWorld, Int64 #)) -> IO Int64
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Int64 #)) -> IO Int64)
-> (State# RealWorld -> (# State# RealWorld, Int64 #)) -> IO Int64
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
  case ThreadId# -> State# RealWorld -> (# State# RealWorld, Int64# #)
getOtherThreadAllocationCounter# ThreadId#
t# State# RealWorld
s of (# State# RealWorld
s', Int64#
i# #)  -> (# State# RealWorld
s', Int64# -> Int64
I64# Int64#
i# #)

-- | Set the allocation counter for a different thread.
-- This can be combined with 'enableAllocationLimitFor' to enable allocation limits for another thread.
-- You may wish to do this during a user-specified allocation limit handler.
--
-- Note: this doesn't take the current nursery chunk into account.
-- If the thread is running then it may overestimate allocations by the size of a nursery thread,
-- and trigger the limit sooner than expected.
setAllocationCounterFor :: Int64 -> ThreadId -> IO ()
setAllocationCounterFor :: Int64 -> ThreadId -> IO ()
setAllocationCounterFor (I64# Int64#
i#) (ThreadId ThreadId#
t#) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
  case Int64# -> ThreadId# -> State# RealWorld -> State# RealWorld
setOtherThreadAllocationCounter# Int64#
i# ThreadId#
t# State# RealWorld
s of State# RealWorld
s' -> (# State# RealWorld
s', () #)


-- | Enable allocation limit processing the thread @t@.
enableAllocationLimitFor :: ThreadId -> IO ()
enableAllocationLimitFor :: ThreadId -> IO ()
enableAllocationLimitFor (ThreadId ThreadId#
t) = do
  ThreadId# -> IO ()
rts_enableThreadAllocationLimit ThreadId#
t

-- | Disable allocation limit processing the thread @t@.
disableAllocationLimitFor :: ThreadId -> IO ()
disableAllocationLimitFor :: ThreadId -> IO ()
disableAllocationLimitFor (ThreadId ThreadId#
t) = do
  ThreadId# -> IO ()
rts_disableThreadAllocationLimit ThreadId#
t

foreign import ccall unsafe "rts_enableThreadAllocationLimit"
  rts_enableThreadAllocationLimit :: ThreadId# -> IO ()

foreign import ccall unsafe "rts_disableThreadAllocationLimit"
  rts_disableThreadAllocationLimit :: ThreadId# -> IO ()