-- | This module provides a wrapper around a deque that can enforce additional
-- invariants at runtime for debugging purposes.

module Data.Concurrent.Deque.Debugger
       (DebugDeque(DebugDeque))
       where

import Data.IORef
import Control.Concurrent
import Data.Concurrent.Deque.Class

-- newtype DebugDeque d = DebugDeque d

-- | Warning, this enforces the excessively STRONG invariant that if any end of the
-- deque is non-threadsafe then it may ever only be touched by one thread during its
-- entire lifetime.
--
-- This extreme form of monagamy is easier to verify, because we don't have enough
-- information to know if two operations on different threads are racing with one
-- another or are properly synchronized.
--
-- The wrapper data structure has two IORefs to track the last thread that touched
-- the left and right end of the deque, respectively.
data DebugDeque d elt = DebugDeque (IORef (Maybe ThreadId), IORef (Maybe ThreadId)) (d elt) 


instance DequeClass d => DequeClass (DebugDeque d) where 
  pushL :: forall elt. DebugDeque d elt -> elt -> IO ()
pushL (DebugDeque (IORef (Maybe ThreadId)
ref,IORef (Maybe ThreadId)
_) d elt
q) elt
elt = do
    Bool -> IORef (Maybe ThreadId) -> IO ()
markThread (d elt -> Bool
forall elt. d elt -> Bool
forall (d :: * -> *) elt. DequeClass d => d elt -> Bool
leftThreadSafe d elt
q) IORef (Maybe ThreadId)
ref
    d elt -> elt -> IO ()
forall elt. d elt -> elt -> IO ()
forall (d :: * -> *) elt. DequeClass d => d elt -> elt -> IO ()
pushL d elt
q elt
elt

  tryPopR :: forall elt. DebugDeque d elt -> IO (Maybe elt)
tryPopR (DebugDeque (IORef (Maybe ThreadId)
_,IORef (Maybe ThreadId)
ref) d elt
q) = do
    Bool -> IORef (Maybe ThreadId) -> IO ()
markThread (d elt -> Bool
forall elt. d elt -> Bool
forall (d :: * -> *) elt. DequeClass d => d elt -> Bool
rightThreadSafe d elt
q) IORef (Maybe ThreadId)
ref
    d elt -> IO (Maybe elt)
forall elt. d elt -> IO (Maybe elt)
forall (d :: * -> *) elt. DequeClass d => d elt -> IO (Maybe elt)
tryPopR d elt
q 

  newQ :: forall elt. IO (DebugDeque d elt)
newQ = do IORef (Maybe ThreadId)
l <- Maybe ThreadId -> IO (IORef (Maybe ThreadId))
forall a. a -> IO (IORef a)
newIORef Maybe ThreadId
forall a. Maybe a
Nothing
            IORef (Maybe ThreadId)
r <- Maybe ThreadId -> IO (IORef (Maybe ThreadId))
forall a. a -> IO (IORef a)
newIORef Maybe ThreadId
forall a. Maybe a
Nothing
            (d elt -> DebugDeque d elt) -> IO (d elt) -> IO (DebugDeque d elt)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((IORef (Maybe ThreadId), IORef (Maybe ThreadId))
-> d elt -> DebugDeque d elt
forall (d :: * -> *) elt.
(IORef (Maybe ThreadId), IORef (Maybe ThreadId))
-> d elt -> DebugDeque d elt
DebugDeque (IORef (Maybe ThreadId)
l,IORef (Maybe ThreadId)
r)) IO (d elt)
forall elt. IO (d elt)
forall (d :: * -> *) elt. DequeClass d => IO (d elt)
newQ

  -- FIXME: What are the threadsafe rules for nullQ?
  nullQ :: forall elt. DebugDeque d elt -> IO Bool
nullQ (DebugDeque (IORef (Maybe ThreadId), IORef (Maybe ThreadId))
_ d elt
q) = d elt -> IO Bool
forall elt. d elt -> IO Bool
forall (d :: * -> *) elt. DequeClass d => d elt -> IO Bool
nullQ d elt
q
      
  leftThreadSafe :: forall elt. DebugDeque d elt -> Bool
leftThreadSafe  (DebugDeque (IORef (Maybe ThreadId), IORef (Maybe ThreadId))
_ d elt
q) = d elt -> Bool
forall elt. d elt -> Bool
forall (d :: * -> *) elt. DequeClass d => d elt -> Bool
leftThreadSafe d elt
q
  rightThreadSafe :: forall elt. DebugDeque d elt -> Bool
rightThreadSafe (DebugDeque (IORef (Maybe ThreadId), IORef (Maybe ThreadId))
_ d elt
q) = d elt -> Bool
forall elt. d elt -> Bool
forall (d :: * -> *) elt. DequeClass d => d elt -> Bool
rightThreadSafe d elt
q


instance PopL d => PopL (DebugDeque d) where 
  tryPopL :: forall elt. DebugDeque d elt -> IO (Maybe elt)
tryPopL (DebugDeque (IORef (Maybe ThreadId)
ref,IORef (Maybe ThreadId)
_) d elt
q) = do
    Bool -> IORef (Maybe ThreadId) -> IO ()
markThread (d elt -> Bool
forall elt. d elt -> Bool
forall (d :: * -> *) elt. DequeClass d => d elt -> Bool
leftThreadSafe d elt
q) IORef (Maybe ThreadId)
ref
    d elt -> IO (Maybe elt)
forall elt. d elt -> IO (Maybe elt)
forall (d :: * -> *) elt. PopL d => d elt -> IO (Maybe elt)
tryPopL d elt
q 

-- | Mark the last thread to use this endpoint.
markThread :: Bool -> IORef (Maybe ThreadId) -> IO ()
markThread Bool
True IORef (Maybe ThreadId)
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return () -- Don't bother tracking.
markThread Bool
False IORef (Maybe ThreadId)
ref = do
  Maybe ThreadId
last <- IORef (Maybe ThreadId) -> IO (Maybe ThreadId)
forall a. IORef a -> IO a
readIORef IORef (Maybe ThreadId)
ref
  ThreadId
tid  <- IO ThreadId
myThreadId
--  putStrLn$"Marking! "++show tid
  IORef (Maybe ThreadId)
-> (Maybe ThreadId -> (Maybe ThreadId, ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef (Maybe ThreadId)
ref ((Maybe ThreadId -> (Maybe ThreadId, ())) -> IO ())
-> (Maybe ThreadId -> (Maybe ThreadId, ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Maybe ThreadId
x ->
    case Maybe ThreadId
x of
      Maybe ThreadId
Nothing -> (ThreadId -> Maybe ThreadId
forall a. a -> Maybe a
Just ThreadId
tid, ())
      Just ThreadId
tid2
        | ThreadId
tid ThreadId -> ThreadId -> Bool
forall a. Eq a => a -> a -> Bool
== ThreadId
tid2 -> (ThreadId -> Maybe ThreadId
forall a. a -> Maybe a
Just ThreadId
tid,())
        | Bool
otherwise   -> [Char] -> (Maybe ThreadId, ())
forall a. HasCallStack => [Char] -> a
error([Char] -> (Maybe ThreadId, ())) -> [Char] -> (Maybe ThreadId, ())
forall a b. (a -> b) -> a -> b
$ [Char]
"DebugDeque: invariant violated, thread safety not allowed but accessed by: "[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++(ThreadId, ThreadId) -> [Char]
forall a. Show a => a -> [Char]
show (ThreadId
tid,ThreadId
tid2)