{-# LANGUAGE CPP #-}

module Distribution.Utils.MapAccum (mapAccumM) where

import Distribution.Compat.Prelude
import Prelude ()

-- Like StateT but with return tuple swapped
newtype StateM s m a = StateM {forall s (m :: * -> *) a. StateM s m a -> s -> m (s, a)
runStateM :: s -> m (s, a)}

instance Functor m => Functor (StateM s m) where
  fmap :: forall a b. (a -> b) -> StateM s m a -> StateM s m b
fmap a -> b
f (StateM s -> m (s, a)
x) = (s -> m (s, b)) -> StateM s m b
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateM s m a
StateM ((s -> m (s, b)) -> StateM s m b)
-> (s -> m (s, b)) -> StateM s m b
forall a b. (a -> b) -> a -> b
$ \s
s -> ((s, a) -> (s, b)) -> m (s, a) -> m (s, b)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(s
s', a
a) -> (s
s', a -> b
f a
a)) (s -> m (s, a)
x s
s)

instance Monad m => Applicative (StateM s m) where
  pure :: forall a. a -> StateM s m a
pure a
x = (s -> m (s, a)) -> StateM s m a
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateM s m a
StateM ((s -> m (s, a)) -> StateM s m a)
-> (s -> m (s, a)) -> StateM s m a
forall a b. (a -> b) -> a -> b
$ \s
s -> (s, a) -> m (s, a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (s
s, a
x)
  StateM s -> m (s, a -> b)
f <*> :: forall a b. StateM s m (a -> b) -> StateM s m a -> StateM s m b
<*> StateM s -> m (s, a)
x = (s -> m (s, b)) -> StateM s m b
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateM s m a
StateM ((s -> m (s, b)) -> StateM s m b)
-> (s -> m (s, b)) -> StateM s m b
forall a b. (a -> b) -> a -> b
$ \s
s -> do
    (s', f') <- s -> m (s, a -> b)
f s
s
    (s'', x') <- x s'
    return (s'', f' x')

-- | Monadic variant of 'mapAccumL'.
mapAccumM
  :: (Monad m, Traversable t)
  => (a -> b -> m (a, c))
  -> a
  -> t b
  -> m (a, t c)
mapAccumM :: forall (m :: * -> *) (t :: * -> *) a b c.
(Monad m, Traversable t) =>
(a -> b -> m (a, c)) -> a -> t b -> m (a, t c)
mapAccumM a -> b -> m (a, c)
f a
s t b
t = StateM a m (t c) -> a -> m (a, t c)
forall s (m :: * -> *) a. StateM s m a -> s -> m (s, a)
runStateM ((b -> StateM a m c) -> t b -> StateM a m (t c)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse (\b
x -> (a -> m (a, c)) -> StateM a m c
forall s (m :: * -> *) a. (s -> m (s, a)) -> StateM s m a
StateM (\a
s' -> a -> b -> m (a, c)
f a
s' b
x)) t b
t) a
s