1------------------------------------------------------------------------
2-- |
3-- Module           : What4.Utils.Arithmetic
4-- Description      : Utility functions for computing arithmetic
5-- Copyright        : (c) Galois, Inc 2015-2020
6-- License          : BSD3
7-- Maintainer       : Joe Hendrix <jhendrix@galois.com>
8-- Stability        : provisional
9------------------------------------------------------------------------
10{-# LANGUAGE BangPatterns #-}
11{-# LANGUAGE CPP #-}
12module What4.Utils.Arithmetic
13  ( -- * Arithmetic utilities
14    isPow2
15  , lg
16  , lgCeil
17  , nextMultiple
18  , nextPow2Multiple
19  , tryIntSqrt
20  , tryRationalSqrt
21  , roundAway
22  , ctz
23  , clz
24  , rotateLeft
25  , rotateRight
26  ) where
27
28import Control.Exception (assert)
29import Data.Bits (Bits(..))
30import Data.Ratio
31
32import Data.Parameterized.NatRepr
33
34-- | Returns true if number is a power of two.
35isPow2 :: (Bits a, Num a) => a -> Bool
36isPow2 x = x .&. (x-1) == 0
37
38-- | Returns floor of log base 2.
39lg :: (Bits a, Num a, Ord a) => a -> Int
40lg i0 | i0 > 0 = go 0 (i0 `shiftR` 1)
41      | otherwise = error "lg given number that is not positive."
42  where go r 0 = r
43        go r n = go (r+1) (n `shiftR` 1)
44
45-- | Returns ceil of log base 2.
46--   We define @lgCeil 0 = 0@
47lgCeil :: (Bits a, Num a, Ord a) => a -> Int
48lgCeil 0 = 0
49lgCeil 1 = 0
50lgCeil i | i > 1 = 1 + lg (i-1)
51         | otherwise = error "lgCeil given number that is not positive."
52
53-- | Count trailing zeros
54ctz :: NatRepr w -> Integer -> Integer
55ctz w x = go 0
56 where
57 go !i
58   | i < toInteger (natValue w) && testBit x (fromInteger i) == False = go (i+1)
59   | otherwise = i
60
61-- | Count leading zeros
62clz :: NatRepr w -> Integer -> Integer
63clz w x = go 0
64 where
65 go !i
66   | i < toInteger (natValue w) && testBit x (widthVal w - fromInteger i - 1) == False = go (i+1)
67   | otherwise = i
68
69rotateRight ::
70  NatRepr w {- ^ width -} ->
71  Integer {- ^ value to rotate -} ->
72  Integer {- ^ amount to rotate -} ->
73  Integer
74rotateRight w x n = xor (shiftR x' n') (toUnsigned w (shiftL x' (widthVal w - n')))
75 where
76 x' = toUnsigned w x
77 n' = fromInteger (n `rem` intValue w)
78
79rotateLeft ::
80  NatRepr w {- ^ width -} ->
81  Integer {- ^ value to rotate -} ->
82  Integer {- ^ amount to rotate -} ->
83  Integer
84rotateLeft w x n = xor (shiftR x' (widthVal w - n')) (toUnsigned w (shiftL x' n'))
85 where
86 x' = toUnsigned w x
87 n' = fromInteger (n `rem` intValue w)
88
89
90-- | @nextMultiple x y@ computes the next multiple m of x s.t. m >= y.  E.g.,
91-- nextMultiple 4 8 = 8 since 8 is a multiple of 8; nextMultiple 4 7 = 8;
92-- nextMultiple 8 6 = 8.
93nextMultiple :: Integral a => a -> a -> a
94nextMultiple x y = ((y + x - 1) `div` x) * x
95
96-- | @nextPow2Multiple x n@ returns the smallest multiple of @2^n@
97-- not less than @x@.
98nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a
99nextPow2Multiple x n | x >= 0 && n >= 0 = ((x+2^n -1) `shiftR` n) `shiftL` n
100                     | otherwise = error "nextPow2Multiple given negative value."
101
102------------------------------------------------------------------------
103-- Sqrt operators.
104
105-- | This returns the sqrt of an integer if it is well-defined.
106tryIntSqrt :: Integer -> Maybe Integer
107tryIntSqrt 0 = return 0
108tryIntSqrt 1 = return 1
109tryIntSqrt 2 = Nothing
110tryIntSqrt 3 = Nothing
111tryIntSqrt n = assert (n >= 4) $ go (n `shiftR` 1)
112  where go x | x2 < n  = Nothing   -- Guess is below sqrt, so we quit.
113             | x2 == n = return x' -- We have found sqrt
114             | True    = go x'     -- Guess is still too large, so try again.
115          where -- Next guess is floor(avg(x, n/x))
116                x' = (x + n `div` x) `div` 2
117                x2 = x' * x'
118
119-- | Return the rational sqrt of a
120tryRationalSqrt :: Rational -> Maybe Rational
121tryRationalSqrt r = do
122  (%) <$> tryIntSqrt (numerator   r)
123      <*> tryIntSqrt (denominator r)
124
125------------------------------------------------------------------------
126-- Conversion
127
128-- | Evaluate a real to an integer with rounding away from zero.
129roundAway :: (RealFrac a) => a -> Integer
130roundAway r = truncate (r + signum r * 0.5)
131