module GHC.Data.UnionFind where
import GHC.Prelude
import Data.STRef
import Control.Monad.ST
import Control.Monad
newtype Point s a = Point (STRef s (Link s a))
deriving (Point s a -> Point s a -> Bool
(Point s a -> Point s a -> Bool)
-> (Point s a -> Point s a -> Bool) -> Eq (Point s a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall s a. Point s a -> Point s a -> Bool
$c== :: forall s a. Point s a -> Point s a -> Bool
== :: Point s a -> Point s a -> Bool
$c/= :: forall s a. Point s a -> Point s a -> Bool
/= :: Point s a -> Point s a -> Bool
Eq)
writePoint :: Point s a -> Link s a -> ST s ()
writePoint :: forall s a. Point s a -> Link s a -> ST s ()
writePoint (Point STRef s (Link s a)
v) = STRef s (Link s a) -> Link s a -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s (Link s a)
v
readPoint :: Point s a -> ST s (Link s a)
readPoint :: forall s a. Point s a -> ST s (Link s a)
readPoint (Point STRef s (Link s a)
v) = STRef s (Link s a) -> ST s (Link s a)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
v
data Link s a
= Info {-# UNPACK #-} !(STRef s Int) {-# UNPACK #-} !(STRef s a)
| Link {-# UNPACK #-} !(Point s a)
fresh :: a -> ST s (Point s a)
fresh :: forall a s. a -> ST s (Point s a)
fresh a
desc = do
weight <- Int -> ST s (STRef s Int)
forall a s. a -> ST s (STRef s a)
newSTRef Int
1
descriptor <- newSTRef desc
Point `fmap` newSTRef (Info weight descriptor)
repr :: Point s a -> ST s (Point s a)
repr :: forall s a. Point s a -> ST s (Point s a)
repr Point s a
point = Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point ST s (Link s a)
-> (Link s a -> ST s (Point s a)) -> ST s (Point s a)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r ->
case Link s a
r of
Link Point s a
point' -> do
point'' <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point'
when (point'' /= point') $ do
writePoint point =<< readPoint point'
return point''
Info STRef s Int
_ STRef s a
_ -> Point s a -> ST s (Point s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Point s a
point
find :: Point s a -> ST s a
find :: forall s a. Point s a -> ST s a
find Point s a
point =
Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point ST s (Link s a) -> (Link s a -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r ->
case Link s a
r of
Info STRef s Int
_ STRef s a
d_ref -> STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
d_ref
Link Point s a
point' -> Point s a -> ST s (Link s a)
forall s a. Point s a -> ST s (Link s a)
readPoint Point s a
point' ST s (Link s a) -> (Link s a -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Link s a
r' ->
case Link s a
r' of
Info STRef s Int
_ STRef s a
d_ref -> STRef s a -> ST s a
forall s a. STRef s a -> ST s a
readSTRef STRef s a
d_ref
Link Point s a
_ -> Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point ST s (Point s a) -> (Point s a -> ST s a) -> ST s a
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Point s a -> ST s a
forall s a. Point s a -> ST s a
find
union :: Point s a -> Point s a -> ST s ()
union :: forall s a. Point s a -> Point s a -> ST s ()
union Point s a
refpoint1 Point s a
refpoint2 = do
point1 <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
refpoint1
point2 <- repr refpoint2
when (point1 /= point2) $ do
l1 <- readPoint point1
l2 <- readPoint point2
case (l1, l2) of
(Info STRef s Int
wref1 STRef s a
dref1, Info STRef s Int
wref2 STRef s a
dref2) -> do
weight1 <- STRef s Int -> ST s Int
forall s a. STRef s a -> ST s a
readSTRef STRef s Int
wref1
weight2 <- readSTRef wref2
if weight1 >= weight2
then do
writePoint point2 (Link point1)
writeSTRef wref1 (weight1 + weight2)
writeSTRef dref1 =<< readSTRef dref2
else do
writePoint point1 (Link point2)
writeSTRef wref2 (weight1 + weight2)
(Link s a, Link s a)
_ -> [Char] -> ST s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"UnionFind.union: repr invariant broken"
equivalent :: Point s a -> Point s a -> ST s Bool
equivalent :: forall s a. Point s a -> Point s a -> ST s Bool
equivalent Point s a
point1 Point s a
point2 = (Point s a -> Point s a -> Bool)
-> ST s (Point s a) -> ST s (Point s a) -> ST s Bool
forall (m :: * -> *) a1 a2 r.
Monad m =>
(a1 -> a2 -> r) -> m a1 -> m a2 -> m r
liftM2 Point s a -> Point s a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point1) (Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point2)