1{-# LANGUAGE CPP #-}
2{-# LANGUAGE TypeFamilies #-}
3module Codec.Picture.Jpg.Internal.FastDct( referenceDct, fastDctLibJpeg ) where
4
5#if !MIN_VERSION_base(4,8,0)
6import Control.Applicative( (<$>) )
7#endif
8
9import Data.Int( Int16, Int32 )
10import Data.Bits( unsafeShiftR, unsafeShiftL )
11import Control.Monad.ST( ST )
12
13import qualified Data.Vector.Storable.Mutable as M
14
15import Codec.Picture.Jpg.Internal.Types
16import Control.Monad( forM, forM_ )
17
18-- | Reference implementation of the DCT, directly implementing the formula
19-- of ITU-81. It's slow as hell, perform to many operations, but is accurate
20-- and a good reference point.
21referenceDct :: MutableMacroBlock s Int32
22             -> MutableMacroBlock s Int16
23             -> ST s (MutableMacroBlock s Int32)
24referenceDct workData block = do
25  forM_ [(u, v) | u <- [0 :: Int .. dctBlockSize - 1], v <- [0..dctBlockSize - 1]] $ \(u,v) -> do
26    val <- at (u,v)
27    (workData `M.unsafeWrite` (v * dctBlockSize + u)) . truncate $ (1 / 4) * c u * c v * val
28
29  return workData
30 where -- at :: (Int, Int) -> ST s Float
31   at (u,v) = do
32     toSum <-
33        forM [(x,y) | x <- [0..dctBlockSize - 1], y <- [0..dctBlockSize - 1 :: Int]] $ \(x,y) -> do
34            sample <- fromIntegral <$> (block `M.unsafeRead` (y * dctBlockSize + x))
35            return $ sample * cos ((2 * fromIntegral x + 1) * fromIntegral u * (pi :: Float)/ 16)
36                            * cos ((2 * fromIntegral y + 1) * fromIntegral v * pi / 16)
37     return $ sum toSum
38
39   c 0 = 1 / sqrt 2
40   c _ = 1
41
42pASS1_BITS, cONST_BITS :: Int
43cONST_BITS = 13
44pASS1_BITS =  2
45
46
47fIX_0_298631336, fIX_0_390180644, fIX_0_541196100,
48    fIX_0_765366865, fIX_0_899976223, fIX_1_175875602,
49    fIX_1_501321110, fIX_1_847759065, fIX_1_961570560,
50    fIX_2_053119869, fIX_2_562915447, fIX_3_072711026 :: Int32
51fIX_0_298631336 = 2446    -- FIX(0.298631336) */
52fIX_0_390180644 = 3196    -- FIX(0.390180644) */
53fIX_0_541196100 = 4433    -- FIX(0.541196100) */
54fIX_0_765366865 = 6270    -- FIX(0.765366865) */
55fIX_0_899976223 = 7373    -- FIX(0.899976223) */
56fIX_1_175875602 = 9633    -- FIX(1.175875602) */
57fIX_1_501321110 = 12299    -- FIX(1.501321110) */
58fIX_1_847759065 = 15137    -- FIX(1.847759065) */
59fIX_1_961570560 = 16069    -- FIX(1.961570560) */
60fIX_2_053119869 = 16819    -- FIX(2.053119869) */
61fIX_2_562915447 = 20995    -- FIX(2.562915447) */
62fIX_3_072711026 = 25172    -- FIX(3.072711026) */
63
64cENTERJSAMPLE :: Int32
65cENTERJSAMPLE = 128
66
67-- | Fast DCT extracted from libjpeg
68fastDctLibJpeg :: MutableMacroBlock s Int32
69        -> MutableMacroBlock s Int16
70        -> ST s (MutableMacroBlock s Int32)
71fastDctLibJpeg workData sample_block = do
72 firstPass workData  0
73 secondPass workData 7
74 {-_ <- mutate (\_ a -> a `quot` 8) workData-}
75 return workData
76  where -- Pass 1: process rows.
77        -- Note results are scaled up by sqrt(8) compared to a true DCT;
78        -- furthermore, we scale the results by 2**PASS1_BITS.
79        firstPass         _ i | i == dctBlockSize = return ()
80        firstPass dataBlock i = do
81            let baseIdx = i * dctBlockSize
82                readAt idx = fromIntegral <$> sample_block `M.unsafeRead` (baseIdx + idx)
83                mult = (*)
84                writeAt idx = dataBlock `M.unsafeWrite` (baseIdx + idx)
85                writeAtPos idx n = (dataBlock `M.unsafeWrite` (baseIdx + idx))
86                                    (n `unsafeShiftR` (cONST_BITS - pASS1_BITS))
87
88            blk0 <- readAt 0
89            blk1 <- readAt 1
90            blk2 <- readAt 2
91            blk3 <- readAt 3
92            blk4 <- readAt 4
93            blk5 <- readAt 5
94            blk6 <- readAt 6
95            blk7 <- readAt 7
96
97            let tmp0 = blk0 + blk7
98                tmp1 = blk1 + blk6
99                tmp2 = blk2 + blk5
100                tmp3 = blk3 + blk4
101
102                tmp10 = tmp0 + tmp3
103                tmp12 = tmp0 - tmp3
104                tmp11 = tmp1 + tmp2
105                tmp13 = tmp1 - tmp2
106
107                tmp0' = blk0 - blk7
108                tmp1' = blk1 - blk6
109                tmp2' = blk2 - blk5
110                tmp3' = blk3 - blk4
111
112            -- Stage 4 and output
113            writeAt 0 $ (tmp10 + tmp11 - dctBlockSize * cENTERJSAMPLE) `unsafeShiftL` pASS1_BITS
114            writeAt 4 $ (tmp10 - tmp11) `unsafeShiftL` pASS1_BITS
115
116            let z1 = mult (tmp12 + tmp13) fIX_0_541196100
117                     + (1 `unsafeShiftL` (cONST_BITS - pASS1_BITS - 1))
118
119            writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865
120            writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065
121
122            let tmp10' = tmp0' + tmp3'
123                tmp11' = tmp1' + tmp2'
124                tmp12' = tmp0' + tmp2'
125                tmp13' = tmp1' + tmp3'
126                z1' = mult (tmp12' + tmp13') fIX_1_175875602 --  c3 */
127                        -- Add fudge factor here for final descale. */
128                        + (1 `unsafeShiftL` (cONST_BITS - pASS1_BITS-1))
129                tmp0'' = mult tmp0' fIX_1_501321110
130                tmp1'' = mult tmp1' fIX_3_072711026
131                tmp2'' = mult tmp2' fIX_2_053119869
132                tmp3'' = mult tmp3' fIX_0_298631336
133
134                tmp10'' = mult tmp10' (- fIX_0_899976223)
135                tmp11'' = mult tmp11' (- fIX_2_562915447)
136                tmp12'' = mult tmp12' (- fIX_0_390180644) + z1'
137                tmp13'' = mult tmp13' (- fIX_1_961570560) + z1'
138
139            writeAtPos 1 $ tmp0'' + tmp10'' + tmp12''
140            writeAtPos 3 $ tmp1'' + tmp11'' + tmp13''
141            writeAtPos 5 $ tmp2'' + tmp11'' + tmp12''
142            writeAtPos 7 $ tmp3'' + tmp10'' + tmp13''
143
144            firstPass dataBlock $ i + 1
145
146        -- Pass 2: process columns.
147        -- We remove the PASS1_BITS scaling, but leave the results scaled up
148        -- by an overall factor of 8.
149        secondPass :: M.STVector s Int32 -> Int -> ST s ()
150        secondPass _     (-1) = return ()
151        secondPass block i = do
152            let readAt idx = block `M.unsafeRead` ((7 - i) + idx * dctBlockSize)
153                mult = (*)
154                writeAt idx = block `M.unsafeWrite` (dctBlockSize * idx + (7 - i))
155                writeAtPos idx n = (block `M.unsafeWrite` (dctBlockSize * idx + (7 - i))) $ n `unsafeShiftR` (cONST_BITS + pASS1_BITS + 3)
156            blk0 <- readAt 0
157            blk1 <- readAt 1
158            blk2 <- readAt 2
159            blk3 <- readAt 3
160            blk4 <- readAt 4
161            blk5 <- readAt 5
162            blk6 <- readAt 6
163            blk7 <- readAt 7
164
165            let tmp0 = blk0 + blk7
166                tmp1 = blk1 + blk6
167                tmp2 = blk2 + blk5
168                tmp3 = blk3 + blk4
169
170                -- Add fudge factor here for final descale. */
171                tmp10 = tmp0 + tmp3 + (1 `unsafeShiftL` (pASS1_BITS-1))
172                tmp12 = tmp0 - tmp3
173                tmp11 = tmp1 + tmp2
174                tmp13 = tmp1 - tmp2
175
176                tmp0' = blk0 - blk7
177                tmp1' = blk1 - blk6
178                tmp2' = blk2 - blk5
179                tmp3' = blk3 - blk4
180
181            writeAt 0 $ (tmp10 + tmp11) `unsafeShiftR` (pASS1_BITS + 3)
182            writeAt 4 $ (tmp10 - tmp11) `unsafeShiftR` (pASS1_BITS + 3)
183
184            let z1 = mult (tmp12 + tmp13) fIX_0_541196100
185                    + (1 `unsafeShiftL` (cONST_BITS + pASS1_BITS - 1))
186
187            writeAtPos 2 $ z1 + mult tmp12 fIX_0_765366865
188            writeAtPos 6 $ z1 - mult tmp13 fIX_1_847759065
189
190            let tmp10' = tmp0' + tmp3'
191                tmp11' = tmp1' + tmp2'
192                tmp12' = tmp0' + tmp2'
193                tmp13' = tmp1' + tmp3'
194
195                z1' = mult (tmp12' + tmp13') fIX_1_175875602
196                    -- Add fudge factor here for final descale. */
197                   + 1 `unsafeShiftL` (cONST_BITS+pASS1_BITS-1);
198
199                tmp0''  = mult tmp0'    fIX_1_501321110
200                tmp1''  = mult tmp1'    fIX_3_072711026
201                tmp2''  = mult tmp2'    fIX_2_053119869
202                tmp3''  = mult tmp3'    fIX_0_298631336
203                tmp10'' = mult tmp10' (- fIX_0_899976223)
204                tmp11'' = mult tmp11' (- fIX_2_562915447)
205                tmp12'' = mult tmp12' (- fIX_0_390180644)
206                            + z1'
207                tmp13'' = mult tmp13' (- fIX_1_961570560)
208                            + z1'
209            writeAtPos 1 $ tmp0'' + tmp10'' + tmp12''
210            writeAtPos 3 $ tmp1'' + tmp11'' + tmp13''
211            writeAtPos 5 $ tmp2'' + tmp11'' + tmp12''
212            writeAtPos 7 $ tmp3'' + tmp10'' + tmp13''
213
214            secondPass block (i - 1)
215
216{-# ANN module "HLint: ignore Use camelCase" #-}
217{-# ANN module "HLint: ignore Reduce duplication" #-}
218
219