1------------------------------------------------------------------------ 2-- | 3-- Module : What4.Utils.Arithmetic 4-- Description : Utility functions for computing arithmetic 5-- Copyright : (c) Galois, Inc 2015-2020 6-- License : BSD3 7-- Maintainer : Joe Hendrix <jhendrix@galois.com> 8-- Stability : provisional 9------------------------------------------------------------------------ 10{-# LANGUAGE BangPatterns #-} 11{-# LANGUAGE CPP #-} 12module What4.Utils.Arithmetic 13 ( -- * Arithmetic utilities 14 isPow2 15 , lg 16 , lgCeil 17 , nextMultiple 18 , nextPow2Multiple 19 , tryIntSqrt 20 , tryRationalSqrt 21 , roundAway 22 , ctz 23 , clz 24 , rotateLeft 25 , rotateRight 26 ) where 27 28import Control.Exception (assert) 29import Data.Bits (Bits(..)) 30import Data.Ratio 31 32import Data.Parameterized.NatRepr 33 34-- | Returns true if number is a power of two. 35isPow2 :: (Bits a, Num a) => a -> Bool 36isPow2 x = x .&. (x-1) == 0 37 38-- | Returns floor of log base 2. 39lg :: (Bits a, Num a, Ord a) => a -> Int 40lg i0 | i0 > 0 = go 0 (i0 `shiftR` 1) 41 | otherwise = error "lg given number that is not positive." 42 where go r 0 = r 43 go r n = go (r+1) (n `shiftR` 1) 44 45-- | Returns ceil of log base 2. 46-- We define @lgCeil 0 = 0@ 47lgCeil :: (Bits a, Num a, Ord a) => a -> Int 48lgCeil 0 = 0 49lgCeil 1 = 0 50lgCeil i | i > 1 = 1 + lg (i-1) 51 | otherwise = error "lgCeil given number that is not positive." 52 53-- | Count trailing zeros 54ctz :: NatRepr w -> Integer -> Integer 55ctz w x = go 0 56 where 57 go !i 58 | i < toInteger (natValue w) && testBit x (fromInteger i) == False = go (i+1) 59 | otherwise = i 60 61-- | Count leading zeros 62clz :: NatRepr w -> Integer -> Integer 63clz w x = go 0 64 where 65 go !i 66 | i < toInteger (natValue w) && testBit x (widthVal w - fromInteger i - 1) == False = go (i+1) 67 | otherwise = i 68 69rotateRight :: 70 NatRepr w {- ^ width -} -> 71 Integer {- ^ value to rotate -} -> 72 Integer {- ^ amount to rotate -} -> 73 Integer 74rotateRight w x n = xor (shiftR x' n') (toUnsigned w (shiftL x' (widthVal w - n'))) 75 where 76 x' = toUnsigned w x 77 n' = fromInteger (n `rem` intValue w) 78 79rotateLeft :: 80 NatRepr w {- ^ width -} -> 81 Integer {- ^ value to rotate -} -> 82 Integer {- ^ amount to rotate -} -> 83 Integer 84rotateLeft w x n = xor (shiftR x' (widthVal w - n')) (toUnsigned w (shiftL x' n')) 85 where 86 x' = toUnsigned w x 87 n' = fromInteger (n `rem` intValue w) 88 89 90-- | @nextMultiple x y@ computes the next multiple m of x s.t. m >= y. E.g., 91-- nextMultiple 4 8 = 8 since 8 is a multiple of 8; nextMultiple 4 7 = 8; 92-- nextMultiple 8 6 = 8. 93nextMultiple :: Integral a => a -> a -> a 94nextMultiple x y = ((y + x - 1) `div` x) * x 95 96-- | @nextPow2Multiple x n@ returns the smallest multiple of @2^n@ 97-- not less than @x@. 98nextPow2Multiple :: (Bits a, Integral a) => a -> Int -> a 99nextPow2Multiple x n | x >= 0 && n >= 0 = ((x+2^n -1) `shiftR` n) `shiftL` n 100 | otherwise = error "nextPow2Multiple given negative value." 101 102------------------------------------------------------------------------ 103-- Sqrt operators. 104 105-- | This returns the sqrt of an integer if it is well-defined. 106tryIntSqrt :: Integer -> Maybe Integer 107tryIntSqrt 0 = return 0 108tryIntSqrt 1 = return 1 109tryIntSqrt 2 = Nothing 110tryIntSqrt 3 = Nothing 111tryIntSqrt n = assert (n >= 4) $ go (n `shiftR` 1) 112 where go x | x2 < n = Nothing -- Guess is below sqrt, so we quit. 113 | x2 == n = return x' -- We have found sqrt 114 | True = go x' -- Guess is still too large, so try again. 115 where -- Next guess is floor(avg(x, n/x)) 116 x' = (x + n `div` x) `div` 2 117 x2 = x' * x' 118 119-- | Return the rational sqrt of a 120tryRationalSqrt :: Rational -> Maybe Rational 121tryRationalSqrt r = do 122 (%) <$> tryIntSqrt (numerator r) 123 <*> tryIntSqrt (denominator r) 124 125------------------------------------------------------------------------ 126-- Conversion 127 128-- | Evaluate a real to an integer with rounding away from zero. 129roundAway :: (RealFrac a) => a -> Integer 130roundAway r = truncate (r + signum r * 0.5) 131