1-- |
2-- Module      : Crypto.Number.Basic
3-- License     : BSD-style
4-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
5-- Stability   : experimental
6-- Portability : Good
7
8{-# LANGUAGE BangPatterns #-}
9module Crypto.Number.Basic
10    ( sqrti
11    , gcde
12    , areEven
13    , log2
14    , numBits
15    , numBytes
16    , asPowerOf2AndOdd
17    ) where
18
19import Data.Bits
20
21import Crypto.Number.Compat
22
23-- | @sqrti@ returns two integers @(l,b)@ so that @l <= sqrt i <= b@.
24-- The implementation is quite naive, use an approximation for the first number
25-- and use a dichotomy algorithm to compute the bound relatively efficiently.
26sqrti :: Integer -> (Integer, Integer)
27sqrti i
28    | i < 0     = error "cannot compute negative square root"
29    | i == 0    = (0,0)
30    | i == 1    = (1,1)
31    | i == 2    = (1,2)
32    | otherwise = loop x0
33        where
34            nbdigits = length $ show i
35            x0n = (if even nbdigits then nbdigits - 2 else nbdigits - 1) `div` 2
36            x0  = if even nbdigits then 2 * 10 ^ x0n else 6 * 10 ^ x0n
37            loop x = case compare (sq x) i of
38                LT -> iterUp x
39                EQ -> (x, x)
40                GT -> iterDown x
41            iterUp lb = if sq ub >= i then iter lb ub else iterUp ub
42                where ub = lb * 2
43            iterDown ub = if sq lb >= i then iterDown lb else iter lb ub
44                where lb = ub `div` 2
45            iter lb ub
46                | lb == ub   = (lb, ub)
47                | lb+1 == ub = (lb, ub)
48                | otherwise  =
49                    let d = (ub - lb) `div` 2 in
50                    if sq (lb + d) >= i
51                        then iter lb (ub-d)
52                        else iter (lb+d) ub
53            sq a = a * a
54
55-- | Get the extended GCD of two integer using integer divMod
56--
57-- gcde 'a' 'b' find (x,y,gcd(a,b)) where ax + by = d
58--
59gcde :: Integer -> Integer -> (Integer, Integer, Integer)
60gcde a b = onGmpUnsupported (gmpGcde a b) $
61    if d < 0 then (-x,-y,-d) else (x,y,d)
62  where
63    (d, x, y)                     = f (a,1,0) (b,0,1)
64    f t              (0, _, _)    = t
65    f (a', sa, ta) t@(b', sb, tb) =
66        let (q, r) = a' `divMod` b' in
67        f t (r, sa - (q * sb), ta - (q * tb))
68
69-- | Check if a list of integer are all even
70areEven :: [Integer] -> Bool
71areEven = and . map even
72
73-- | Compute the binary logarithm of a integer
74log2 :: Integer -> Int
75log2 n = onGmpUnsupported (gmpLog2 n) $ imLog 2 n
76  where
77    -- http://www.haskell.org/pipermail/haskell-cafe/2008-February/039465.html
78    imLog b x = if x < b then 0 else (x `div` b^l) `doDiv` l
79      where
80        l = 2 * imLog (b * b) x
81        doDiv x' l' = if x' < b then l' else (x' `div` b) `doDiv` (l' + 1)
82{-# INLINE log2 #-}
83
84-- | Compute the number of bits for an integer
85numBits :: Integer -> Int
86numBits n = gmpSizeInBits n `onGmpUnsupported` (if n == 0 then 1 else computeBits 0 n)
87  where computeBits !acc i
88            | q == 0 =
89                if r >= 0x80 then acc+8
90                else if r >= 0x40 then acc+7
91                else if r >= 0x20 then acc+6
92                else if r >= 0x10 then acc+5
93                else if r >= 0x08 then acc+4
94                else if r >= 0x04 then acc+3
95                else if r >= 0x02 then acc+2
96                else if r >= 0x01 then acc+1
97                else acc -- should be catch by previous loop
98            | otherwise = computeBits (acc+8) q
99          where (q,r) = i `divMod` 256
100
101-- | Compute the number of bytes for an integer
102numBytes :: Integer -> Int
103numBytes n = gmpSizeInBytes n `onGmpUnsupported` ((numBits n + 7) `div` 8)
104
105-- | Express an integer as an odd number and a power of 2
106asPowerOf2AndOdd :: Integer -> (Int, Integer)
107asPowerOf2AndOdd a
108    | a == 0       = (0, 0)
109    | odd a        = (0, a)
110    | a < 0        = let (e, a1) = asPowerOf2AndOdd $ abs a in (e, -a1)
111    | isPowerOf2 a = (log2 a, 1)
112    | otherwise    = loop a 0
113        where
114          isPowerOf2 n = (n /= 0) && ((n .&. (n - 1)) == 0)
115          loop n pw = if n `mod` 2 == 0 then loop (n `div` 2) (pw + 1)
116                      else (pw, n)