{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Control.Monad.Trans.UnionFind
  ( UnionFindT, runUnionFind
  , Point, fresh, repr, descriptor, union, equivalent
  ) where

import Control.Applicative (Applicative)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.State (StateT(..), evalStateT)
import Data.UnionFind.IntMap (Point)
import qualified Control.Monad.Trans.State as State
import qualified Data.UnionFind.IntMap as UF

-- | A monad transformer that adds union find operations.
--
-- The @p@ parameter is the type of points.  Uses the
-- "Data.UnionFind.IntMap" as the underlying union-find
-- implementation.
newtype UnionFindT p m a = UnionFindT {
  forall p (m :: * -> *) a.
UnionFindT p m a -> StateT (PointSupply p) m a
unUnionFindT :: StateT (UF.PointSupply p) m a
  } deriving (forall a b. a -> UnionFindT p m b -> UnionFindT p m a
forall a b. (a -> b) -> UnionFindT p m a -> UnionFindT p m b
forall p (m :: * -> *) a b.
Functor m =>
a -> UnionFindT p m b -> UnionFindT p m a
forall p (m :: * -> *) a b.
Functor m =>
(a -> b) -> UnionFindT p m a -> UnionFindT p m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> UnionFindT p m b -> UnionFindT p m a
$c<$ :: forall p (m :: * -> *) a b.
Functor m =>
a -> UnionFindT p m b -> UnionFindT p m a
fmap :: forall a b. (a -> b) -> UnionFindT p m a -> UnionFindT p m b
$cfmap :: forall p (m :: * -> *) a b.
Functor m =>
(a -> b) -> UnionFindT p m a -> UnionFindT p m b
Functor, forall a. a -> UnionFindT p m a
forall a b.
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m a
forall a b.
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
forall a b.
UnionFindT p m (a -> b) -> UnionFindT p m a -> UnionFindT p m b
forall a b c.
(a -> b -> c)
-> UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m c
forall {p} {m :: * -> *}. Monad m => Functor (UnionFindT p m)
forall p (m :: * -> *) a. Monad m => a -> UnionFindT p m a
forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m a
forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m (a -> b) -> UnionFindT p m a -> UnionFindT p m b
forall p (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m a
$c<* :: forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m a
*> :: forall a b.
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
$c*> :: forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
liftA2 :: forall a b c.
(a -> b -> c)
-> UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m c
$cliftA2 :: forall p (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m c
<*> :: forall a b.
UnionFindT p m (a -> b) -> UnionFindT p m a -> UnionFindT p m b
$c<*> :: forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m (a -> b) -> UnionFindT p m a -> UnionFindT p m b
pure :: forall a. a -> UnionFindT p m a
$cpure :: forall p (m :: * -> *) a. Monad m => a -> UnionFindT p m a
Applicative, forall a. a -> UnionFindT p m a
forall a b.
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
forall a b.
UnionFindT p m a -> (a -> UnionFindT p m b) -> UnionFindT p m b
forall p (m :: * -> *). Monad m => Applicative (UnionFindT p m)
forall p (m :: * -> *) a. Monad m => a -> UnionFindT p m a
forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> (a -> UnionFindT p m b) -> UnionFindT p m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> UnionFindT p m a
$creturn :: forall p (m :: * -> *) a. Monad m => a -> UnionFindT p m a
>> :: forall a b.
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
$c>> :: forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> UnionFindT p m b -> UnionFindT p m b
>>= :: forall a b.
UnionFindT p m a -> (a -> UnionFindT p m b) -> UnionFindT p m b
$c>>= :: forall p (m :: * -> *) a b.
Monad m =>
UnionFindT p m a -> (a -> UnionFindT p m b) -> UnionFindT p m b
Monad, forall p (m :: * -> *) a. Monad m => m a -> UnionFindT p m a
forall (m :: * -> *) a. Monad m => m a -> UnionFindT p m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
lift :: forall (m :: * -> *) a. Monad m => m a -> UnionFindT p m a
$clift :: forall p (m :: * -> *) a. Monad m => m a -> UnionFindT p m a
MonadTrans)

runUnionFind :: Monad m => UnionFindT p m a -> m a
runUnionFind :: forall (m :: * -> *) p a. Monad m => UnionFindT p m a -> m a
runUnionFind = (forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
`evalStateT` forall a. PointSupply a
UF.newPointSupply) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall p (m :: * -> *) a.
UnionFindT p m a -> StateT (PointSupply p) m a
unUnionFindT

swap :: (a, b) -> (b, a)
swap :: forall a b. (a, b) -> (b, a)
swap (a
x, b
y) = (b
y, a
x)

-- | Create a new point with the given descriptor.  The returned is
-- only equivalent to itself.
--
-- Note that a 'Point' has its own identity.  That is, if two points
-- are equivalent then their descriptors are equal, but not vice
-- versa.
--
fresh :: Monad m => p -> UnionFindT p m (Point p)
fresh :: forall (m :: * -> *) p. Monad m => p -> UnionFindT p m (Point p)
fresh p
x = forall p (m :: * -> *) a.
StateT (PointSupply p) m a -> UnionFindT p m a
UnionFindT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> (b, a)
swap forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. PointSupply a -> a -> (PointSupply a, Point a)
UF.fresh p
x

-- | /O(1)/. @repr point@ returns the representative point of
-- @point@'s equivalence class.
repr :: Monad m => Point p -> UnionFindT p m (Point p)
repr :: forall (m :: * -> *) p.
Monad m =>
Point p -> UnionFindT p m (Point p)
repr = forall p (m :: * -> *) a.
StateT (PointSupply p) m a -> UnionFindT p m a
UnionFindT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. PointSupply a -> Point a -> Point a
UF.repr

-- | Return the descriptor of the 
descriptor :: Monad m => Point p -> UnionFindT p m p
descriptor :: forall (m :: * -> *) p. Monad m => Point p -> UnionFindT p m p
descriptor = forall p (m :: * -> *) a.
StateT (PointSupply p) m a -> UnionFindT p m a
UnionFindT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. PointSupply a -> Point a -> a
UF.descriptor

-- | Join the equivalence classes of the points.  The resulting
-- equivalence class will get the descriptor of the second argument.
union :: Monad m => Point p -> Point p -> UnionFindT p m ()
union :: forall (m :: * -> *) p.
Monad m =>
Point p -> Point p -> UnionFindT p m ()
union Point p
p1 Point p
p2 = forall p (m :: * -> *) a.
StateT (PointSupply p) m a -> UnionFindT p m a
UnionFindT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
State.modify forall a b. (a -> b) -> a -> b
$ \PointSupply p
x -> forall a. PointSupply a -> Point a -> Point a -> PointSupply a
UF.union PointSupply p
x Point p
p1 Point p
p2

-- | Test if the two elements are in the same equivalence class.
-- 
-- @
-- liftA2 (==) (repr x) (repr y)
-- @
equivalent :: Monad m => Point p -> Point p -> UnionFindT p m Bool
equivalent :: forall (m :: * -> *) p.
Monad m =>
Point p -> Point p -> UnionFindT p m Bool
equivalent Point p
p1 Point p
p2 = forall p (m :: * -> *) a.
StateT (PointSupply p) m a -> UnionFindT p m a
UnionFindT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
State.gets forall a b. (a -> b) -> a -> b
$ \PointSupply p
x -> forall a. PointSupply a -> Point a -> Point a -> Bool
UF.equivalent PointSupply p
x Point p
p1 Point p
p2