{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Vector.Algorithms.Radix (sort, sortBy, Radix(..)) where
import Prelude hiding (read, length)
import Control.Monad
import Control.Monad.Primitive
import qualified Data.Vector.Primitive.Mutable as PV
import Data.Vector.Generic.Mutable
import Data.Vector.Algorithms.Common
import Data.Bits
import Data.Int
import Data.Word
import Foreign.Storable
class Radix e where
passes :: e -> Int
size :: e -> Int
radix :: Int -> e -> Int
instance Radix Int where
passes :: Int -> Int
passes Int
_ = forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: Int)
{-# INLINE passes #-}
size :: Int -> Int
size Int
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Int -> Int
radix Int
0 Int
e = Int
e forall a. Bits a => a -> a -> a
.&. Int
255
radix Int
i Int
e
| Int
i forall a. Eq a => a -> a -> Bool
== forall e. Radix e => e -> Int
passes Int
e forall a. Num a => a -> a -> a
- Int
1 = Int -> Int
radix' (Int
e forall a. Bits a => a -> a -> a
`xor` forall a. Bounded a => a
minBound)
| Bool
otherwise = Int -> Int
radix' Int
e
where radix' :: Int -> Int
radix' Int
e = (Int
e forall a. Bits a => a -> Int -> a
`shiftR` (Int
i forall a. Bits a => a -> Int -> a
`shiftL` Int
3)) forall a. Bits a => a -> a -> a
.&. Int
255
{-# INLINE radix #-}
instance Radix Int8 where
passes :: Int8 -> Int
passes Int8
_ = Int
1
{-# INLINE passes #-}
size :: Int8 -> Int
size Int8
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Int8 -> Int
radix Int
_ Int8
e = Int
255 forall a. Bits a => a -> a -> a
.&. forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
e forall a. Bits a => a -> a -> a
`xor` Int
128
{-# INLINE radix #-}
instance Radix Int16 where
passes :: Int16 -> Int
passes Int16
_ = Int
2
{-# INLINE passes #-}
size :: Int16 -> Int
size Int16
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Int16 -> Int
radix Int
0 Int16
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int16
e forall a. Bits a => a -> a -> a
.&. Int16
255)
radix Int
1 Int16
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (((Int16
e forall a. Bits a => a -> a -> a
`xor` forall a. Bounded a => a
minBound) forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Int16
255)
{-# INLINE radix #-}
instance Radix Int32 where
passes :: Int32 -> Int
passes Int32
_ = Int
4
{-# INLINE passes #-}
size :: Int32 -> Int
size Int32
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Int32 -> Int
radix Int
0 Int32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
e forall a. Bits a => a -> a -> a
.&. Int32
255)
radix Int
1 Int32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int32
e forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Int32
255)
radix Int
2 Int32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int32
e forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Int32
255)
radix Int
3 Int32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (((Int32
e forall a. Bits a => a -> a -> a
`xor` forall a. Bounded a => a
minBound) forall a. Bits a => a -> Int -> a
`shiftR` Int
24) forall a. Bits a => a -> a -> a
.&. Int32
255)
{-# INLINE radix #-}
instance Radix Int64 where
passes :: Int64 -> Int
passes Int64
_ = Int
8
{-# INLINE passes #-}
size :: Int64 -> Int
size Int64
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Int64 -> Int
radix Int
0 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
e forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
1 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
2 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
3 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
24) forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
4 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
32) forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
5 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
40) forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
6 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
48) forall a. Bits a => a -> a -> a
.&. Int64
255)
radix Int
7 Int64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (((Int64
e forall a. Bits a => a -> a -> a
`xor` forall a. Bounded a => a
minBound) forall a. Bits a => a -> Int -> a
`shiftR` Int
56) forall a. Bits a => a -> a -> a
.&. Int64
255)
{-# INLINE radix #-}
instance Radix Word where
passes :: Word -> Int
passes Word
_ = forall a. Storable a => a -> Int
sizeOf (forall a. HasCallStack => a
undefined :: Word)
{-# INLINE passes #-}
size :: Word -> Int
size Word
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Word -> Int
radix Int
0 Word
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word
e forall a. Bits a => a -> a -> a
.&. Word
255)
radix Int
i Word
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word
e forall a. Bits a => a -> Int -> a
`shiftR` (Int
i forall a. Bits a => a -> Int -> a
`shiftL` Int
3)) forall a. Bits a => a -> a -> a
.&. Word
255)
{-# INLINE radix #-}
instance Radix Word8 where
passes :: Word8 -> Int
passes Word8
_ = Int
1
{-# INLINE passes #-}
size :: Word8 -> Int
size Word8
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Word8 -> Int
radix Int
_ = forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE radix #-}
instance Radix Word16 where
passes :: Word16 -> Int
passes Word16
_ = Int
2
{-# INLINE passes #-}
size :: Word16 -> Int
size Word16
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Word16 -> Int
radix Int
0 Word16
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
e forall a. Bits a => a -> a -> a
.&. Word16
255)
radix Int
1 Word16
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word16
e forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Word16
255)
{-# INLINE radix #-}
instance Radix Word32 where
passes :: Word32 -> Int
passes Word32
_ = Int
4
{-# INLINE passes #-}
size :: Word32 -> Int
size Word32
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Word32 -> Int
radix Int
0 Word32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
e forall a. Bits a => a -> a -> a
.&. Word32
255)
radix Int
1 Word32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
e forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Word32
255)
radix Int
2 Word32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
e forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Word32
255)
radix Int
3 Word32
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32
e forall a. Bits a => a -> Int -> a
`shiftR` Int
24) forall a. Bits a => a -> a -> a
.&. Word32
255)
{-# INLINE radix #-}
instance Radix Word64 where
passes :: Word64 -> Int
passes Word64
_ = Int
8
{-# INLINE passes #-}
size :: Word64 -> Int
size Word64
_ = Int
256
{-# INLINE size #-}
radix :: Int -> Word64 -> Int
radix Int
0 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
e forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
1 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
2 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
3 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
24) forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
4 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
32) forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
5 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
40) forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
6 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
48) forall a. Bits a => a -> a -> a
.&. Word64
255)
radix Int
7 Word64
e = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word64
e forall a. Bits a => a -> Int -> a
`shiftR` Int
56) forall a. Bits a => a -> a -> a
.&. Word64
255)
{-# INLINE radix #-}
instance (Radix i, Radix j) => Radix (i, j) where
passes :: (i, j) -> Int
passes ~(i
i, j
j) = forall e. Radix e => e -> Int
passes i
i forall a. Num a => a -> a -> a
+ forall e. Radix e => e -> Int
passes j
j
{-# INLINE passes #-}
size :: (i, j) -> Int
size ~(i
i, j
j) = forall e. Radix e => e -> Int
size i
i forall a. Ord a => a -> a -> a
`max` forall e. Radix e => e -> Int
size j
j
{-# INLINE size #-}
radix :: Int -> (i, j) -> Int
radix Int
k ~(i
i, j
j) | Int
k forall a. Ord a => a -> a -> Bool
< forall e. Radix e => e -> Int
passes j
j = forall e. Radix e => Int -> e -> Int
radix Int
k j
j
| Bool
otherwise = forall e. Radix e => Int -> e -> Int
radix (Int
k forall a. Num a => a -> a -> a
- forall e. Radix e => e -> Int
passes j
j) i
i
{-# INLINE radix #-}
sort :: forall e m v. (PrimMonad m, MVector v e, Radix e)
=> v (PrimState m) e -> m ()
sort :: forall e (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v e, Radix e) =>
v (PrimState m) e -> m ()
sort v (PrimState m) e
arr = forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int -> Int -> (Int -> e -> Int) -> v (PrimState m) e -> m ()
sortBy (forall e. Radix e => e -> Int
passes e
e) (forall e. Radix e => e -> Int
size e
e) forall e. Radix e => Int -> e -> Int
radix v (PrimState m) e
arr
where
e :: e
e :: e
e = forall a. HasCallStack => a
undefined
{-# INLINABLE sort #-}
sortBy :: (PrimMonad m, MVector v e)
=> Int
-> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> m ()
sortBy :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int -> Int -> (Int -> e -> Int) -> v (PrimState m) e -> m ()
sortBy Int
passes Int
size Int -> e -> Int
rdx v (PrimState m) e
arr = do
v (PrimState m) e
tmp <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new (forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
arr)
MVector (PrimState m) Int
count <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
new Int
size
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
radixLoop Int
passes Int -> e -> Int
rdx v (PrimState m) e
arr v (PrimState m) e
tmp MVector (PrimState m) Int
count
{-# INLINE sortBy #-}
radixLoop :: (PrimMonad m, MVector v e)
=> Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> m ()
radixLoop :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
radixLoop Int
passes Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count = Bool -> Int -> m ()
go Bool
False Int
0
where
len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src
go :: Bool -> Int -> m ()
go Bool
swap Int
k
| Int
k forall a. Ord a => a -> a -> Bool
< Int
passes = if Bool
swap
then forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
body Int -> e -> Int
rdx v (PrimState m) e
dst v (PrimState m) e
src MVector (PrimState m) Int
count Int
k forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> Int -> m ()
go (Bool -> Bool
not Bool
swap) (Int
kforall a. Num a => a -> a -> a
+Int
1)
else forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
body Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count Int
k forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> Int -> m ()
go (Bool -> Bool
not Bool
swap) (Int
kforall a. Num a => a -> a -> a
+Int
1)
| Bool
otherwise = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
swap (forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy v (PrimState m) e
src v (PrimState m) e
dst)
{-# INLINE radixLoop #-}
body :: (PrimMonad m, MVector v e)
=> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> PV.MVector (PrimState m) Int
-> Int
-> m ()
body :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> Int
-> m ()
body Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count Int
k = do
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(e -> Int)
-> v (PrimState m) e -> MVector (PrimState m) Int -> m ()
countLoop (Int -> e -> Int
rdx Int
k) v (PrimState m) e
src MVector (PrimState m) Int
count
forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> m ()
accumulate MVector (PrimState m) Int
count
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
moveLoop Int
k Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
count
{-# INLINE body #-}
accumulate :: (PrimMonad m)
=> PV.MVector (PrimState m) Int -> m ()
accumulate :: forall (m :: * -> *).
PrimMonad m =>
MVector (PrimState m) Int -> m ()
accumulate MVector (PrimState m) Int
count = Int -> Int -> m ()
go Int
0 Int
0
where
len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length MVector (PrimState m) Int
count
go :: Int -> Int -> m ()
go Int
i Int
acc
| Int
i forall a. Ord a => a -> a -> Bool
< Int
len = do Int
ci <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead MVector (PrimState m) Int
count Int
i
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite MVector (PrimState m) Int
count Int
i Int
acc
Int -> Int -> m ()
go (Int
iforall a. Num a => a -> a -> a
+Int
1) (Int
acc forall a. Num a => a -> a -> a
+ Int
ci)
| Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE accumulate #-}
moveLoop :: (PrimMonad m, MVector v e)
=> Int -> (Int -> e -> Int) -> v (PrimState m) e
-> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
moveLoop :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Int
-> (Int -> e -> Int)
-> v (PrimState m) e
-> v (PrimState m) e
-> MVector (PrimState m) Int
-> m ()
moveLoop Int
k Int -> e -> Int
rdx v (PrimState m) e
src v (PrimState m) e
dst MVector (PrimState m) Int
prefix = Int -> m ()
go Int
0
where
len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src
go :: Int -> m ()
go Int
i
| Int
i forall a. Ord a => a -> a -> Bool
< Int
len = do e
srci <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
src Int
i
Int
pf <- forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Int) =>
v (PrimState m) Int -> Int -> m Int
inc MVector (PrimState m) Int
prefix (Int -> e -> Int
rdx Int
k e
srci)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) e
dst Int
pf e
srci
Int -> m ()
go (Int
iforall a. Num a => a -> a -> a
+Int
1)
| Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE moveLoop #-}