-- |A small selection of utilities that might be of use to others working with bytestring/number combinations.
module Crypto.Util where

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Unsafe (unsafeIndex, unsafeUseAsCStringLen)
import Data.Bits (shiftL, shiftR)
import Data.Bits (xor, setBit, shiftR, shiftL)
import Control.Exception (Exception, throw)
import Data.Tagged
import System.IO.Unsafe
import Foreign.C.Types
import Foreign.Ptr

-- |@incBS bs@ inefficiently computes the value @i2bs (8 * B.length bs) (bs2i bs + 1)@
incBS :: B.ByteString -> B.ByteString
incBS :: ByteString -> ByteString
incBS ByteString
bs = [ByteString] -> ByteString
B.concat (ByteString -> Int -> [ByteString]
go ByteString
bs (ByteString -> Int
B.length ByteString
bs forall a. Num a => a -> a -> a
- Int
1))
  where
  go :: ByteString -> Int -> [ByteString]
go ByteString
bs Int
i
        | ByteString -> Int
B.length ByteString
bs forall a. Eq a => a -> a -> Bool
== Int
0     = []
        | ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i forall a. Eq a => a -> a -> Bool
== Word8
0xFF = (ByteString -> Int -> [ByteString]
go (HasCallStack => ByteString -> ByteString
B.init ByteString
bs) (Int
iforall a. Num a => a -> a -> a
-Int
1)) forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton Word8
0]
        | Bool
otherwise            = [HasCallStack => ByteString -> ByteString
B.init ByteString
bs] forall a. [a] -> [a] -> [a]
++ [Word8 -> ByteString
B.singleton forall a b. (a -> b) -> a -> b
$ (ByteString -> Int -> Word8
unsafeIndex ByteString
bs Int
i) forall a. Num a => a -> a -> a
+ Word8
1]
{-# INLINE incBS #-}


-- |@i2bs bitLen i@ converts @i@ to a 'ByteString' of @bitLen@ bits (must be a multiple of 8).
i2bs :: Int -> Integer -> B.ByteString
i2bs :: Int -> Integer -> ByteString
i2bs Int
l Integer
i = forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\Int
l' -> if Int
l' forall a. Ord a => a -> a -> Bool
< Int
0 then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
i forall a. Bits a => a -> Int -> a
`shiftR` Int
l'), Int
l' forall a. Num a => a -> a -> a
- Int
8)) (Int
lforall a. Num a => a -> a -> a
-Int
8)
{-# INLINE i2bs #-}

-- |@i2bs_unsized i@ converts @i@ to a 'ByteString' of sufficient bytes to express the integer.
-- The integer must be non-negative and a zero will be encoded in one byte.
i2bs_unsized :: Integer -> B.ByteString
i2bs_unsized :: Integer -> ByteString
i2bs_unsized Integer
0 = Word8 -> ByteString
B.singleton Word8
0
i2bs_unsized Integer
i = ByteString -> ByteString
B.reverse forall a b. (a -> b) -> a -> b
$ forall a. (a -> Maybe (Word8, a)) -> a -> ByteString
B.unfoldr (\Integer
i' -> if Integer
i' forall a. Ord a => a -> a -> Bool
<= Integer
0 then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
i', (Integer
i' forall a. Bits a => a -> Int -> a
`shiftR` Int
8))) Integer
i
{-# INLINE i2bs_unsized #-}

-- | Useful utility to extract the result of a generator operation
-- and translate error results to exceptions.
throwLeft :: Exception e => Either e a -> a
throwLeft :: forall e a. Exception e => Either e a -> a
throwLeft (Left e
e)  = forall a e. Exception e => e -> a
throw e
e
throwLeft (Right a
a) = a
a

-- |Obtain a tagged value for a particular instantiated type.
for :: Tagged a b -> a -> b
for :: forall a b. Tagged a b -> a -> b
for Tagged a b
t a
_ = forall {k} (s :: k) b. Tagged s b -> b
unTagged Tagged a b
t

-- |Infix `for` operator
(.::.) :: Tagged a b -> a -> b
.::. :: forall a b. Tagged a b -> a -> b
(.::.) = forall a b. Tagged a b -> a -> b
for

-- | Checks two bytestrings for equality without breaches for
-- timing attacks.
--
-- Semantically, @constTimeEq = (==)@.  However, @x == y@ takes less
-- time when the first byte is different than when the first byte
-- is equal.  This side channel allows an attacker to mount a
-- timing attack.  On the other hand, @constTimeEq@ always takes the
-- same time regardless of the bytestrings' contents, unless they are
-- of difference size.
--
-- You should always use @constTimeEq@ when comparing secrets,
-- otherwise you may leave a significant security hole
-- (cf. <http://codahale.com/a-lesson-in-timing-attacks/>).
constTimeEq :: B.ByteString -> B.ByteString -> Bool
constTimeEq :: ByteString -> ByteString -> Bool
constTimeEq ByteString
s1 ByteString
s2 =
    forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$
    forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s1 forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s1_ptr, Int
s1_len) ->
    forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
s2 forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
s2_ptr, Int
s2_len) ->
    if Int
s1_len forall a. Eq a => a -> a -> Bool
/= Int
s2_len
      then forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
      else (forall a. Eq a => a -> a -> Bool
== CInt
0) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr CChar -> Ptr CChar -> CInt -> IO CInt
c_constTimeEq Ptr CChar
s1_ptr Ptr CChar
s2_ptr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
s1_len)

foreign import ccall unsafe
   c_constTimeEq :: Ptr CChar -> Ptr CChar -> CInt -> IO CInt

-- |Helper function to convert bytestrings to integers
bs2i :: B.ByteString -> Integer
bs2i :: ByteString -> Integer
bs2i ByteString
bs = forall a. (a -> Word8 -> a) -> a -> ByteString -> a
B.foldl' (\Integer
i Word8
b -> (Integer
i forall a. Bits a => a -> Int -> a
`shiftL` Int
8) forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Integer
0 ByteString
bs
{-# INLINE bs2i #-}

-- |zipWith xor + Pack
-- As a result of rewrite rules, this should automatically be
-- optimized (at compile time). to use the bytestring libraries
-- 'zipWith'' function.
zwp' :: B.ByteString -> B.ByteString -> B.ByteString
zwp' :: ByteString -> ByteString -> ByteString
zwp' ByteString
a = [Word8] -> ByteString
B.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Word8 -> Word8 -> a) -> ByteString -> ByteString -> [a]
B.zipWith forall a. Bits a => a -> a -> a
xor ByteString
a
{-# INLINE zwp' #-}

-- |zipWith xor + Pack
--
-- This is written intentionally to take advantage
-- of the bytestring libraries 'zipWith'' rewrite rule but at the
-- extra cost of the resulting lazy bytestring being more fragmented
-- than either of the two inputs.
zwp :: L.ByteString -> L.ByteString -> L.ByteString
zwp :: ByteString -> ByteString -> ByteString
zwp  ByteString
a ByteString
b = 
        let as :: [ByteString]
as = ByteString -> [ByteString]
L.toChunks ByteString
a
            bs :: [ByteString]
bs = ByteString -> [ByteString]
L.toChunks ByteString
b
        in [ByteString] -> ByteString
L.fromChunks ([ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as [ByteString]
bs)
  where
  go :: [ByteString] -> [ByteString] -> [ByteString]
go [] [ByteString]
_ = []
  go [ByteString]
_ [] = []
  go (ByteString
a:[ByteString]
as) (ByteString
b:[ByteString]
bs) =
        let l :: Int
l = forall a. Ord a => a -> a -> a
min (ByteString -> Int
B.length ByteString
a) (ByteString -> Int
B.length ByteString
b)
            (ByteString
a',ByteString
ar) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
a
            (ByteString
b',ByteString
br) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
l ByteString
b
            as' :: [ByteString]
as' = if ByteString -> Int
B.length ByteString
ar forall a. Eq a => a -> a -> Bool
== Int
0 then [ByteString]
as else ByteString
ar forall a. a -> [a] -> [a]
: [ByteString]
as
            bs' :: [ByteString]
bs' = if ByteString -> Int
B.length ByteString
br forall a. Eq a => a -> a -> Bool
== Int
0 then [ByteString]
bs else ByteString
br forall a. a -> [a] -> [a]
: [ByteString]
bs
        in (ByteString -> ByteString -> ByteString
zwp' ByteString
a' ByteString
b') forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString] -> [ByteString]
go [ByteString]
as' [ByteString]
bs'
{-# INLINEABLE zwp #-}

-- gather a specified number of bytes from the list of bytestrings
collect :: Int -> [B.ByteString] -> [B.ByteString]
collect :: Int -> [ByteString] -> [ByteString]
collect Int
0 [ByteString]
_ = []
collect Int
_ [] = []
collect Int
i (ByteString
b:[ByteString]
bs)
        | Int
len forall a. Ord a => a -> a -> Bool
< Int
i  = ByteString
b forall a. a -> [a] -> [a]
: Int -> [ByteString] -> [ByteString]
collect (Int
i forall a. Num a => a -> a -> a
- Int
len) [ByteString]
bs
        | Int
len forall a. Ord a => a -> a -> Bool
>= Int
i = [Int -> ByteString -> ByteString
B.take Int
i ByteString
b]
  where
  len :: Int
len = ByteString -> Int
B.length ByteString
b
{-# INLINE collect #-}