{-# LANGUAGE ForeignFunctionInterface, OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.Socket.BufferPool.Recv (
receive
, receiveBuf
, makeReceiveN
, makePlainReceiveN
) where
import qualified Data.ByteString as BS
import Data.ByteString.Internal (ByteString(..))
import Data.IORef
import Foreign.C.Error (eAGAIN, getErrno, throwErrno)
import Foreign.C.Types
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import GHC.Conc (threadWaitRead)
import Network.Socket (Socket, withFdSocket)
import System.Posix.Types (Fd(..))
#ifdef mingw32_HOST_OS
import GHC.IO.FD (FD(..), readRawBufferPtr)
import Network.Socket.BufferPool.Windows
#endif
import Network.Socket.BufferPool.Types
import Network.Socket.BufferPool.Buffer
receive :: Socket -> BufferPool -> Recv
receive :: Socket -> BufferPool -> Recv
receive Socket
sock BufferPool
pool = BufferPool -> (Buffer -> Int -> IO Int) -> Recv
withBufferPool BufferPool
pool forall a b. (a -> b) -> a -> b
$ \Buffer
ptr Int
size -> do
#if MIN_VERSION_network(3,1,0)
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
let size' :: CSize
size' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size
forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryRecv CInt
fd Buffer
ptr CSize
size'
receiveBuf :: Socket -> RecvBuf
receiveBuf :: Socket -> RecvBuf
receiveBuf Socket
sock Buffer
buf0 Int
siz0 = do
#if MIN_VERSION_network(3,1,0)
forall r. Socket -> (CInt -> IO r) -> IO r
withFdSocket Socket
sock forall a b. (a -> b) -> a -> b
$ \CInt
fd -> do
#elif MIN_VERSION_network(3,0,0)
fd <- fdSocket sock
#else
let fd = fdSocket sock
#endif
CInt -> RecvBuf
loop CInt
fd Buffer
buf0 Int
siz0
where
loop :: CInt -> RecvBuf
loop CInt
_ Buffer
_ Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
loop CInt
fd Buffer
buf Int
siz = do
Int
n <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Buffer -> CSize -> IO CInt
tryRecv CInt
fd Buffer
buf (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
siz)
if Int
n forall a. Eq a => a -> a -> Bool
== Int
0 then
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
else
CInt -> RecvBuf
loop CInt
fd (Buffer
buf forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n) (Int
siz forall a. Num a => a -> a -> a
- Int
n)
tryRecv :: CInt -> Buffer -> CSize -> IO CInt
tryRecv :: CInt -> Buffer -> CSize -> IO CInt
tryRecv CInt
sock Buffer
ptr CSize
size = IO CInt
go
where
go :: IO CInt
go = do
#ifdef mingw32_HOST_OS
bytes <- windowsThreadBlockHack $ fromIntegral <$> readRawBufferPtr "tryRecv" (FD sock 1) (castPtr ptr) 0 size
#else
CInt
bytes <- CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
c_recv CInt
sock (forall a b. Ptr a -> Ptr b
castPtr Buffer
ptr) CSize
size CInt
0
#endif
if CInt
bytes forall a. Eq a => a -> a -> Bool
== -CInt
1 then do
Errno
errno <- IO Errno
getErrno
if Errno
errno forall a. Eq a => a -> a -> Bool
== Errno
eAGAIN then do
Fd -> IO ()
threadWaitRead (CInt -> Fd
Fd CInt
sock)
IO CInt
go
else
forall a. String -> IO a
throwErrno String
"tryRecv"
else
forall (m :: * -> *) a. Monad m => a -> m a
return CInt
bytes
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO RecvN
makeReceiveN :: ByteString -> Recv -> RecvBuf -> IO RecvN
makeReceiveN ByteString
bs0 Recv
recv RecvBuf
recvBuf = do
IORef ByteString
ref <- forall a. a -> IO (IORef a)
newIORef ByteString
bs0
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf
makePlainReceiveN :: Socket -> Int -> Int -> ByteString -> IO RecvN
makePlainReceiveN :: Socket -> Int -> Int -> ByteString -> IO RecvN
makePlainReceiveN Socket
s Int
l Int
h ByteString
bs0 = do
IORef ByteString
ref <- forall a. a -> IO (IORef a)
newIORef ByteString
bs0
BufferPool
pool <- Int -> Int -> IO BufferPool
newBufferPool Int
l Int
h
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN IORef ByteString
ref (Socket -> BufferPool -> Recv
receive Socket
s BufferPool
pool) (Socket -> RecvBuf
receiveBuf Socket
s)
receiveN :: IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN :: IORef ByteString -> Recv -> RecvBuf -> RecvN
receiveN IORef ByteString
ref Recv
recv RecvBuf
recvBuf Int
size = do
ByteString
cached <- forall a. IORef a -> IO a
readIORef IORef ByteString
ref
(ByteString
bs, ByteString
leftover) <- ByteString -> Int -> Recv -> RecvBuf -> IO (ByteString, ByteString)
tryRecvN ByteString
cached Int
size Recv
recv RecvBuf
recvBuf
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
ref ByteString
leftover
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
tryRecvN :: ByteString -> Int -> IO ByteString -> RecvBuf -> IO (ByteString, ByteString)
tryRecvN :: ByteString -> Int -> Recv -> RecvBuf -> IO (ByteString, ByteString)
tryRecvN ByteString
init0 Int
siz0 Recv
recv RecvBuf
recvBuf
| Int
siz0 forall a. Ord a => a -> a -> Bool
<= Int
len0 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
siz0 ByteString
init0
| Int
siz0 forall a. Ord a => a -> a -> Bool
<= Int
4096 = [ByteString] -> Int -> IO (ByteString, ByteString)
recvWithPool [ByteString
init0] (Int
siz0 forall a. Num a => a -> a -> a
- Int
len0)
| Bool
otherwise = IO (ByteString, ByteString)
recvWithNewBuf
where
len0 :: Int
len0 = ByteString -> Int
BS.length ByteString
init0
recvWithPool :: [ByteString] -> Int -> IO (ByteString, ByteString)
recvWithPool [ByteString]
bss Int
siz = do
ByteString
bs <- Recv
recv
let len :: Int
len = ByteString -> Int
BS.length ByteString
bs
if Int
len forall a. Eq a => a -> a -> Bool
== Int
0 then
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
else if Int
len forall a. Ord a => a -> a -> Bool
>= Int
siz then do
let (ByteString
consume, ByteString
leftover) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
siz ByteString
bs
ret :: ByteString
ret = [ByteString] -> ByteString
BS.concat forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse (ByteString
consume forall a. a -> [a] -> [a]
: [ByteString]
bss)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
ret, ByteString
leftover)
else do
let bss' :: [ByteString]
bss' = ByteString
bs forall a. a -> [a] -> [a]
: [ByteString]
bss
siz' :: Int
siz' = Int
siz forall a. Num a => a -> a -> a
- Int
len
[ByteString] -> Int -> IO (ByteString, ByteString)
recvWithPool [ByteString]
bss' Int
siz'
recvWithNewBuf :: IO (ByteString, ByteString)
recvWithNewBuf = do
bs :: ByteString
bs@(PS ForeignPtr Word8
fptr Int
_ Int
_) <- RecvN
mallocBS Int
siz0
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr forall a b. (a -> b) -> a -> b
$ \Buffer
ptr -> do
Buffer
ptr' <- Buffer -> ByteString -> IO Buffer
copy Buffer
ptr ByteString
init0
Bool
full <- RecvBuf
recvBuf Buffer
ptr' (Int
siz0 forall a. Num a => a -> a -> a
- Int
len0)
if Bool
full then
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bs, ByteString
"")
else
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
"", ByteString
"")
#ifndef mingw32_HOST_OS
foreign import ccall unsafe "recv"
c_recv :: CInt -> Ptr CChar -> CSize -> CInt -> IO CInt
#endif