--------------------------------------------------------------------------------
-- | This provides a simple stand-alone server for 'WebSockets' applications.
-- Note that in production you want to use a real webserver such as snap or
-- warp.
{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Server
    ( ServerApp
    , runServer
    , ServerOptions (..)
    , defaultServerOptions
    , runServerWithOptions
    , runServerWith
    , makeListenSocket
    , makePendingConnection
    , makePendingConnectionFromStream

    , PongTimeout
    ) where


--------------------------------------------------------------------------------
import           Control.Concurrent            (threadDelay)
import qualified Control.Concurrent.Async      as Async
import           Control.Exception             (Exception, allowInterrupt,
                                                bracket, bracketOnError,
                                                finally, mask_, throwIO)
import           Control.Monad                 (forever, void, when)
import qualified Data.IORef                    as IORef
import           Data.Maybe                    (isJust)
import           Network.Socket                (Socket)
import qualified Network.Socket                as S
import qualified System.Clock                  as Clock


--------------------------------------------------------------------------------
import           Network.WebSockets.Connection
import           Network.WebSockets.Http
import qualified Network.WebSockets.Stream     as Stream
import           Network.WebSockets.Types


--------------------------------------------------------------------------------
-- | WebSockets application that can be ran by a server. Once this 'IO' action
-- finishes, the underlying socket is closed automatically.
type ServerApp = PendingConnection -> IO ()


--------------------------------------------------------------------------------
-- | Provides a simple server. This function blocks forever.  Note that this
-- is merely provided for quick-and-dirty or internal applications, but for real
-- applications, you should use a real server.
--
-- For example:
--
-- * Performance is reasonable under load, but:
-- * No protection against DoS attacks is provided.
-- * No logging is performed.
-- * ...
--
-- Glue for using this package with real servers is provided by:
--
-- * <https://hackage.haskell.org/package/wai-websockets>
--
-- * <https://hackage.haskell.org/package/websockets-snap>
runServer :: String     -- ^ Address to bind
          -> Int        -- ^ Port to listen on
          -> ServerApp  -- ^ Application
          -> IO ()      -- ^ Never returns
runServer :: String -> Int -> ServerApp -> IO ()
runServer String
host Int
port ServerApp
app = String -> Int -> ConnectionOptions -> ServerApp -> IO ()
runServerWith String
host Int
port ConnectionOptions
defaultConnectionOptions ServerApp
app


--------------------------------------------------------------------------------
-- | A version of 'runServer' which allows you to customize some options.
runServerWith :: String -> Int -> ConnectionOptions -> ServerApp -> IO ()
runServerWith :: String -> Int -> ConnectionOptions -> ServerApp -> IO ()
runServerWith String
host Int
port ConnectionOptions
opts = forall a. ServerOptions -> ServerApp -> IO a
runServerWithOptions ServerOptions
defaultServerOptions
    { serverHost :: String
serverHost              = String
host
    , serverPort :: Int
serverPort              = Int
port
    , serverConnectionOptions :: ConnectionOptions
serverConnectionOptions = ConnectionOptions
opts
    }
{-# DEPRECATED runServerWith "Use 'runServerWithOptions' instead" #-}


--------------------------------------------------------------------------------
data ServerOptions = ServerOptions
    { ServerOptions -> String
serverHost              :: String
    , ServerOptions -> Int
serverPort              :: Int
    , ServerOptions -> ConnectionOptions
serverConnectionOptions :: ConnectionOptions
    -- | Require a pong from the client every N seconds; otherwise kill the
    -- connection.  If you use this, you should also use 'withPingThread' to
    -- send a ping at a smaller interval; for example N/2.
    , ServerOptions -> Maybe Int
serverRequirePong       :: Maybe Int
    }


--------------------------------------------------------------------------------
defaultServerOptions :: ServerOptions
defaultServerOptions :: ServerOptions
defaultServerOptions = ServerOptions
    { serverHost :: String
serverHost              = String
"127.0.0.1"
    , serverPort :: Int
serverPort              = Int
8080
    , serverConnectionOptions :: ConnectionOptions
serverConnectionOptions = ConnectionOptions
defaultConnectionOptions
    , serverRequirePong :: Maybe Int
serverRequirePong       = forall a. Maybe a
Nothing
    }


--------------------------------------------------------------------------------
-- | Customizable version of 'runServer'.  Never returns until killed.
--
-- Please use the 'defaultServerOptions' combined with record updates to set the
-- fields you want.  This way your code is unlikely to break on future changes.
runServerWithOptions :: ServerOptions -> ServerApp -> IO a
runServerWithOptions :: forall a. ServerOptions -> ServerApp -> IO a
runServerWithOptions ServerOptions
opts ServerApp
app = forall a. IO a -> IO a
S.withSocketsDo forall a b. (a -> b) -> a -> b
$
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    (String -> Int -> IO Socket
makeListenSocket String
host Int
port)
    Socket -> IO ()
S.close forall a b. (a -> b) -> a -> b
$ \Socket
sock -> forall a. IO a -> IO a
mask_ forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
        IO ()
allowInterrupt
        (Socket
conn, SockAddr
_) <- Socket -> IO (Socket, SockAddr)
S.accept Socket
sock

        -- This IORef holds a time at which the thread may be killed.  This time
        -- can be extended by calling 'tickle'.
        IORef Int64
killRef <- forall a. a -> IO (IORef a)
IORef.newIORef forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. Num a => a -> a -> a
+ Int64
killDelay) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int64
getSecs
        let tickle :: IO ()
tickle = forall a. IORef a -> a -> IO ()
IORef.writeIORef IORef Int64
killRef forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall a. Num a => a -> a -> a
+ Int64
killDelay) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int64
getSecs

        -- Update the connection options to call 'tickle' whenever a pong is
        -- received.
        let connOpts' :: ConnectionOptions
connOpts'
                | Bool -> Bool
not Bool
useKiller = ConnectionOptions
connOpts
                | Bool
otherwise     = ConnectionOptions
connOpts
                    { connectionOnPong :: IO ()
connectionOnPong = IO ()
tickle forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConnectionOptions -> IO ()
connectionOnPong ConnectionOptions
connOpts
                    }

        -- Run the application.
        Async ()
appAsync  <- forall a. ((forall a. IO a -> IO a) -> IO a) -> IO (Async a)
Async.asyncWithUnmask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
            (forall a. IO a -> IO a
unmask forall a b. (a -> b) -> a -> b
$ do
                Socket -> ConnectionOptions -> ServerApp -> IO ()
runApp Socket
conn ConnectionOptions
connOpts' ServerApp
app) forall a b. IO a -> IO b -> IO a
`finally`
            (Socket -> IO ()
S.close Socket
conn)

        -- Install the killer if required.
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
useKiller forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. IO a -> IO (Async a)
Async.async (forall {a}. IORef Int64 -> Async a -> IO ()
killer IORef Int64
killRef Async ()
appAsync)
  where
    host :: String
host     = ServerOptions -> String
serverHost ServerOptions
opts
    port :: Int
port     = ServerOptions -> Int
serverPort ServerOptions
opts
    connOpts :: ConnectionOptions
connOpts = ServerOptions -> ConnectionOptions
serverConnectionOptions ServerOptions
opts

    -- Get the current number of seconds on some clock.
    getSecs :: IO Int64
getSecs = TimeSpec -> Int64
Clock.sec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Clock -> IO TimeSpec
Clock.getTime Clock
Clock.Monotonic

    -- Parse the 'serverRequirePong' options.
    useKiller :: Bool
useKiller = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ ServerOptions -> Maybe Int
serverRequirePong ServerOptions
opts
    killDelay :: Int64
killDelay = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Int64
0 forall a b. (Integral a, Num b) => a -> b
fromIntegral (ServerOptions -> Maybe Int
serverRequirePong ServerOptions
opts)

    -- Thread that reads the killRef, and kills the application if enough time
    -- has passed.
    killer :: IORef Int64 -> Async a -> IO ()
killer IORef Int64
killRef Async a
appAsync = do
        Int64
killAt   <- forall a. IORef a -> IO a
IORef.readIORef IORef Int64
killRef
        Int64
now      <- IO Int64
getSecs
        Maybe (Either SomeException a)
appState <- forall a. Async a -> IO (Maybe (Either SomeException a))
Async.poll Async a
appAsync
        case Maybe (Either SomeException a)
appState of
            -- Already finished/killed/crashed, we can give up.
            Just Either SomeException a
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            -- Should not be killed yet.  Wait and try again.
            Maybe (Either SomeException a)
Nothing | Int64
now forall a. Ord a => a -> a -> Bool
< Int64
killAt -> do
                Int -> IO ()
threadDelay (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
killDelay forall a. Num a => a -> a -> a
* Int
1000 forall a. Num a => a -> a -> a
* Int
1000)
                IORef Int64 -> Async a -> IO ()
killer IORef Int64
killRef Async a
appAsync
            -- Time to kill.
            Maybe (Either SomeException a)
_ -> forall e a. Exception e => Async a -> e -> IO ()
Async.cancelWith Async a
appAsync PongTimeout
PongTimeout


--------------------------------------------------------------------------------
-- | Create a standardized socket on which you can listen for incomming
-- connections. Should only be used for a quick and dirty solution! Should be
-- preceded by the call 'Network.Socket.withSocketsDo'.
makeListenSocket :: String -> Int -> IO Socket
makeListenSocket :: String -> Int -> IO Socket
makeListenSocket String
host Int
port = do
  AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe String -> Maybe String -> IO [AddrInfo]
S.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
hints) (forall a. a -> Maybe a
Just String
host) (forall a. a -> Maybe a
Just (forall a. Show a => a -> String
show Int
port))
  forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
    (Family -> SocketType -> ProtocolNumber -> IO Socket
S.socket (AddrInfo -> Family
S.addrFamily AddrInfo
addr) SocketType
S.Stream ProtocolNumber
S.defaultProtocol)
    Socket -> IO ()
S.close
    (\Socket
sock -> do
        ()
_     <- Socket -> SocketOption -> Int -> IO ()
S.setSocketOption Socket
sock SocketOption
S.ReuseAddr Int
1
        ()
_     <- Socket -> SocketOption -> Int -> IO ()
S.setSocketOption Socket
sock SocketOption
S.NoDelay   Int
1
        Socket -> SockAddr -> IO ()
S.bind Socket
sock (AddrInfo -> SockAddr
S.addrAddress AddrInfo
addr)
        Socket -> Int -> IO ()
S.listen Socket
sock Int
5
        forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
        )
  where
    hints :: AddrInfo
hints = AddrInfo
S.defaultHints { addrSocketType :: SocketType
S.addrSocketType = SocketType
S.Stream }


--------------------------------------------------------------------------------
runApp :: Socket
       -> ConnectionOptions
       -> ServerApp
       -> IO ()
runApp :: Socket -> ConnectionOptions -> ServerApp -> IO ()
runApp Socket
socket ConnectionOptions
opts ServerApp
app =
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
        (Socket -> ConnectionOptions -> IO PendingConnection
makePendingConnection Socket
socket ConnectionOptions
opts)
        (Stream -> IO ()
Stream.close forall b c a. (b -> c) -> (a -> b) -> a -> c
. PendingConnection -> Stream
pendingStream)
        ServerApp
app


--------------------------------------------------------------------------------
-- | Turns a socket, connected to some client, into a 'PendingConnection'. The
-- 'PendingConnection' should be closed using 'Stream.close' later.
makePendingConnection
    :: Socket -> ConnectionOptions -> IO PendingConnection
makePendingConnection :: Socket -> ConnectionOptions -> IO PendingConnection
makePendingConnection Socket
socket ConnectionOptions
opts = do
    Stream
stream <- Socket -> IO Stream
Stream.makeSocketStream Socket
socket
    Stream -> ConnectionOptions -> IO PendingConnection
makePendingConnectionFromStream Stream
stream ConnectionOptions
opts


-- | More general version of 'makePendingConnection' for 'Stream.Stream'
-- instead of a 'Socket'.
makePendingConnectionFromStream
    :: Stream.Stream -> ConnectionOptions -> IO PendingConnection
makePendingConnectionFromStream :: Stream -> ConnectionOptions -> IO PendingConnection
makePendingConnectionFromStream Stream
stream ConnectionOptions
opts = do
    -- TODO: we probably want to send a 40x if the request is bad?
    Maybe RequestHead
mbRequest <- forall a. Stream -> Parser a -> IO (Maybe a)
Stream.parse Stream
stream (Bool -> Parser RequestHead
decodeRequestHead Bool
False)
    case Maybe RequestHead
mbRequest of
        Maybe RequestHead
Nothing      -> forall e a. Exception e => e -> IO a
throwIO ConnectionException
ConnectionClosed
        Just RequestHead
request -> forall (m :: * -> *) a. Monad m => a -> m a
return PendingConnection
            { pendingOptions :: ConnectionOptions
pendingOptions  = ConnectionOptions
opts
            , pendingRequest :: RequestHead
pendingRequest  = RequestHead
request
            , pendingOnAccept :: Connection -> IO ()
pendingOnAccept = \Connection
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            , pendingStream :: Stream
pendingStream   = Stream
stream
            }


--------------------------------------------------------------------------------
-- | Internally used exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Int -> PongTimeout -> ShowS
[PongTimeout] -> ShowS
PongTimeout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PongTimeout] -> ShowS
$cshowList :: [PongTimeout] -> ShowS
show :: PongTimeout -> String
$cshow :: PongTimeout -> String
showsPrec :: Int -> PongTimeout -> ShowS
$cshowsPrec :: Int -> PongTimeout -> ShowS
Show


--------------------------------------------------------------------------------
instance Exception PongTimeout