1{-# LANGUAGE CPP #-} 2{-# LANGUAGE TypeFamilies #-} 3module Codec.Picture.Jpg.Internal.FastDct( referenceDct, fastDctLibJpeg ) where 4 5#if !MIN_VERSION_base(4,8,0) 6import Control.Applicative( (<$>) ) 7#endif 8 9import Data.Int( Int16, Int32 ) 10import Data.Bits( unsafeShiftR, unsafeShiftL ) 11import Control.Monad.ST( ST ) 12 13import qualified Data.Vector.Storable.Mutable as M 14 15import Codec.Picture.Jpg.Internal.Types 16import Control.Monad( forM, forM_ ) 17 18-- | Reference implementation of the DCT, directly implementing the formula 19-- of ITU-81. It's slow as hell, perform to many operations, but is accurate 20-- and a good reference point. 21referenceDct :: MutableMacroBlock s Int32 22 -> MutableMacroBlock s Int16 23 -> ST s (MutableMacroBlock s Int32) 24referenceDct workData block = do 25 forM_ [(u, v) | u <- [0 :: Int .. dctBlockSize - 1], v <- [0..dctBlockSize - 1]] $ \(u,v) -> do 26 val <- at (u,v) 27 (workData `M.unsafeWrite` (v * dctBlockSize + u)) . truncate $ (1 / 4) * c u * c v * val 28 29 return workData 30 where -- at :: (Int, Int) -> ST s Float 31 at (u,v) = do 32 toSum <- 33 forM [(x,y) | x <- [0..dctBlockSize - 1], y <- [0..dctBlockSize - 1 :: Int]] $ \(x,y) -> do 34 sample <- fromIntegral <$> (block `M.unsafeRead` (y * dctBlockSize + x)) 35 return $ sample * cos ((2 * fromIntegral x + 1) * fromIntegral u * (pi :: Float)/ 16) 36 * cos ((2 * fromIntegral y + 1) * fromIntegral v * pi / 16) 37 return $ sum toSum 38 39 c 0 = 1 / sqrt 2 40 c _ = 1 41 42pASS1_BITS, cONST_BITS :: Int 43cONST_BITS = 13 44pASS1_BITS = 2 45 46 47fIX_0_298631336, fIX_0_390180644, fIX_0_541196100, 48 fIX_0_765366865, fIX_0_899976223, fIX_1_175875602, 49 fIX_1_501321110, fIX_1_847759065, fIX_1_961570560, 50 fIX_2_053119869, fIX_2_562915447, fIX_3_072711026 :: Int32 51fIX_0_298631336 = 2446 -- FIX(0.298631336) */ 52fIX_0_390180644 = 3196 -- FIX(0.390180644) */ 53fIX_0_541196100 = 4433 -- FIX(0.541196100) */ 54fIX_0_765366865 = 6270 -- FIX(0.765366865) */ 55fIX_0_899976223 = 7373 -- FIX(0.899976223) */ 56fIX_1_175875602 = 9633 -- FIX(1.175875602) */ 57fIX_1_501321110 = 12299 -- FIX(1.501321110) */ 58fIX_1_847759065 = 15137 -- FIX(1.847759065) */ 59fIX_1_961570560 = 16069 -- FIX(1.961570560) */ 60fIX_2_053119869 = 16819 -- FIX(2.053119869) */ 61fIX_2_562915447 = 20995 -- FIX(2.562915447) */ 62fIX_3_072711026 = 25172 -- FIX(3.072711026) */ 63 64cENTERJSAMPLE :: Int32 65cENTERJSAMPLE = 128 66 67-- | Fast DCT extracted from libjpeg 68fastDctLibJpeg :: MutableMacroBlock s Int32 69 -> MutableMacroBlock s Int16 70 -> ST s (MutableMacroBlock s Int32) 71fastDctLibJpeg workData sample_block = do 72 firstPass workData 0 73 secondPass workData 7 74 {-_ <- mutate (\_ a -> a `quot` 8) workData-} 75 return workData 76 where -- Pass 1: process rows. 77 -- Note results are scaled up by sqrt(8) compared to a true DCT; 78 -- furthermore, we scale the results by 2**PASS1_BITS. 79 firstPass _ i | i == dctBlockSize = return () 80 firstPass dataBlock i = do 81 let baseIdx = i * dctBlockSize 82 readAt idx = fromIntegral <$> sample_block `M.unsafeRead` (baseIdx + idx) 83 mult = (*) 84 writeAt idx = dataBlock `M.unsafeWrite` (baseIdx + idx) 85 writeAtPos idx n = (dataBlock `M.unsafeWrite` (baseIdx + idx)) 86 (n `unsafeShiftR` (cONST_BITS - pASS1_BITS)) 87 88 blk0 <- readAt 0 89 blk1 <- readAt 1 90 blk2 <- readAt 2 91 blk3 <- readAt 3 92 blk4 <- readAt 4 93 blk5 <- readAt 5 94 blk6 <- readAt 6 95 blk7 <- readAt 7 96 97 let tmp0 = blk0 + blk7 98 tmp1 = blk1 + blk6 99 tmp2 = blk2 + blk5 100 tmp3 = blk3 + blk4 101 102 tmp10 = tmp0 + tmp3 103 tmp12 = tmp0 - tmp3 104 tmp11 = tmp1 + tmp2 105 tmp13 = tmp1 - tmp2 106 107 tmp0' = blk0 - blk7 108 tmp1' = blk1 - blk6 109 tmp2' = blk2 - blk5 110 tmp3' = blk3 - blk4 111 112 -- Stage 4 and output 113 writeAt 0 $ (tmp10 + tmp11 - dctBlockSize * cENTERJSAMPLE) `unsafeShiftL` pASS1_BITS 114 writeAt 4 $ (tmp10 - tmp11) `unsafeShiftL` pASS1_BITS 115 116 let z1 = mult (tmp12 + tmp13) fIX_0_541196100 117 + (1 `unsafeShiftL` (cONST_BITS - pASS1_BITS - 1)) 118 119 writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865 120 writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065 121 122 let tmp10' = tmp0' + tmp3' 123 tmp11' = tmp1' + tmp2' 124 tmp12' = tmp0' + tmp2' 125 tmp13' = tmp1' + tmp3' 126 z1' = mult (tmp12' + tmp13') fIX_1_175875602 -- c3 */ 127 -- Add fudge factor here for final descale. */ 128 + (1 `unsafeShiftL` (cONST_BITS - pASS1_BITS-1)) 129 tmp0'' = mult tmp0' fIX_1_501321110 130 tmp1'' = mult tmp1' fIX_3_072711026 131 tmp2'' = mult tmp2' fIX_2_053119869 132 tmp3'' = mult tmp3' fIX_0_298631336 133 134 tmp10'' = mult tmp10' (- fIX_0_899976223) 135 tmp11'' = mult tmp11' (- fIX_2_562915447) 136 tmp12'' = mult tmp12' (- fIX_0_390180644) + z1' 137 tmp13'' = mult tmp13' (- fIX_1_961570560) + z1' 138 139 writeAtPos 1 $ tmp0'' + tmp10'' + tmp12'' 140 writeAtPos 3 $ tmp1'' + tmp11'' + tmp13'' 141 writeAtPos 5 $ tmp2'' + tmp11'' + tmp12'' 142 writeAtPos 7 $ tmp3'' + tmp10'' + tmp13'' 143 144 firstPass dataBlock $ i + 1 145 146 -- Pass 2: process columns. 147 -- We remove the PASS1_BITS scaling, but leave the results scaled up 148 -- by an overall factor of 8. 149 secondPass :: M.STVector s Int32 -> Int -> ST s () 150 secondPass _ (-1) = return () 151 secondPass block i = do 152 let readAt idx = block `M.unsafeRead` ((7 - i) + idx * dctBlockSize) 153 mult = (*) 154 writeAt idx = block `M.unsafeWrite` (dctBlockSize * idx + (7 - i)) 155 writeAtPos idx n = (block `M.unsafeWrite` (dctBlockSize * idx + (7 - i))) $ n `unsafeShiftR` (cONST_BITS + pASS1_BITS + 3) 156 blk0 <- readAt 0 157 blk1 <- readAt 1 158 blk2 <- readAt 2 159 blk3 <- readAt 3 160 blk4 <- readAt 4 161 blk5 <- readAt 5 162 blk6 <- readAt 6 163 blk7 <- readAt 7 164 165 let tmp0 = blk0 + blk7 166 tmp1 = blk1 + blk6 167 tmp2 = blk2 + blk5 168 tmp3 = blk3 + blk4 169 170 -- Add fudge factor here for final descale. */ 171 tmp10 = tmp0 + tmp3 + (1 `unsafeShiftL` (pASS1_BITS-1)) 172 tmp12 = tmp0 - tmp3 173 tmp11 = tmp1 + tmp2 174 tmp13 = tmp1 - tmp2 175 176 tmp0' = blk0 - blk7 177 tmp1' = blk1 - blk6 178 tmp2' = blk2 - blk5 179 tmp3' = blk3 - blk4 180 181 writeAt 0 $ (tmp10 + tmp11) `unsafeShiftR` (pASS1_BITS + 3) 182 writeAt 4 $ (tmp10 - tmp11) `unsafeShiftR` (pASS1_BITS + 3) 183 184 let z1 = mult (tmp12 + tmp13) fIX_0_541196100 185 + (1 `unsafeShiftL` (cONST_BITS + pASS1_BITS - 1)) 186 187 writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865 188 writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065 189 190 let tmp10' = tmp0' + tmp3' 191 tmp11' = tmp1' + tmp2' 192 tmp12' = tmp0' + tmp2' 193 tmp13' = tmp1' + tmp3' 194 195 z1' = mult (tmp12' + tmp13') fIX_1_175875602 196 -- Add fudge factor here for final descale. */ 197 + 1 `unsafeShiftL` (cONST_BITS+pASS1_BITS-1); 198 199 tmp0'' = mult tmp0' fIX_1_501321110 200 tmp1'' = mult tmp1' fIX_3_072711026 201 tmp2'' = mult tmp2' fIX_2_053119869 202 tmp3'' = mult tmp3' fIX_0_298631336 203 tmp10'' = mult tmp10' (- fIX_0_899976223) 204 tmp11'' = mult tmp11' (- fIX_2_562915447) 205 tmp12'' = mult tmp12' (- fIX_0_390180644) 206 + z1' 207 tmp13'' = mult tmp13' (- fIX_1_961570560) 208 + z1' 209 writeAtPos 1 $ tmp0'' + tmp10'' + tmp12'' 210 writeAtPos 3 $ tmp1'' + tmp11'' + tmp13'' 211 writeAtPos 5 $ tmp2'' + tmp11'' + tmp12'' 212 writeAtPos 7 $ tmp3'' + tmp10'' + tmp13'' 213 214 secondPass block (i - 1) 215 216{-# ANN module "HLint: ignore Use camelCase" #-} 217{-# ANN module "HLint: ignore Reduce duplication" #-} 218 219