1{-|
2Module           : What4.Interface
3Description      : Main interface for constructing What4 formulae
4Copyright        : (c) Galois, Inc 2014-2020
5License          : BSD3
6Maintainer       : Joe Hendrix <jhendrix@galois.com>
7
8Defines interface between the simulator and terms that are sent to the
9SAT or SMT solver.  The simulator can use a richer set of types, but the
10symbolic values must be representable by types supported by this interface.
11
12A solver backend is defined in terms of a type parameter @sym@, which
13is the type that tracks whatever state or context is needed by that
14particular backend. To instantiate the solver interface, one must
15provide several type family definitions and class instances for @sym@:
16
17  [@type 'SymExpr' sym :: 'BaseType' -> *@]
18  Type of symbolic expressions.
19
20  [@type 'BoundVar' sym :: 'BaseType' -> *@]
21  Representation of bound variables in symbolic expressions.
22
23  [@type 'SymFn' sym :: Ctx BaseType -> BaseType -> *@]
24  Representation of symbolic functions.
25
26  [@instance 'IsExprBuilder' sym@]
27  Functions for building expressions of various types.
28
29  [@instance 'IsSymExprBuilder' sym@]
30  Functions for building expressions with bound variables and quantifiers.
31
32  [@instance 'IsExpr' ('SymExpr' sym)@]
33  Recognizers for various kinds of literal expressions.
34
35  [@instance 'OrdF' ('SymExpr' sym)@]
36
37  [@instance 'TestEquality' ('SymExpr' sym)@]
38
39  [@instance 'HashableF' ('SymExpr' sym)@]
40
41The canonical implementation of these interface classes is found in "What4.Expr.Builder".
42-}
43{-# LANGUAGE CPP #-}
44{-# LANGUAGE ConstraintKinds #-}
45{-# LANGUAGE DataKinds #-}
46{-# LANGUAGE DeriveGeneric #-}
47{-# LANGUAGE DoAndIfThenElse #-}
48{-# LANGUAGE FlexibleContexts #-}
49{-# LANGUAGE FlexibleInstances #-}
50{-# LANGUAGE GADTs #-}
51{-# LANGUAGE LambdaCase #-}
52{-# LANGUAGE LiberalTypeSynonyms #-}
53{-# LANGUAGE MultiParamTypeClasses #-}
54{-# LANGUAGE PatternGuards #-}
55{-# LANGUAGE PolyKinds #-}
56{-# LANGUAGE RankNTypes #-}
57{-# LANGUAGE ScopedTypeVariables #-}
58{-# LANGUAGE TypeApplications #-}
59{-# LANGUAGE TypeFamilies #-}
60{-# LANGUAGE TypeOperators #-}
61
62{-# LANGUAGE UndecidableInstances #-}
63
64module What4.Interface
65  ( -- * Interface classes
66    -- ** Type Families
67    SymExpr
68  , BoundVar
69  , SymFn
70  , SymAnnotation
71
72    -- ** Expression recognizers
73  , IsExpr(..)
74  , IsSymFn(..)
75  , UnfoldPolicy(..)
76  , shouldUnfold
77
78    -- ** IsExprBuilder
79  , IsExprBuilder(..)
80  , IsSymExprBuilder(..)
81  , SolverEvent(..)
82
83    -- ** Bitvector operations
84  , bvJoinVector
85  , bvSplitVector
86  , bvSwap
87  , bvBitreverse
88
89    -- ** Floating-point rounding modes
90  , RoundingMode(..)
91
92    -- ** Run-time statistics
93  , Statistics(..)
94  , zeroStatistics
95
96    -- * Type Aliases
97  , Pred
98  , SymInteger
99  , SymReal
100  , SymFloat
101  , SymString
102  , SymCplx
103  , SymStruct
104  , SymBV
105  , SymArray
106
107    -- * Natural numbers
108  , SymNat
109  , asNat
110  , natLit
111  , natAdd
112  , natSub
113  , natMul
114  , natDiv
115  , natMod
116  , natIte
117  , natEq
118  , natLe
119  , natLt
120  , natToInteger
121  , bvToNat
122  , natToReal
123  , integerToNat
124  , realToNat
125  , freshBoundedNat
126  , freshNat
127  , printSymNat
128
129    -- * Array utility types
130  , IndexLit(..)
131  , indexLit
132  , ArrayResultWrapper(..)
133
134    -- * Concrete values
135  , asConcrete
136  , concreteToSym
137  , baseIsConcrete
138  , baseDefaultValue
139  , realExprAsInteger
140  , rationalAsInteger
141  , cplxExprAsRational
142  , cplxExprAsInteger
143
144    -- * SymEncoder
145  , SymEncoder(..)
146
147    -- * Utility combinators
148    -- ** Boolean operations
149  , backendPred
150  , andAllOf
151  , orOneOf
152  , itePredM
153  , iteM
154  , iteList
155  , predToReal
156
157    -- ** Complex number operations
158  , cplxDiv
159  , cplxLog
160  , cplxLogBase
161  , mkRational
162  , mkReal
163  , isNonZero
164  , isReal
165
166    -- ** Indexing
167  , muxRange
168
169    -- * Exceptions
170  , InvalidRange(..)
171
172    -- * Reexports
173  , module Data.Parameterized.NatRepr
174  , module What4.BaseTypes
175  , HasAbsValue
176  , What4.Symbol.SolverSymbol
177  , What4.Symbol.emptySymbol
178  , What4.Symbol.userSymbol
179  , What4.Symbol.safeSymbol
180  , ValueRange(..)
181  , StringLiteral(..)
182  , stringLiteralInfo
183  ) where
184
185#if !MIN_VERSION_base(4,13,0)
186import Control.Monad.Fail( MonadFail )
187#endif
188
189import           Control.Exception (assert, Exception)
190import           Control.Lens
191import           Control.Monad
192import           Control.Monad.IO.Class
193import qualified Data.BitVector.Sized as BV
194import           Data.Coerce (coerce)
195import           Data.Foldable
196import           Data.Kind ( Type )
197import qualified Data.Map as Map
198import           Data.Parameterized.Classes
199import qualified Data.Parameterized.Context as Ctx
200import           Data.Parameterized.Ctx
201import           Data.Parameterized.Utils.Endian (Endian(..))
202import           Data.Parameterized.NatRepr
203import           Data.Parameterized.TraversableFC
204import qualified Data.Parameterized.Vector as Vector
205import           Data.Ratio
206import           Data.Scientific (Scientific)
207import           GHC.Generics (Generic)
208import           Numeric.Natural
209import           LibBF (BigFloat)
210import           Prettyprinter (Doc)
211
212import           What4.BaseTypes
213import           What4.Config
214import qualified What4.Expr.ArrayUpdateMap as AUM
215import           What4.IndexLit
216import           What4.ProgramLoc
217import           What4.Concrete
218import           What4.SatResult
219import           What4.Symbol
220import           What4.Utils.AbstractDomains
221import           What4.Utils.Arithmetic
222import           What4.Utils.Complex
223import           What4.Utils.FloatHelpers (RoundingMode(..))
224import           What4.Utils.StringLiteral
225
226------------------------------------------------------------------------
227-- SymExpr names
228
229-- | Symbolic boolean values, AKA predicates.
230type Pred sym = SymExpr sym BaseBoolType
231
232-- | Symbolic integers.
233type SymInteger sym = SymExpr sym BaseIntegerType
234
235-- | Symbolic real numbers.
236type SymReal sym = SymExpr sym BaseRealType
237
238-- | Symbolic floating point numbers.
239type SymFloat sym fpp = SymExpr sym (BaseFloatType fpp)
240
241-- | Symbolic complex numbers.
242type SymCplx sym = SymExpr sym BaseComplexType
243
244-- | Symbolic structures.
245type SymStruct sym flds = SymExpr sym (BaseStructType flds)
246
247-- | Symbolic arrays.
248type SymArray sym idx b = SymExpr sym (BaseArrayType idx b)
249
250-- | Symbolic bitvectors.
251type SymBV sym n = SymExpr sym (BaseBVType n)
252
253-- | Symbolic strings.
254type SymString sym si = SymExpr sym (BaseStringType si)
255
256------------------------------------------------------------------------
257-- Type families for the interface.
258
259-- | The class for expressions.
260type family SymExpr (sym :: Type) :: BaseType -> Type
261
262------------------------------------------------------------------------
263-- | Type of bound variable associated with symbolic state.
264--
265-- This type is used by some methods in class 'IsSymExprBuilder'.
266type family BoundVar (sym :: Type) :: BaseType -> Type
267
268
269------------------------------------------------------------------------
270-- | Type used to uniquely identify expressions that have been annotated.
271type family SymAnnotation (sym :: Type) :: BaseType -> Type
272
273------------------------------------------------------------------------
274-- IsBoolSolver
275
276-- | Perform an ite on a predicate lazily.
277itePredM :: (IsExpr (SymExpr sym), IsExprBuilder sym, MonadIO m)
278         => sym
279         -> Pred sym
280         -> m (Pred sym)
281         -> m (Pred sym)
282         -> m (Pred sym)
283itePredM sym c mx my =
284  case asConstantPred c of
285    Just True -> mx
286    Just False -> my
287    Nothing -> do
288      x <- mx
289      y <- my
290      liftIO $ itePred sym c x y
291
292------------------------------------------------------------------------
293-- IsExpr
294
295-- | This class provides operations for recognizing when symbolic expressions
296--   represent concrete values, extracting the type from an expression,
297--   and for providing pretty-printed representations of an expression.
298class HasAbsValue e => IsExpr e where
299  -- | Evaluate if predicate is constant.
300  asConstantPred :: e BaseBoolType -> Maybe Bool
301  asConstantPred _ = Nothing
302
303  -- | Return integer if this is a constant integer.
304  asInteger :: e BaseIntegerType -> Maybe Integer
305  asInteger _ = Nothing
306
307  -- | Return any bounding information we have about the term
308  integerBounds :: e BaseIntegerType -> ValueRange Integer
309
310  -- | Return rational if this is a constant value.
311  asRational :: e BaseRealType -> Maybe Rational
312  asRational _ = Nothing
313
314  -- | Return floating-point value if this is a constant
315  asFloat :: e (BaseFloatType fpp) -> Maybe BigFloat
316
317  -- | Return any bounding information we have about the term
318  rationalBounds :: e BaseRealType -> ValueRange Rational
319
320  -- | Return complex if this is a constant value.
321  asComplex :: e BaseComplexType -> Maybe (Complex Rational)
322  asComplex _ = Nothing
323
324  -- | Return a bitvector if this is a constant bitvector.
325  asBV :: e (BaseBVType w) -> Maybe (BV.BV w)
326  asBV _ = Nothing
327
328  -- | If we have bounds information about the term, return unsigned
329  -- upper and lower bounds as integers
330  unsignedBVBounds :: (1 <= w) => e (BaseBVType w) -> Maybe (Integer, Integer)
331
332  -- | If we have bounds information about the term, return signed
333  -- upper and lower bounds as integers
334  signedBVBounds :: (1 <= w) => e (BaseBVType w) -> Maybe (Integer, Integer)
335
336  -- | If this expression syntactically represents an "affine" form, return its components.
337  --   When @asAffineVar x = Just (c,r,o)@, then we have @x == c*r + o@.
338  asAffineVar :: e tp -> Maybe (ConcreteVal tp, e tp, ConcreteVal tp)
339
340  -- | Return the string value if this is a constant string
341  asString :: e (BaseStringType si) -> Maybe (StringLiteral si)
342  asString _ = Nothing
343
344  -- | Return the representation of the string info for a string-typed term.
345  stringInfo :: e (BaseStringType si) -> StringInfoRepr si
346  stringInfo e =
347    case exprType e of
348      BaseStringRepr si -> si
349
350  -- | Return the unique element value if this is a constant array,
351  --   such as one made with 'constantArray'.
352  asConstantArray :: e (BaseArrayType idx bt) -> Maybe (e bt)
353  asConstantArray _ = Nothing
354
355  -- | Return the struct fields if this is a concrete struct.
356  asStruct :: e (BaseStructType flds) -> Maybe (Ctx.Assignment e flds)
357  asStruct _ = Nothing
358
359  -- | Get type of expression.
360  exprType :: e tp -> BaseTypeRepr tp
361
362  -- | Get the width of a bitvector
363  bvWidth      :: e (BaseBVType w) -> NatRepr w
364  bvWidth e =
365    case exprType e of
366      BaseBVRepr w -> w
367
368  -- | Get the precision of a floating-point expression
369  floatPrecision :: e (BaseFloatType fpp) -> FloatPrecisionRepr fpp
370  floatPrecision e =
371    case exprType e of
372      BaseFloatRepr fpp -> fpp
373
374  -- | Print a sym expression for debugging or display purposes.
375  printSymExpr :: e tp -> Doc ann
376
377
378newtype ArrayResultWrapper f idx tp =
379  ArrayResultWrapper { unwrapArrayResult :: f (BaseArrayType idx tp) }
380
381instance TestEquality f => TestEquality (ArrayResultWrapper f idx) where
382  testEquality (ArrayResultWrapper x) (ArrayResultWrapper y) = do
383    Refl <- testEquality x y
384    return Refl
385
386instance HashableF e => HashableF (ArrayResultWrapper e idx) where
387  hashWithSaltF s (ArrayResultWrapper v) = hashWithSaltF s v
388
389
390-- | This datatype describes events that involve interacting with
391--   solvers.  A @SolverEvent@ will be provided to the action
392--   installed via @setSolverLogListener@ whenever an interesting
393--   event occurs.
394data SolverEvent
395  = SolverStartSATQuery
396    { satQuerySolverName :: !String
397    , satQueryReason     :: !String
398    }
399  | SolverEndSATQuery
400    { satQueryResult     :: !(SatResult () ())
401    , satQueryError      :: !(Maybe String)
402    }
403 deriving (Show, Generic)
404
405------------------------------------------------------------------------
406-- SymNat
407
408-- | Symbolic natural numbers.
409newtype SymNat sym =
410  SymNat
411  { -- Internal Invariant: the value in a SymNat is always nonnegative
412    _symNat :: SymExpr sym BaseIntegerType
413  }
414
415-- | Return nat if this is a constant natural number.
416asNat :: IsExpr (SymExpr sym) => SymNat sym -> Maybe Natural
417asNat (SymNat x) = fromInteger . max 0 <$> asInteger x
418
419-- | A natural number literal.
420natLit :: IsExprBuilder sym => sym -> Natural -> IO (SymNat sym)
421-- @Natural@ input is necessarily nonnegative
422natLit sym x = SymNat <$> intLit sym (toInteger x)
423
424-- | Add two natural numbers.
425natAdd :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (SymNat sym)
426-- Integer addition preserves nonnegative values
427natAdd sym (SymNat x) (SymNat y) = SymNat <$> intAdd sym x y
428
429-- | Subtract one number from another.
430--
431-- The result is 0 if the subtraction would otherwise be negative.
432natSub :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (SymNat sym)
433natSub sym (SymNat x) (SymNat y) =
434  do z <- intSub sym x y
435     SymNat <$> (intMax sym z =<< intLit sym 0)
436
437-- | Multiply one number by another.
438natMul :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (SymNat sym)
439-- Integer multiplication preserves nonnegative values
440natMul sym (SymNat x) (SymNat y) = SymNat <$> intMul sym x y
441
442-- | @'natDiv' sym x y@ performs division on naturals.
443--
444-- The result is undefined if @y@ equals @0@.
445--
446-- 'natDiv' and 'natMod' satisfy the property that given
447--
448-- @
449--   d <- natDiv sym x y
450--   m <- natMod sym x y
451-- @
452--
453--  and @y > 0@, we have that @y * d + m = x@ and @m < y@.
454natDiv :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (SymNat sym)
455-- Integer division preserves nonnegative values.
456natDiv sym (SymNat x) (SymNat y) = SymNat <$> intDiv sym x y
457
458-- | @'natMod' sym x y@ returns @x@ mod @y@.
459--
460-- See 'natDiv' for a description of the properties the return
461-- value is expected to satisfy.
462natMod :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (SymNat sym)
463-- Integer modulus preserves nonnegative values.
464natMod sym (SymNat x) (SymNat y) = SymNat <$> intMod sym x y
465
466-- | If-then-else applied to natural numbers.
467natIte :: IsExprBuilder sym => sym -> Pred sym -> SymNat sym -> SymNat sym -> IO (SymNat sym)
468-- ITE preserves nonnegative values.
469natIte sym p (SymNat x) (SymNat y) = SymNat <$> intIte sym p x y
470
471-- | Equality predicate for natural numbers.
472natEq :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (Pred sym)
473natEq sym (SymNat x) (SymNat y) = intEq sym x y
474
475-- | @'natLe' sym x y@ returns @true@ if @x <= y@.
476natLe :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (Pred sym)
477natLe sym (SymNat x) (SymNat y) = intLe sym x y
478
479-- | @'natLt' sym x y@ returns @true@ if @x < y@.
480natLt :: IsExprBuilder sym => sym -> SymNat sym -> SymNat sym -> IO (Pred sym)
481natLt sym x y = notPred sym =<< natLe sym y x
482
483-- | Convert a natural number to an integer.
484natToInteger :: IsExprBuilder sym => sym -> SymNat sym -> IO (SymInteger sym)
485natToInteger _sym (SymNat x) = pure x
486
487-- | Convert the unsigned value of a bitvector to a natural.
488bvToNat :: (IsExprBuilder sym, 1 <= w) => sym -> SymBV sym w -> IO (SymNat sym)
489-- The unsigned value of a bitvector is always nonnegative
490bvToNat sym x = SymNat <$> bvToInteger sym x
491
492-- | Convert a natural number to a real number.
493natToReal :: IsExprBuilder sym => sym -> SymNat sym -> IO (SymReal sym)
494natToReal sym = natToInteger sym >=> integerToReal sym
495
496-- | Convert an integer to a natural number.
497--
498-- For negative integers, the result is clamped to 0.
499integerToNat :: IsExprBuilder sym => sym -> SymInteger sym -> IO (SymNat sym)
500integerToNat sym x = SymNat <$> (intMax sym x =<< intLit sym 0)
501
502-- | Convert a real number to a natural number.
503--
504-- The result is undefined if the given real number does not represent a natural number.
505realToNat :: IsExprBuilder sym => sym -> SymReal sym -> IO (SymNat sym)
506realToNat sym r = realToInteger sym r >>= integerToNat sym
507
508-- | Create a fresh natural number constant with optional lower and upper bounds.
509--   If provided, the bounds are inclusive.
510--   If inconsistent bounds are given, an InvalidRange exception will be thrown.
511freshBoundedNat ::
512  IsSymExprBuilder sym =>
513  sym ->
514  SolverSymbol ->
515  Maybe Natural {- ^ lower bound -} ->
516  Maybe Natural {- ^ upper bound -} ->
517  IO (SymNat sym)
518freshBoundedNat sym s lo hi = SymNat <$> (freshBoundedInt sym s lo' hi')
519 where
520   lo' = Just (maybe 0 toInteger lo)
521   hi' = toInteger <$> hi
522
523-- | Create a fresh natural number constant.
524freshNat :: IsSymExprBuilder sym => sym -> SolverSymbol -> IO (SymNat sym)
525freshNat sym s = freshBoundedNat sym s (Just 0) Nothing
526
527printSymNat :: IsExpr (SymExpr sym) => SymNat sym -> Doc ann
528printSymNat (SymNat x) = printSymExpr x
529
530instance TestEquality (SymExpr sym) => Eq (SymNat sym) where
531  SymNat x == SymNat y = isJust (testEquality x y)
532
533instance OrdF (SymExpr sym) => Ord (SymNat sym) where
534  compare (SymNat x) (SymNat y) = toOrdering (compareF x y)
535
536instance HashableF (SymExpr sym) => Hashable (SymNat sym) where
537  hashWithSalt s (SymNat x) = hashWithSaltF s x
538
539------------------------------------------------------------------------
540-- IsExprBuilder
541
542-- | This class allows the simulator to build symbolic expressions.
543--
544-- Methods of this class refer to type families @'SymExpr' sym@
545-- and @'SymFn' sym@.
546--
547-- Note: Some methods in this class represent operations that are
548-- partial functions on their domain (e.g., division by 0).
549-- Such functions will have documentation strings indicating that they
550-- are undefined under some conditions.  When partial functions are applied
551-- outside their defined domains, they will silently produce an unspecified
552-- value of the expected type.  The unspecified value returned as the result
553-- of an undefined function is _not_ guaranteed to be equivalant to a free
554-- constant, and no guarantees are made about what properties such values
555-- will satisfy.
556class ( IsExpr (SymExpr sym), HashableF (SymExpr sym)
557      , TestEquality (SymAnnotation sym), OrdF (SymAnnotation sym)
558      , HashableF (SymAnnotation sym)
559      ) => IsExprBuilder sym where
560
561  -- | Retrieve the configuration object corresponding to this solver interface.
562  getConfiguration :: sym -> Config
563
564
565  -- | Install an action that will be invoked before and after calls to
566  --   backend solvers.  This action is primarily intended to be used for
567  --   logging\/profiling\/debugging purposes.  Passing 'Nothing' to this
568  --   function disables logging.
569  setSolverLogListener :: sym -> Maybe (SolverEvent -> IO ()) -> IO ()
570
571  -- | Get the currently-installed solver log listener, if one has been installed.
572  getSolverLogListener :: sym -> IO (Maybe (SolverEvent -> IO ()))
573
574  -- | Provide the given event to the currently installed
575  --   solver log listener, if any.
576  logSolverEvent :: sym -> SolverEvent -> IO ()
577
578  -- | Get statistics on execution from the initialization of the
579  -- symbolic interface to this point.  May return zeros if gathering
580  -- statistics isn't supported.
581  getStatistics :: sym -> IO Statistics
582  getStatistics _ = return zeroStatistics
583
584  ----------------------------------------------------------------------
585  -- Program location operations
586
587  -- | Get current location of program for term creation purposes.
588  getCurrentProgramLoc :: sym -> IO ProgramLoc
589
590  -- | Set current location of program for term creation purposes.
591  setCurrentProgramLoc :: sym -> ProgramLoc -> IO ()
592
593  -- | Return true if two expressions are equal. The default
594  -- implementation dispatches 'eqPred', 'bvEq', 'natEq', 'intEq',
595  -- 'realEq', 'cplxEq', 'structEq', or 'arrayEq', depending on the
596  -- type.
597  isEq :: sym -> SymExpr sym tp -> SymExpr sym tp -> IO (Pred sym)
598  isEq sym x y =
599    case exprType x of
600      BaseBoolRepr     -> eqPred sym x y
601      BaseBVRepr{}     -> bvEq sym x y
602      BaseIntegerRepr  -> intEq sym x y
603      BaseRealRepr     -> realEq sym x y
604      BaseFloatRepr{}  -> floatEq sym x y
605      BaseComplexRepr  -> cplxEq sym x y
606      BaseStringRepr{} -> stringEq sym x y
607      BaseStructRepr{} -> structEq sym x y
608      BaseArrayRepr{}  -> arrayEq sym x y
609
610  -- | Take the if-then-else of two expressions. The default
611  -- implementation dispatches 'itePred', 'bvIte', 'natIte', 'intIte',
612  -- 'realIte', 'cplxIte', 'structIte', or 'arrayIte', depending on
613  -- the type.
614  baseTypeIte :: sym
615              -> Pred sym
616              -> SymExpr sym tp
617              -> SymExpr sym tp
618              -> IO (SymExpr sym tp)
619  baseTypeIte sym c x y =
620    case exprType x of
621      BaseBoolRepr     -> itePred   sym c x y
622      BaseBVRepr{}     -> bvIte     sym c x y
623      BaseIntegerRepr  -> intIte    sym c x y
624      BaseRealRepr     -> realIte   sym c x y
625      BaseFloatRepr{}  -> floatIte  sym c x y
626      BaseStringRepr{} -> stringIte sym c x y
627      BaseComplexRepr  -> cplxIte   sym c x y
628      BaseStructRepr{} -> structIte sym c x y
629      BaseArrayRepr{}  -> arrayIte  sym c x y
630
631  -- | Given a symbolic expression, annotate it with a unique identifier
632  --   that can be used to maintain a connection with the given term.
633  --   The 'SymAnnotation' is intended to be used as the key in a hash
634  --   table or map to additional data can be maintained alongside the terms.
635  --   The returned 'SymExpr' has the same semantics as the argument, but
636  --   has embedded in it the 'SymAnnotation' value so that it can be used
637  --   later during term traversals.
638  --
639  --   Note, the returned annotation is not necessarily fresh; if an
640  --   already-annotated term is passed in, the same annotation value will be
641  --   returned.
642  annotateTerm :: sym -> SymExpr sym tp -> IO (SymAnnotation sym tp, SymExpr sym tp)
643
644  -- | Project an annotation from an expression
645  --
646  -- It should be the case that using 'getAnnotation' on a term returned by
647  -- 'annotateTerm' returns the same annotation that 'annotateTerm' did.
648  getAnnotation :: sym -> SymExpr sym tp -> Maybe (SymAnnotation sym tp)
649
650  ----------------------------------------------------------------------
651  -- Boolean operations.
652
653  -- | Constant true predicate
654  truePred  :: sym -> Pred sym
655
656  -- | Constant false predicate
657  falsePred :: sym -> Pred sym
658
659  -- | Boolean negation
660  notPred :: sym -> Pred sym -> IO (Pred sym)
661
662  -- | Boolean conjunction
663  andPred :: sym -> Pred sym -> Pred sym -> IO (Pred sym)
664
665  -- | Boolean disjunction
666  orPred  :: sym -> Pred sym -> Pred sym -> IO (Pred sym)
667
668  -- | Boolean implication
669  impliesPred :: sym -> Pred sym -> Pred sym -> IO (Pred sym)
670  impliesPred sym x y = do
671    nx <- notPred sym x
672    orPred sym y nx
673
674  -- | Exclusive-or operation
675  xorPred :: sym -> Pred sym -> Pred sym -> IO (Pred sym)
676
677  -- | Equality of boolean values
678  eqPred  :: sym -> Pred sym -> Pred sym -> IO (Pred sym)
679
680  -- | If-then-else on a predicate.
681  itePred :: sym -> Pred sym -> Pred sym -> Pred sym -> IO (Pred sym)
682
683  ----------------------------------------------------------------------
684  -- Integer operations
685
686  -- | Create an integer literal.
687  intLit :: sym -> Integer -> IO (SymInteger sym)
688
689  -- | Negate an integer.
690  intNeg :: sym -> SymInteger sym -> IO (SymInteger sym)
691
692  -- | Add two integers.
693  intAdd :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
694
695  -- | Subtract one integer from another.
696  intSub :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
697  intSub sym x y = intAdd sym x =<< intNeg sym y
698
699  -- | Multiply one integer by another.
700  intMul :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
701
702  -- | Return the minimum value of two integers.
703  intMin :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
704  intMin sym x y =
705    do p <- intLe sym x y
706       intIte sym p x y
707
708  -- | Return the maximum value of two integers.
709  intMax :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
710  intMax sym x y =
711    do p <- intLe sym x y
712       intIte sym p y x
713
714  -- | If-then-else applied to integers.
715  intIte :: sym -> Pred sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
716
717  -- | Integer equality.
718  intEq  :: sym -> SymInteger sym -> SymInteger sym -> IO (Pred sym)
719
720  -- | Integer less-than-or-equal.
721  intLe  :: sym -> SymInteger sym -> SymInteger sym -> IO (Pred sym)
722
723  -- | Integer less-than.
724  intLt  :: sym -> SymInteger sym -> SymInteger sym -> IO (Pred sym)
725  intLt sym x y = notPred sym =<< intLe sym y x
726
727  -- | Compute the absolute value of an integer.
728  intAbs :: sym -> SymInteger sym -> IO (SymInteger sym)
729
730  -- | @intDiv x y@ computes the integer division of @x@ by @y@.  This division is
731  --   interpreted the same way as the SMT-Lib integer theory, which states that
732  --   @div@ and @mod@ are the unique Euclidean division operations satisfying the
733  --   following for all @y /= 0@:
734  --
735  --   * @y * (div x y) + (mod x y) == x@
736  --   * @ 0 <= mod x y < abs y@
737  --
738  --   The value of @intDiv x y@ is undefined when @y = 0@.
739  --
740  --   Integer division requires nonlinear support whenever the divisor is
741  --   not a constant.
742  --
743  --   Note: @div x y@ is @floor (x/y)@ when @y@ is positive
744  --   (regardless of sign of @x@) and @ceiling (x/y)@ when @y@ is
745  --   negative.  This is neither of the more common "round toward
746  --   zero" nor "round toward -inf" definitions.
747  --
748  --   Some useful theorems that are true of this division/modulus pair:
749  --
750  --   * @mod x y == mod x (- y) == mod x (abs y)@
751  --   * @div x (-y) == -(div x y)@
752  intDiv :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
753
754  -- | @intMod x y@ computes the integer modulus of @x@ by @y@.  See 'intDiv' for
755  --   more details.
756  --
757  --   The value of @intMod x y@ is undefined when @y = 0@.
758  --
759  --   Integer modulus requires nonlinear support whenever the divisor is
760  --   not a constant.
761  intMod :: sym -> SymInteger sym -> SymInteger sym -> IO (SymInteger sym)
762
763  -- | @intDivisible x k@ is true whenever @x@ is an integer divisible
764  --   by the known natural number @k@.  In other words `divisible x k`
765  --   holds if there exists an integer `z` such that `x = k*z`.
766  intDivisible :: sym -> SymInteger sym -> Natural -> IO (Pred sym)
767
768  ----------------------------------------------------------------------
769  -- Bitvector operations
770
771  -- | Create a bitvector with the given width and value.
772  bvLit :: (1 <= w) => sym -> NatRepr w -> BV.BV w -> IO (SymBV sym w)
773
774  -- | Concatenate two bitvectors.
775  bvConcat :: (1 <= u, 1 <= v)
776           => sym
777           -> SymBV sym u  -- ^ most significant bits
778           -> SymBV sym v  -- ^ least significant bits
779           -> IO (SymBV sym (u+v))
780
781  -- | Select a subsequence from a bitvector.
782  bvSelect :: forall idx n w. (1 <= n, idx + n <= w)
783           => sym
784           -> NatRepr idx  -- ^ Starting index, from 0 as least significant bit
785           -> NatRepr n    -- ^ Number of bits to take
786           -> SymBV sym w  -- ^ Bitvector to select from
787           -> IO (SymBV sym n)
788
789  -- | 2's complement negation.
790  bvNeg :: (1 <= w)
791        => sym
792        -> SymBV sym w
793        -> IO (SymBV sym w)
794
795  -- | Add two bitvectors.
796  bvAdd :: (1 <= w)
797        => sym
798        -> SymBV sym w
799        -> SymBV sym w
800        -> IO (SymBV sym w)
801
802  -- | Subtract one bitvector from another.
803  bvSub :: (1 <= w)
804        => sym
805        -> SymBV sym w
806        -> SymBV sym w
807        -> IO (SymBV sym w)
808  bvSub sym x y = bvAdd sym x =<< bvNeg sym y
809
810  -- | Multiply one bitvector by another.
811  bvMul :: (1 <= w)
812        => sym
813        -> SymBV sym w
814        -> SymBV sym w
815        -> IO (SymBV sym w)
816
817  -- | Unsigned bitvector division.
818  --
819  --   The result of @bvUdiv x y@ is undefined when @y@ is zero,
820  --   but is otherwise equal to @floor( x / y )@.
821  bvUdiv :: (1 <= w)
822         => sym
823         -> SymBV sym w
824         -> SymBV sym w
825         -> IO (SymBV sym w)
826
827  -- | Unsigned bitvector remainder.
828  --
829  --   The result of @bvUrem x y@ is undefined when @y@ is zero,
830  --   but is otherwise equal to @x - (bvUdiv x y) * y@.
831  bvUrem :: (1 <= w)
832         => sym
833         -> SymBV sym w
834         -> SymBV sym w
835         -> IO (SymBV sym w)
836
837  -- | Signed bitvector division.  The result is truncated to zero.
838  --
839  --   The result of @bvSdiv x y@ is undefined when @y@ is zero,
840  --   but is equal to @floor(x/y)@ when @x@ and @y@ have the same sign,
841  --   and equal to @ceiling(x/y)@ when @x@ and @y@ have opposite signs.
842  --
843  --   NOTE! However, that there is a corner case when dividing @MIN_INT@ by
844  --   @-1@, in which case an overflow condition occurs, and the result is instead
845  --   @MIN_INT@.
846  bvSdiv :: (1 <= w)
847         => sym
848         -> SymBV sym w
849         -> SymBV sym w
850         -> IO (SymBV sym w)
851
852  -- | Signed bitvector remainder.
853  --
854  --   The result of @bvSrem x y@ is undefined when @y@ is zero, but is
855  --   otherwise equal to @x - (bvSdiv x y) * y@.
856  bvSrem :: (1 <= w)
857         => sym
858         -> SymBV sym w
859         -> SymBV sym w
860         -> IO (SymBV sym w)
861
862  -- | Returns true if the corresponding bit in the bitvector is set.
863  testBitBV :: (1 <= w)
864            => sym
865            -> Natural -- ^ Index of bit (0 is the least significant bit)
866            -> SymBV sym w
867            -> IO (Pred sym)
868
869  -- | Return true if bitvector is negative.
870  bvIsNeg :: (1 <= w) => sym -> SymBV sym w -> IO (Pred sym)
871  bvIsNeg sym x = bvSlt sym x =<< bvLit sym (bvWidth x) (BV.zero (bvWidth x))
872
873  -- | If-then-else applied to bitvectors.
874  bvIte :: (1 <= w)
875        => sym
876        -> Pred sym
877        -> SymBV sym w
878        -> SymBV sym w
879        -> IO (SymBV sym w)
880
881  -- | Return true if bitvectors are equal.
882  bvEq  :: (1 <= w)
883        => sym
884        -> SymBV sym w
885        -> SymBV sym w
886        -> IO (Pred sym)
887
888  -- | Return true if bitvectors are distinct.
889  bvNe  :: (1 <= w)
890        => sym
891        -> SymBV sym w
892        -> SymBV sym w
893        -> IO (Pred sym)
894  bvNe sym x y = notPred sym =<< bvEq sym x y
895
896  -- | Unsigned less-than.
897  bvUlt  :: (1 <= w)
898         => sym
899         -> SymBV sym w
900         -> SymBV sym w
901         -> IO (Pred sym)
902
903  -- | Unsigned less-than-or-equal.
904  bvUle  :: (1 <= w)
905         => sym
906         -> SymBV sym w
907         -> SymBV sym w
908         -> IO (Pred sym)
909  bvUle sym x y = notPred sym =<< bvUlt sym y x
910
911  -- | Unsigned greater-than-or-equal.
912  bvUge :: (1 <= w) => sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
913  bvUge sym x y = bvUle sym y x
914
915  -- | Unsigned greater-than.
916  bvUgt :: (1 <= w) => sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
917  bvUgt sym x y = bvUlt sym y x
918
919  -- | Signed less-than.
920  bvSlt :: (1 <= w) => sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
921
922  -- | Signed greater-than.
923  bvSgt :: (1 <= w) => sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
924  bvSgt sym x y = bvSlt sym y x
925
926  -- | Signed less-than-or-equal.
927  bvSle :: (1 <= w) => sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
928  bvSle sym x y = notPred sym =<< bvSlt sym y x
929
930  -- | Signed greater-than-or-equal.
931  bvSge :: (1 <= w) => sym -> SymBV sym w -> SymBV sym w -> IO (Pred sym)
932  bvSge sym x y = notPred sym =<< bvSlt sym x y
933
934  -- | returns true if the given bitvector is non-zero.
935  bvIsNonzero :: (1 <= w) => sym -> SymBV sym w -> IO (Pred sym)
936
937  -- | Left shift.  The shift amount is treated as an unsigned value.
938  bvShl :: (1 <= w) => sym ->
939                       SymBV sym w {- ^ Shift this -} ->
940                       SymBV sym w {- ^ Amount to shift by -} ->
941                       IO (SymBV sym w)
942
943  -- | Logical right shift.  The shift amount is treated as an unsigned value.
944  bvLshr :: (1 <= w) => sym ->
945                        SymBV sym w {- ^ Shift this -} ->
946                        SymBV sym w {- ^ Amount to shift by -} ->
947                        IO (SymBV sym w)
948
949  -- | Arithmetic right shift.  The shift amount is treated as an
950  -- unsigned value.
951  bvAshr :: (1 <= w) => sym ->
952                        SymBV sym w {- ^ Shift this -} ->
953                        SymBV sym w {- ^ Amount to shift by -} ->
954                        IO (SymBV sym w)
955
956  -- | Rotate left.  The rotate amount is treated as an unsigned value.
957  bvRol :: (1 <= w) =>
958    sym ->
959    SymBV sym w {- ^ bitvector to rotate -} ->
960    SymBV sym w {- ^ amount to rotate by -} ->
961    IO (SymBV sym w)
962
963  -- | Rotate right.  The rotate amount is treated as an unsigned value.
964  bvRor :: (1 <= w) =>
965    sym ->
966    SymBV sym w {- ^ bitvector to rotate -} ->
967    SymBV sym w {- ^ amount to rotate by -} ->
968    IO (SymBV sym w)
969
970  -- | Zero-extend a bitvector.
971  bvZext :: (1 <= u, u+1 <= r) => sym -> NatRepr r -> SymBV sym u -> IO (SymBV sym r)
972
973  -- | Sign-extend a bitvector.
974  bvSext :: (1 <= u, u+1 <= r) => sym -> NatRepr r -> SymBV sym u -> IO (SymBV sym r)
975
976  -- | Truncate a bitvector.
977  bvTrunc :: (1 <= r, r+1 <= w) -- Assert result is less than input.
978          => sym
979          -> NatRepr r
980          -> SymBV sym w
981          -> IO (SymBV sym r)
982  bvTrunc sym w x
983    | LeqProof <- leqTrans
984        (addIsLeq w (knownNat @1))
985        (leqProof (incNat w) (bvWidth x))
986    = bvSelect sym (knownNat @0) w x
987
988  -- | Bitwise logical and.
989  bvAndBits :: (1 <= w)
990            => sym
991            -> SymBV sym w
992            -> SymBV sym w
993            -> IO (SymBV sym w)
994
995  -- | Bitwise logical or.
996  bvOrBits  :: (1 <= w)
997            => sym
998            -> SymBV sym w
999            -> SymBV sym w
1000            -> IO (SymBV sym w)
1001
1002  -- | Bitwise logical exclusive or.
1003  bvXorBits :: (1 <= w)
1004            => sym
1005            -> SymBV sym w
1006            -> SymBV sym w
1007            -> IO (SymBV sym w)
1008
1009  -- | Bitwise complement.
1010  bvNotBits :: (1 <= w) => sym -> SymBV sym w -> IO (SymBV sym w)
1011
1012  -- | @bvSet sym v i p@ returns a bitvector @v'@ where bit @i@ of @v'@ is set to
1013  -- @p@, and the bits at the other indices are the same as in @v@.
1014  bvSet :: forall w
1015         . (1 <= w)
1016        => sym         -- ^ Symbolic interface
1017        -> SymBV sym w -- ^ Bitvector to update
1018        -> Natural     -- ^ 0-based index to set
1019        -> Pred sym    -- ^ Predicate to set.
1020        -> IO (SymBV sym w)
1021  bvSet sym v i p = assert (i < natValue (bvWidth v)) $
1022    -- NB, this representation based on AND/XOR structure is designed so that a
1023    -- sequence of bvSet operations will collapse nicely into a xor-linear combination
1024    -- of the original term and bvFill terms. It has the nice property that we
1025    -- do not introduce any additional subterm sharing.
1026    do let w    = bvWidth v
1027       let mask = BV.bit' w i
1028       pbits <- bvFill sym w p
1029       vbits <- bvAndBits sym v =<< bvLit sym w (BV.complement w mask)
1030       bvXorBits sym vbits =<< bvAndBits sym pbits =<< bvLit sym w mask
1031
1032  -- | @bvFill sym w p@ returns a bitvector @w@-bits long where every bit
1033  --   is given by the boolean value of @p@.
1034  bvFill :: forall w. (1 <= w) =>
1035    sym       {-^ symbolic interface -} ->
1036    NatRepr w {-^ output bitvector width -} ->
1037    Pred sym  {-^ predicate to fill the bitvector with -} ->
1038    IO (SymBV sym w)
1039
1040  -- | Return the bitvector of the desired width with all 0 bits;
1041  --   this is the minimum unsigned integer.
1042  minUnsignedBV :: (1 <= w) => sym -> NatRepr w -> IO (SymBV sym w)
1043  minUnsignedBV sym w = bvLit sym w (BV.zero w)
1044
1045  -- | Return the bitvector of the desired width with all bits set;
1046  --   this is the maximum unsigned integer.
1047  maxUnsignedBV :: (1 <= w) => sym -> NatRepr w -> IO (SymBV sym w)
1048  maxUnsignedBV sym w = bvLit sym w (BV.maxUnsigned w)
1049
1050  -- | Return the bitvector representing the largest 2's complement
1051  --   signed integer of the given width.  This consists of all bits
1052  --   set except the MSB.
1053  maxSignedBV :: (1 <= w) => sym -> NatRepr w -> IO (SymBV sym w)
1054  maxSignedBV sym w = bvLit sym w (BV.maxSigned w)
1055
1056  -- | Return the bitvector representing the smallest 2's complement
1057  --   signed integer of the given width. This consists of all 0 bits
1058  --   except the MSB, which is set.
1059  minSignedBV :: (1 <= w) => sym -> NatRepr w -> IO (SymBV sym w)
1060  minSignedBV sym w = bvLit sym w (BV.minSigned w)
1061
1062  -- | Return the number of 1 bits in the input.
1063  bvPopcount :: (1 <= w) => sym -> SymBV sym w -> IO (SymBV sym w)
1064
1065  -- | Return the number of consecutive 0 bits in the input, starting from
1066  --   the most significant bit position.  If the input is zero, all bits are counted
1067  --   as leading.
1068  bvCountLeadingZeros :: (1 <= w) => sym -> SymBV sym w -> IO (SymBV sym w)
1069
1070  -- | Return the number of consecutive 0 bits in the input, starting from
1071  --   the least significant bit position.  If the input is zero, all bits are counted
1072  --   as leading.
1073  bvCountTrailingZeros :: (1 <= w) => sym -> SymBV sym w -> IO (SymBV sym w)
1074
1075  -- | Unsigned add with overflow bit.
1076  addUnsignedOF :: (1 <= w)
1077                => sym
1078                -> SymBV sym w
1079                -> SymBV sym w
1080                -> IO (Pred sym, SymBV sym w)
1081  addUnsignedOF sym x y = do
1082    -- Compute result
1083    r   <- bvAdd sym x y
1084    -- Return that this overflows if r is less than either x or y
1085    ovx  <- bvUlt sym r x
1086    ovy  <- bvUlt sym r y
1087    ov   <- orPred sym ovx ovy
1088    return (ov, r)
1089
1090  -- | Signed add with overflow bit. Overflow is true if positive +
1091  -- positive = negative, or if negative + negative = positive.
1092  addSignedOF :: (1 <= w)
1093              => sym
1094              -> SymBV sym w
1095              -> SymBV sym w
1096              -> IO (Pred sym, SymBV sym w)
1097  addSignedOF sym x y = do
1098    xy  <- bvAdd sym x y
1099    sx  <- bvIsNeg sym x
1100    sy  <- bvIsNeg sym y
1101    sxy <- bvIsNeg sym xy
1102
1103    not_sx  <- notPred sym sx
1104    not_sy  <- notPred sym sy
1105    not_sxy <- notPred sym sxy
1106
1107    -- Return this overflowed if the sign bits of sx and sy are equal,
1108    -- but different from sxy.
1109    ov1 <- andPred sym not_sxy =<< andPred sym sx sy
1110    ov2 <- andPred sym sxy =<< andPred sym not_sx not_sy
1111
1112    ov  <- orPred sym ov1 ov2
1113    return (ov, xy)
1114
1115  -- | Unsigned subtract with overflow bit. Overflow is true if x < y.
1116  subUnsignedOF ::
1117    (1 <= w) =>
1118    sym ->
1119    SymBV sym w ->
1120    SymBV sym w ->
1121    IO (Pred sym, SymBV sym w)
1122  subUnsignedOF sym x y = do
1123    xy <- bvSub sym x y
1124    ov <- bvUlt sym x y
1125    return (ov, xy)
1126
1127  -- | Signed subtract with overflow bit. Overflow is true if positive
1128  -- - negative = negative, or if negative - positive = positive.
1129  subSignedOF :: (1 <= w)
1130              => sym
1131              -> SymBV sym w
1132              -> SymBV sym w
1133              -> IO (Pred sym, SymBV sym w)
1134  subSignedOF sym x y = do
1135       xy  <- bvSub sym x y
1136       sx  <- bvIsNeg sym x
1137       sy  <- bvIsNeg sym y
1138       sxy <- bvIsNeg sym xy
1139       ov  <- join (pure (andPred sym) <*> xorPred sym sx sxy <*> xorPred sym sx sy)
1140       return (ov, xy)
1141
1142
1143  -- | Compute the carry-less multiply of the two input bitvectors.
1144  --   This operation is essentially the same as a standard multiply, except that
1145  --   the partial addends are simply XOR'd together instead of using a standard
1146  --   adder.  This operation is useful for computing on GF(2^n) polynomials.
1147  carrylessMultiply ::
1148    (1 <= w) =>
1149    sym ->
1150    SymBV sym w ->
1151    SymBV sym w ->
1152    IO (SymBV sym (w+w))
1153  carrylessMultiply sym x0 y0
1154    | Just _  <- BV.asUnsigned <$> asBV x0
1155    , Nothing <- BV.asUnsigned <$> asBV y0
1156    = go y0 x0
1157    | otherwise
1158    = go x0 y0
1159   where
1160   go :: (1 <= w) => SymBV sym w -> SymBV sym w -> IO (SymBV sym (w+w))
1161   go x y =
1162    do let w = bvWidth x
1163       let w2 = addNat w w
1164       -- 1 <= w
1165       one_leq_w@LeqProof <- return (leqProof (knownNat @1) w)
1166       -- 1 <= w implies 1 <= w + w
1167       LeqProof <- return (leqAdd one_leq_w w)
1168       -- w <= w
1169       w_leq_w@LeqProof <- return (leqProof w w)
1170       -- w <= w, 1 <= w implies w + 1 <= w + w
1171       LeqProof <- return (leqAdd2 w_leq_w one_leq_w)
1172       z  <- bvLit sym w2 (BV.zero w2)
1173       x' <- bvZext sym w2 x
1174       xs <- sequence [ do p <- testBitBV sym (BV.asNatural i) y
1175                           iteM bvIte sym
1176                             p
1177                             (bvShl sym x' =<< bvLit sym w2 i)
1178                             (return z)
1179                      | i <- BV.enumFromToUnsigned (BV.zero w2) (BV.mkBV w2 (intValue w - 1))
1180                      ]
1181       foldM (bvXorBits sym) z xs
1182
1183  -- | @unsignedWideMultiplyBV sym x y@ multiplies two unsigned 'w' bit numbers 'x' and 'y'.
1184  --
1185  -- It returns a pair containing the top 'w' bits as the first element, and the
1186  -- lower 'w' bits as the second element.
1187  unsignedWideMultiplyBV :: (1 <= w)
1188                         => sym
1189                         -> SymBV sym w
1190                         -> SymBV sym w
1191                         -> IO (SymBV sym w, SymBV sym w)
1192  unsignedWideMultiplyBV sym x y = do
1193       let w = bvWidth x
1194       let dbl_w = addNat w w
1195       -- 1 <= w
1196       one_leq_w@LeqProof <- return (leqProof (knownNat @1) w)
1197       -- 1 <= w implies 1 <= w + w
1198       LeqProof <- return (leqAdd one_leq_w w)
1199       -- w <= w
1200       w_leq_w@LeqProof <- return (leqProof w w)
1201       -- w <= w, 1 <= w implies w + 1 <= w + w
1202       LeqProof <- return (leqAdd2 w_leq_w one_leq_w)
1203       x'  <- bvZext sym dbl_w x
1204       y'  <- bvZext sym dbl_w y
1205       s   <- bvMul sym x' y'
1206       lo  <- bvTrunc sym w s
1207       n   <- bvLit sym dbl_w (BV.zext dbl_w (BV.width w))
1208       hi  <- bvTrunc sym w =<< bvLshr sym s n
1209       return (hi, lo)
1210
1211  -- | Compute the unsigned multiply of two values with overflow bit.
1212  mulUnsignedOF ::
1213    (1 <= w) =>
1214    sym ->
1215    SymBV sym w ->
1216    SymBV sym w ->
1217    IO (Pred sym, SymBV sym w)
1218  mulUnsignedOF sym x y =
1219    do let w = bvWidth x
1220       let dbl_w = addNat w w
1221       -- 1 <= w
1222       one_leq_w@LeqProof <- return (leqProof (knownNat @1) w)
1223       -- 1 <= w implies 1 <= w + w
1224       LeqProof <- return (leqAdd one_leq_w w)
1225       -- w <= w
1226       w_leq_w@LeqProof <- return (leqProof w w)
1227       -- w <= w, 1 <= w implies w + 1 <= w + w
1228       LeqProof <- return (leqAdd2 w_leq_w one_leq_w)
1229       x'  <- bvZext sym dbl_w x
1230       y'  <- bvZext sym dbl_w y
1231       s   <- bvMul sym x' y'
1232       lo  <- bvTrunc sym w s
1233
1234       -- overflow if the result is greater than the max representable value in w bits
1235       ov  <- bvUgt sym s =<< bvLit sym dbl_w (BV.zext dbl_w (BV.maxUnsigned w))
1236
1237       return (ov, lo)
1238
1239  -- | @signedWideMultiplyBV sym x y@ multiplies two signed 'w' bit numbers 'x' and 'y'.
1240  --
1241  -- It returns a pair containing the top 'w' bits as the first element, and the
1242  -- lower 'w' bits as the second element.
1243  signedWideMultiplyBV :: (1 <= w)
1244                       => sym
1245                       -> SymBV sym w
1246                       -> SymBV sym w
1247                       -> IO (SymBV sym w, SymBV sym w)
1248  signedWideMultiplyBV sym x y = do
1249       let w = bvWidth x
1250       let dbl_w = addNat w w
1251       -- 1 <= w
1252       one_leq_w@LeqProof <- return (leqProof (knownNat @1) w)
1253       -- 1 <= w implies 1 <= w + w
1254       LeqProof <- return (leqAdd one_leq_w w)
1255       -- w <= w
1256       w_leq_w@LeqProof <- return (leqProof w w)
1257       -- w <= w, 1 <= w implies w + 1 <= w + w
1258       LeqProof <- return (leqAdd2 w_leq_w one_leq_w)
1259       x'  <- bvSext sym dbl_w x
1260       y'  <- bvSext sym dbl_w y
1261       s   <- bvMul sym x' y'
1262       lo  <- bvTrunc sym w s
1263       n   <- bvLit sym dbl_w (BV.zext dbl_w (BV.width w))
1264       hi  <- bvTrunc sym w =<< bvLshr sym s n
1265       return (hi, lo)
1266
1267  -- | Compute the signed multiply of two values with overflow bit.
1268  mulSignedOF ::
1269    (1 <= w) =>
1270    sym ->
1271    SymBV sym w ->
1272    SymBV sym w ->
1273    IO (Pred sym, SymBV sym w)
1274  mulSignedOF sym x y =
1275    do let w = bvWidth x
1276       let dbl_w = addNat w w
1277       -- 1 <= w
1278       one_leq_w@LeqProof <- return (leqProof (knownNat @1) w)
1279       -- 1 <= w implies 1 <= w + w
1280       LeqProof <- return (leqAdd one_leq_w w)
1281       -- w <= w
1282       w_leq_w@LeqProof <- return (leqProof w w)
1283       -- w <= w, 1 <= w implies w + 1 <= w + w
1284       LeqProof <- return (leqAdd2 w_leq_w one_leq_w)
1285       x'  <- bvSext sym dbl_w x
1286       y'  <- bvSext sym dbl_w y
1287       s   <- bvMul sym x' y'
1288       lo  <- bvTrunc sym w s
1289
1290       -- overflow if greater or less than max representable values
1291       ov1 <- bvSlt sym s =<< bvLit sym dbl_w (BV.sext w dbl_w (BV.minSigned w))
1292       ov2 <- bvSgt sym s =<< bvLit sym dbl_w (BV.sext w dbl_w (BV.maxSigned w))
1293       ov  <- orPred sym ov1 ov2
1294       return (ov, lo)
1295
1296  ----------------------------------------------------------------------
1297  -- Struct operations
1298
1299  -- | Create a struct from an assignment of expressions.
1300  mkStruct :: sym
1301           -> Ctx.Assignment (SymExpr sym) flds
1302           -> IO (SymStruct sym flds)
1303
1304  -- | Get the value of a specific field in a struct.
1305  structField :: sym
1306              -> SymStruct sym flds
1307              -> Ctx.Index flds tp
1308              -> IO (SymExpr sym tp)
1309
1310  -- | Check if two structs are equal.
1311  structEq  :: forall flds
1312            .  sym
1313            -> SymStruct sym flds
1314            -> SymStruct sym flds
1315            -> IO (Pred sym)
1316  structEq sym x y = do
1317    case exprType x of
1318      BaseStructRepr fld_types -> do
1319        let sz = Ctx.size fld_types
1320        -- Checks to see if the ith struct fields are equal, and all previous entries
1321        -- are as well.
1322        let f :: IO (Pred sym) -> Ctx.Index flds tp -> IO (Pred sym)
1323            f mp i = do
1324              xi <- structField sym x i
1325              yi <- structField sym y i
1326              i_eq <- isEq sym xi yi
1327              case asConstantPred i_eq of
1328                Just True -> mp
1329                Just False -> return (falsePred sym)
1330                _ ->  andPred sym i_eq =<< mp
1331        Ctx.forIndex sz f (return (truePred sym))
1332
1333  -- | Take the if-then-else of two structures.
1334  structIte :: sym
1335            -> Pred sym
1336            -> SymStruct sym flds
1337            -> SymStruct sym flds
1338            -> IO (SymStruct sym flds)
1339
1340  -----------------------------------------------------------------------
1341  -- Array operations
1342
1343  -- | Create an array where each element has the same value.
1344  constantArray :: sym -- Interface
1345                -> Ctx.Assignment BaseTypeRepr (idx::>tp) -- ^ Index type
1346                -> SymExpr sym b -- ^ Constant
1347                -> IO (SymArray sym (idx::>tp) b)
1348
1349  -- | Create an array from an arbitrary symbolic function.
1350  --
1351  -- Arrays created this way can typically not be compared
1352  -- for equality when provided to backend solvers.
1353  arrayFromFn :: sym
1354              -> SymFn sym (idx ::> itp) ret
1355              -> IO (SymArray sym (idx ::> itp) ret)
1356
1357  -- | Create an array by mapping a function over one or more existing arrays.
1358  arrayMap :: sym
1359           -> SymFn sym (ctx::>d) r
1360           -> Ctx.Assignment (ArrayResultWrapper (SymExpr sym) (idx ::> itp)) (ctx::>d)
1361           -> IO (SymArray sym (idx ::> itp) r)
1362
1363  -- | Update an array at a specific location.
1364  arrayUpdate :: sym
1365              -> SymArray sym (idx::>tp) b
1366              -> Ctx.Assignment (SymExpr sym) (idx::>tp)
1367              -> SymExpr sym b
1368              -> IO (SymArray sym (idx::>tp) b)
1369
1370  -- | Return element in array.
1371  arrayLookup :: sym
1372              -> SymArray sym (idx::>tp) b
1373              -> Ctx.Assignment (SymExpr sym) (idx::>tp)
1374              -> IO (SymExpr sym b)
1375
1376  -- | Create an array from a map of concrete indices to values.
1377  --
1378  -- This is implemented, but designed to be overridden for efficiency.
1379  arrayFromMap :: sym
1380               -> Ctx.Assignment BaseTypeRepr (idx ::> itp)
1381                  -- ^ Types for indices
1382               -> AUM.ArrayUpdateMap (SymExpr sym) (idx ::> itp) tp
1383                  -- ^ Value for known indices.
1384               -> SymExpr sym tp
1385                  -- ^ Value for other entries.
1386               -> IO (SymArray sym (idx ::> itp) tp)
1387  arrayFromMap sym idx_tps m default_value = do
1388    a0 <- constantArray sym idx_tps default_value
1389    arrayUpdateAtIdxLits sym m a0
1390
1391  -- | Update an array at specific concrete indices.
1392  --
1393  -- This is implemented, but designed to be overriden for efficiency.
1394  arrayUpdateAtIdxLits :: sym
1395                       -> AUM.ArrayUpdateMap (SymExpr sym) (idx ::> itp) tp
1396                       -- ^ Value for known indices.
1397                       -> SymArray sym (idx ::> itp) tp
1398                       -- ^ Value for existing array.
1399                       -> IO (SymArray sym (idx ::> itp) tp)
1400  arrayUpdateAtIdxLits sym m a0 = do
1401    let updateAt a (i,v) = do
1402          idx <-  traverseFC (indexLit sym) i
1403          arrayUpdate sym a idx v
1404    foldlM updateAt a0 (AUM.toList m)
1405
1406  -- | If-then-else applied to arrays.
1407  arrayIte :: sym
1408           -> Pred sym
1409           -> SymArray sym idx b
1410           -> SymArray sym idx b
1411           -> IO (SymArray sym idx b)
1412
1413  -- | Return true if two arrays are equal.
1414  --
1415  -- Note that in the backend, arrays do not have a fixed number of elements, so
1416  -- this equality requires that arrays are equal on all elements.
1417  arrayEq :: sym
1418          -> SymArray sym idx b
1419          -> SymArray sym idx b
1420          -> IO (Pred sym)
1421
1422  -- | Return true if all entries in the array are true.
1423  allTrueEntries :: sym -> SymArray sym idx BaseBoolType -> IO (Pred sym)
1424  allTrueEntries sym a = do
1425    case exprType a of
1426      BaseArrayRepr idx_tps _ ->
1427        arrayEq sym a =<< constantArray sym idx_tps (truePred sym)
1428
1429  -- | Return true if the array has the value true at every index satisfying the
1430  -- given predicate.
1431  arrayTrueOnEntries
1432    :: sym
1433    -> SymFn sym (idx::>itp) BaseBoolType
1434    -- ^ Predicate that indicates if array should be true.
1435    -> SymArray sym (idx ::> itp) BaseBoolType
1436    -> IO (Pred sym)
1437
1438  ----------------------------------------------------------------------
1439  -- Lossless (injective) conversions
1440
1441  -- | Convert an integer to a real number.
1442  integerToReal :: sym -> SymInteger sym -> IO (SymReal sym)
1443
1444  -- | Return the unsigned value of the given bitvector as an integer.
1445  bvToInteger :: (1 <= w) => sym -> SymBV sym w -> IO (SymInteger sym)
1446
1447  -- | Return the signed value of the given bitvector as an integer.
1448  sbvToInteger :: (1 <= w) => sym -> SymBV sym w -> IO (SymInteger sym)
1449
1450  -- | Return @1@ if the predicate is true; @0@ otherwise.
1451  predToBV :: (1 <= w) => sym -> Pred sym -> NatRepr w -> IO (SymBV sym w)
1452
1453  ----------------------------------------------------------------------
1454  -- Lossless combinators
1455
1456  -- | Convert an unsigned bitvector to a real number.
1457  uintToReal :: (1 <= w) => sym -> SymBV sym w -> IO (SymReal sym)
1458  uintToReal sym = bvToInteger sym >=> integerToReal sym
1459
1460  -- | Convert an signed bitvector to a real number.
1461  sbvToReal :: (1 <= w) => sym -> SymBV sym w -> IO (SymReal sym)
1462  sbvToReal sym = sbvToInteger sym >=> integerToReal sym
1463
1464  ----------------------------------------------------------------------
1465  -- Lossy (non-injective) conversions
1466
1467  -- | Round a real number to an integer.
1468  --
1469  -- Numbers are rounded to the nearest integer, with rounding away from
1470  -- zero when two integers are equidistant (e.g., 1.5 rounds to 2).
1471  realRound :: sym -> SymReal sym -> IO (SymInteger sym)
1472
1473  -- | Round a real number to an integer.
1474  --
1475  -- Numbers are rounded to the nearest integer, with rounding toward
1476  -- even values when two integers are equidistant (e.g., 2.5 rounds to 2).
1477  realRoundEven :: sym -> SymReal sym -> IO (SymInteger sym)
1478
1479  -- | Round down to the nearest integer that is at most this value.
1480  realFloor :: sym -> SymReal sym -> IO (SymInteger sym)
1481
1482  -- | Round up to the nearest integer that is at least this value.
1483  realCeil :: sym -> SymReal sym -> IO (SymInteger sym)
1484
1485  -- | Round toward zero.  This is @floor(x)@ when x is positive
1486  --   and @celing(x)@ when @x@ is negative.
1487  realTrunc :: sym -> SymReal sym -> IO (SymInteger sym)
1488  realTrunc sym x =
1489    do pneg <- realLt sym x =<< realLit sym 0
1490       iteM intIte sym pneg (realCeil sym x) (realFloor sym x)
1491
1492  -- | Convert an integer to a bitvector.  The result is the unique bitvector
1493  --   whose value (signed or unsigned) is congruent to the input integer, modulo @2^w@.
1494  --
1495  --   This operation has the following properties:
1496  --
1497  --   *  @bvToInteger (integerToBv x w) == mod x (2^w)@
1498  --   *  @bvToInteger (integerToBV x w) == x@     when @0 <= x < 2^w@.
1499  --   *  @sbvToInteger (integerToBV x w) == mod (x + 2^(w-1)) (2^w) - 2^(w-1)@
1500  --   *  @sbvToInteger (integerToBV x w) == x@    when @-2^(w-1) <= x < 2^(w-1)@
1501  --   *  @integerToBV (bvToInteger y) w == y@     when @y@ is a @SymBV sym w@
1502  --   *  @integerToBV (sbvToInteger y) w == y@    when @y@ is a @SymBV sym w@
1503  integerToBV :: (1 <= w) => sym -> SymInteger sym -> NatRepr w -> IO (SymBV sym w)
1504
1505  ----------------------------------------------------------------------
1506  -- Lossy (non-injective) combinators
1507
1508  -- | Convert a real number to an integer.
1509  --
1510  -- The result is undefined if the given real number does not represent an integer.
1511  realToInteger :: sym -> SymReal sym -> IO (SymInteger sym)
1512
1513  -- | Convert a real number to an unsigned bitvector.
1514  --
1515  -- Numbers are rounded to the nearest representable number, with rounding away from
1516  -- zero when two integers are equidistant (e.g., 1.5 rounds to 2).
1517  -- When the real is negative the result is zero.
1518  realToBV :: (1 <= w) => sym -> SymReal sym -> NatRepr w -> IO (SymBV sym w)
1519  realToBV sym r w = do
1520    i <- realRound sym r
1521    clampedIntToBV sym i w
1522
1523  -- | Convert a real number to a signed bitvector.
1524  --
1525  -- Numbers are rounded to the nearest representable number, with rounding away from
1526  -- zero when two integers are equidistant (e.g., 1.5 rounds to 2).
1527  realToSBV  :: (1 <= w) => sym -> SymReal sym -> NatRepr w -> IO (SymBV sym w)
1528  realToSBV sym r w  = do
1529    i <- realRound sym r
1530    clampedIntToSBV sym i w
1531
1532  -- | Convert an integer to the nearest signed bitvector.
1533  --
1534  -- Numbers are rounded to the nearest representable number.
1535  clampedIntToSBV :: (1 <= w) => sym -> SymInteger sym -> NatRepr w -> IO (SymBV sym w)
1536  clampedIntToSBV sym i w
1537    | Just v <- asInteger i = do
1538      bvLit sym w $ BV.signedClamp w v
1539    | otherwise = do
1540      -- Handle case where i < minSigned w
1541      let min_val = minSigned w
1542          min_val_bv = BV.minSigned w
1543      min_sym <- intLit sym min_val
1544      is_lt <- intLt sym i min_sym
1545      iteM bvIte sym is_lt (bvLit sym w min_val_bv) $ do
1546        -- Handle case where i > maxSigned w
1547        let max_val = maxSigned w
1548            max_val_bv = BV.maxSigned w
1549        max_sym <- intLit sym max_val
1550        is_gt <- intLt sym max_sym i
1551        iteM bvIte sym is_gt (bvLit sym w max_val_bv) $ do
1552          -- Do unclamped conversion.
1553          integerToBV sym i w
1554
1555  -- | Convert an integer to the nearest unsigned bitvector.
1556  --
1557  -- Numbers are rounded to the nearest representable number.
1558  clampedIntToBV :: (1 <= w) => sym -> SymInteger sym -> NatRepr w -> IO (SymBV sym w)
1559  clampedIntToBV sym i w
1560    | Just v <- asInteger i = do
1561      bvLit sym w $ BV.unsignedClamp w v
1562    | otherwise = do
1563      -- Handle case where i < 0
1564      min_sym <- intLit sym 0
1565      is_lt <- intLt sym i min_sym
1566      iteM bvIte sym is_lt (bvLit sym w (BV.zero w)) $ do
1567        -- Handle case where i > maxUnsigned w
1568        let max_val = maxUnsigned w
1569            max_val_bv = BV.maxUnsigned w
1570        max_sym <- intLit sym max_val
1571        is_gt <- intLt sym max_sym i
1572        iteM bvIte sym is_gt (bvLit sym w max_val_bv) $
1573          -- Do unclamped conversion.
1574          integerToBV sym i w
1575
1576  ----------------------------------------------------------------------
1577  -- Bitvector operations.
1578
1579  -- | Convert a signed bitvector to the nearest signed bitvector with
1580  -- the given width. If the resulting width is smaller, this clamps
1581  -- the value to min-int or max-int when necessary.
1582  intSetWidth :: (1 <= m, 1 <= n) => sym -> SymBV sym m -> NatRepr n -> IO (SymBV sym n)
1583  intSetWidth sym e n = do
1584    let m = bvWidth e
1585    case n `testNatCases` m of
1586      -- Truncate when the width of e is larger than w.
1587      NatCaseLT LeqProof -> do
1588        -- Check if e underflows
1589        does_underflow <- bvSlt sym e =<< bvLit sym m (BV.sext n m (BV.minSigned n))
1590        iteM bvIte sym does_underflow (bvLit sym n (BV.minSigned n)) $ do
1591          -- Check if e overflows target signed representation.
1592          does_overflow <- bvSgt sym e =<< bvLit sym m (BV.mkBV m (maxSigned n))
1593          iteM bvIte sym does_overflow (bvLit sym n (BV.maxSigned n)) $ do
1594            -- Just do truncation.
1595            bvTrunc sym n e
1596      NatCaseEQ -> return e
1597      NatCaseGT LeqProof -> bvSext sym n e
1598
1599  -- | Convert an unsigned bitvector to the nearest unsigned bitvector with
1600  -- the given width (clamp on overflow).
1601  uintSetWidth :: (1 <= m, 1 <= n) => sym -> SymBV sym m -> NatRepr n -> IO (SymBV sym n)
1602  uintSetWidth sym e n = do
1603    let m = bvWidth e
1604    case n `testNatCases` m of
1605      NatCaseLT LeqProof -> do
1606        does_overflow <- bvUgt sym e =<< bvLit sym m (BV.mkBV m (maxUnsigned n))
1607        iteM bvIte sym does_overflow (bvLit sym n (BV.maxUnsigned n)) $ bvTrunc sym n e
1608      NatCaseEQ -> return e
1609      NatCaseGT LeqProof -> bvZext sym n e
1610
1611  -- | Convert an signed bitvector to the nearest unsigned bitvector with
1612  -- the given width (clamp on overflow).
1613  intToUInt :: (1 <= m, 1 <= n) => sym -> SymBV sym m -> NatRepr n -> IO (SymBV sym n)
1614  intToUInt sym e w = do
1615    p <- bvIsNeg sym e
1616    iteM bvIte sym p (bvLit sym w (BV.zero w)) (uintSetWidth sym e w)
1617
1618  -- | Convert an unsigned bitvector to the nearest signed bitvector with
1619  -- the given width (clamp on overflow).
1620  uintToInt :: (1 <= m, 1 <= n) => sym -> SymBV sym m -> NatRepr n -> IO (SymBV sym n)
1621  uintToInt sym e n = do
1622    let m = bvWidth e
1623    case n `testNatCases` m of
1624      NatCaseLT LeqProof -> do
1625        -- Get maximum signed n-bit number.
1626        max_val <- bvLit sym m (BV.sext n m (BV.maxSigned n))
1627        -- Check if expression is less than maximum.
1628        p <- bvUle sym e max_val
1629        -- Select appropriate number then truncate.
1630        bvTrunc sym n =<< bvIte sym p e max_val
1631      NatCaseEQ -> do
1632        max_val <- maxSignedBV sym n
1633        p <- bvUle sym e max_val
1634        bvIte sym p e max_val
1635      NatCaseGT LeqProof -> do
1636        bvZext sym n e
1637
1638  ----------------------------------------------------------------------
1639  -- String operations
1640
1641  -- | Create an empty string literal
1642  stringEmpty :: sym -> StringInfoRepr si -> IO (SymString sym si)
1643
1644  -- | Create a concrete string literal
1645  stringLit :: sym -> StringLiteral si -> IO (SymString sym si)
1646
1647  -- | Check the equality of two strings
1648  stringEq :: sym -> SymString sym si -> SymString sym si -> IO (Pred sym)
1649
1650  -- | If-then-else on strings
1651  stringIte :: sym -> Pred sym -> SymString sym si -> SymString sym si -> IO (SymString sym si)
1652
1653  -- | Concatenate two strings
1654  stringConcat :: sym -> SymString sym si -> SymString sym si -> IO (SymString sym si)
1655
1656  -- | Test if the first string contains the second string as a substring
1657  stringContains :: sym -> SymString sym si -> SymString sym si -> IO (Pred sym)
1658
1659  -- | Test if the first string is a prefix of the second string
1660  stringIsPrefixOf :: sym -> SymString sym si -> SymString sym si -> IO (Pred sym)
1661
1662  -- | Test if the first string is a suffix of the second string
1663  stringIsSuffixOf :: sym -> SymString sym si -> SymString sym si -> IO (Pred sym)
1664
1665  -- | Return the first position at which the second string can be found as a substring
1666  --   in the first string, starting from the given index.
1667  --   If no such position exists, return a negative value.
1668  stringIndexOf :: sym -> SymString sym si -> SymString sym si -> SymInteger sym -> IO (SymInteger sym)
1669
1670  -- | Compute the length of a string
1671  stringLength :: sym -> SymString sym si -> IO (SymInteger sym)
1672
1673  -- | @stringSubstring s off len@ extracts the substring of @s@ starting at index @off@ and
1674  --   having length @len@.  The result of this operation is undefined if @off@ and @len@
1675  --   do not specify a valid substring of @s@; in particular, we must have
1676  --   0 <= off@, @0 <= len@ and @off+len <= length(s)@.
1677  stringSubstring :: sym -> SymString sym si -> SymInteger sym -> SymInteger sym -> IO (SymString sym si)
1678
1679  ----------------------------------------------------------------------
1680  -- Real operations
1681
1682  -- | Return real number 0.
1683  realZero :: sym -> SymReal sym
1684
1685  -- | Create a constant real literal.
1686  realLit :: sym -> Rational -> IO (SymReal sym)
1687
1688  -- | Make a real literal from a scientific value. May be overridden
1689  -- if we want to avoid the overhead of converting scientific value
1690  -- to rational.
1691  sciLit :: sym -> Scientific -> IO (SymReal sym)
1692  sciLit sym s = realLit sym (toRational s)
1693
1694  -- | Check equality of two real numbers.
1695  realEq :: sym -> SymReal sym -> SymReal sym -> IO (Pred sym)
1696
1697  -- | Check non-equality of two real numbers.
1698  realNe :: sym -> SymReal sym -> SymReal sym -> IO (Pred sym)
1699  realNe sym x y = notPred sym =<< realEq sym x y
1700
1701  -- | Check @<=@ on two real numbers.
1702  realLe :: sym -> SymReal sym -> SymReal sym -> IO (Pred sym)
1703
1704  -- | Check @<@ on two real numbers.
1705  realLt :: sym -> SymReal sym -> SymReal sym -> IO (Pred sym)
1706  realLt sym x y = notPred sym =<< realLe sym y x
1707
1708  -- | Check @>=@ on two real numbers.
1709  realGe :: sym -> SymReal sym -> SymReal sym -> IO (Pred sym)
1710  realGe sym x y = realLe sym y x
1711
1712  -- | Check @>@ on two real numbers.
1713  realGt :: sym -> SymReal sym -> SymReal sym -> IO (Pred sym)
1714  realGt sym x y = realLt sym y x
1715
1716  -- | If-then-else on real numbers.
1717  realIte :: sym -> Pred sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1718
1719  -- | Return the minimum of two real numbers.
1720  realMin :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1721  realMin sym x y =
1722    do p <- realLe sym x y
1723       realIte sym p x y
1724
1725  -- | Return the maxmimum of two real numbers.
1726  realMax :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1727  realMax sym x y =
1728    do p <- realLe sym x y
1729       realIte sym p y x
1730
1731  -- | Negate a real number.
1732  realNeg :: sym -> SymReal sym -> IO (SymReal sym)
1733
1734  -- | Add two real numbers.
1735  realAdd :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1736
1737  -- | Multiply two real numbers.
1738  realMul :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1739
1740  -- | Subtract one real from another.
1741  realSub :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1742  realSub sym x y = realAdd sym x =<< realNeg sym y
1743
1744  -- | @realSq sym x@ returns @x * x@.
1745  realSq :: sym -> SymReal sym -> IO (SymReal sym)
1746  realSq sym x = realMul sym x x
1747
1748  -- | @realDiv sym x y@ returns term equivalent to @x/y@.
1749  --
1750  -- The result is undefined when @y@ is zero.
1751  realDiv :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1752
1753  -- | @realMod x y@ returns the value of @x - y * floor(x / y)@ when
1754  -- @y@ is not zero and @x@ when @y@ is zero.
1755  realMod :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1756  realMod sym x y = do
1757    isZero <- realEq sym y (realZero sym)
1758    iteM realIte sym isZero (return x) $ do
1759      realSub sym x =<< realMul sym y
1760                    =<< integerToReal sym
1761                    =<< realFloor sym
1762                    =<< realDiv sym x y
1763
1764  -- | Predicate that holds if the real number is an exact integer.
1765  isInteger :: sym -> SymReal sym -> IO (Pred sym)
1766
1767  -- | Return true if the real is non-negative.
1768  realIsNonNeg :: sym -> SymReal sym -> IO (Pred sym)
1769  realIsNonNeg sym x = realLe sym (realZero sym) x
1770
1771  -- | @realSqrt sym x@ returns sqrt(x).  Result is undefined
1772  -- if @x@ is negative.
1773  realSqrt :: sym -> SymReal sym -> IO (SymReal sym)
1774
1775  -- | @realAtan2 sym y x@ returns the arctangent of @y/x@ with a range
1776  -- of @-pi@ to @pi@; this corresponds to the angle between the positive
1777  -- x-axis and the line from the origin @(x,y)@.
1778  --
1779  -- When @x@ is @0@ this returns @pi/2 * sgn y@.
1780  --
1781  -- When @x@ and @y@ are both zero, this function is undefined.
1782  realAtan2 :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1783
1784  -- | Return value denoting pi.
1785  realPi :: sym -> IO (SymReal sym)
1786
1787  -- | Natural logarithm.  @realLog x@ is undefined
1788  --   for @x <= 0@.
1789  realLog :: sym -> SymReal sym -> IO (SymReal sym)
1790
1791  -- | Natural exponentiation
1792  realExp :: sym -> SymReal sym -> IO (SymReal sym)
1793
1794  -- | Sine trig function
1795  realSin :: sym -> SymReal sym -> IO (SymReal sym)
1796
1797  -- | Cosine trig function
1798  realCos :: sym -> SymReal sym -> IO (SymReal sym)
1799
1800  -- | Tangent trig function.  @realTan x@ is undefined
1801  --   when @cos x = 0@,  i.e., when @x = pi/2 + k*pi@ for
1802  --   some integer @k@.
1803  realTan :: sym -> SymReal sym -> IO (SymReal sym)
1804  realTan sym x = do
1805    sin_x <- realSin sym x
1806    cos_x <- realCos sym x
1807    realDiv sym sin_x cos_x
1808
1809  -- | Hyperbolic sine
1810  realSinh :: sym -> SymReal sym -> IO (SymReal sym)
1811
1812  -- | Hyperbolic cosine
1813  realCosh :: sym -> SymReal sym -> IO (SymReal sym)
1814
1815  -- | Hyperbolic tangent
1816  realTanh :: sym -> SymReal sym -> IO (SymReal sym)
1817  realTanh sym x = do
1818    sinh_x <- realSinh sym x
1819    cosh_x <- realCosh sym x
1820    realDiv sym sinh_x cosh_x
1821
1822  -- | Return absolute value of the real number.
1823  realAbs :: sym -> SymReal sym -> IO (SymReal sym)
1824  realAbs sym x = do
1825    c <- realGe sym x (realZero sym)
1826    realIte sym c x =<< realNeg sym x
1827
1828  -- | @realHypot x y@ returns sqrt(x^2 + y^2).
1829  realHypot :: sym -> SymReal sym -> SymReal sym -> IO (SymReal sym)
1830  realHypot sym x y = do
1831    case (asRational x, asRational y) of
1832      (Just 0, _) -> realAbs sym y
1833      (_, Just 0) -> realAbs sym x
1834      _ -> do
1835        x2 <- realSq sym x
1836        y2 <- realSq sym y
1837        realSqrt sym =<< realAdd sym x2 y2
1838
1839  ----------------------------------------------------------------------
1840  -- IEEE-754 floating-point operations
1841  -- | Return floating point number @+0@.
1842  floatPZero :: sym -> FloatPrecisionRepr fpp -> IO (SymFloat sym fpp)
1843
1844  -- | Return floating point number @-0@.
1845  floatNZero :: sym -> FloatPrecisionRepr fpp -> IO (SymFloat sym fpp)
1846
1847  -- |  Return floating point NaN.
1848  floatNaN :: sym -> FloatPrecisionRepr fpp -> IO (SymFloat sym fpp)
1849
1850  -- | Return floating point @+infinity@.
1851  floatPInf :: sym -> FloatPrecisionRepr fpp -> IO (SymFloat sym fpp)
1852
1853  -- | Return floating point @-infinity@.
1854  floatNInf :: sym -> FloatPrecisionRepr fpp -> IO (SymFloat sym fpp)
1855
1856  -- | Create a floating point literal from a rational literal.
1857  --   The rational value will be rounded if necessary using the
1858  --   "round to nearest even" rounding mode.
1859  floatLitRational
1860    :: sym -> FloatPrecisionRepr fpp -> Rational -> IO (SymFloat sym fpp)
1861  floatLitRational sym fpp x = realToFloat sym fpp RNE =<< realLit sym x
1862
1863  -- | Create a floating point literal from a @BigFloat@ value.
1864  floatLit :: sym -> FloatPrecisionRepr fpp -> BigFloat -> IO (SymFloat sym fpp)
1865
1866  -- | Negate a floating point number.
1867  floatNeg
1868    :: sym
1869    -> SymFloat sym fpp
1870    -> IO (SymFloat sym fpp)
1871
1872  -- | Return the absolute value of a floating point number.
1873  floatAbs
1874    :: sym
1875    -> SymFloat sym fpp
1876    -> IO (SymFloat sym fpp)
1877
1878  -- | Compute the square root of a floating point number.
1879  floatSqrt
1880    :: sym
1881    -> RoundingMode
1882    -> SymFloat sym fpp
1883    -> IO (SymFloat sym fpp)
1884
1885  -- | Add two floating point numbers.
1886  floatAdd
1887    :: sym
1888    -> RoundingMode
1889    -> SymFloat sym fpp
1890    -> SymFloat sym fpp
1891    -> IO (SymFloat sym fpp)
1892
1893  -- | Subtract two floating point numbers.
1894  floatSub
1895    :: sym
1896    -> RoundingMode
1897    -> SymFloat sym fpp
1898    -> SymFloat sym fpp
1899    -> IO (SymFloat sym fpp)
1900
1901  -- | Multiply two floating point numbers.
1902  floatMul
1903    :: sym
1904    -> RoundingMode
1905    -> SymFloat sym fpp
1906    -> SymFloat sym fpp
1907    -> IO (SymFloat sym fpp)
1908
1909  -- | Divide two floating point numbers.
1910  floatDiv
1911    :: sym
1912    -> RoundingMode
1913    -> SymFloat sym fpp
1914    -> SymFloat sym fpp
1915    -> IO (SymFloat sym fpp)
1916
1917  -- | Compute the reminder: @x - y * n@, where @n@ in Z is nearest to @x / y@
1918  --   (breaking ties to even values of @n@).
1919  floatRem
1920    :: sym
1921    -> SymFloat sym fpp
1922    -> SymFloat sym fpp
1923    -> IO (SymFloat sym fpp)
1924
1925  -- | Return the minimum of two floating point numbers.
1926  --   If one argument is NaN, return the other argument.
1927  --   If the arguments are equal when compared as floating-point values,
1928  --   one of the two will be returned, but it is unspecified which;
1929  --   this underspecification can (only) be observed with zeros of different signs.
1930  floatMin
1931    :: sym
1932    -> SymFloat sym fpp
1933    -> SymFloat sym fpp
1934    -> IO (SymFloat sym fpp)
1935
1936  -- | Return the maximum of two floating point numbers.
1937  --   If one argument is NaN, return the other argument.
1938  --   If the arguments are equal when compared as floating-point values,
1939  --   one of the two will be returned, but it is unspecified which;
1940  --   this underspecification can (only) be observed with zeros of different signs.
1941  floatMax
1942    :: sym
1943    -> SymFloat sym fpp
1944    -> SymFloat sym fpp
1945    -> IO (SymFloat sym fpp)
1946
1947  -- | Compute the fused multiplication and addition: @(x * y) + z@.
1948  floatFMA
1949    :: sym
1950    -> RoundingMode
1951    -> SymFloat sym fpp
1952    -> SymFloat sym fpp
1953    -> SymFloat sym fpp
1954    -> IO (SymFloat sym fpp)
1955
1956  -- | Check logical equality of two floating point numbers.
1957  --
1958  --   NOTE! This does NOT accurately represent the equality test on floating point
1959  --   values typically found in programming languages.  See 'floatFpEq' instead.
1960  floatEq
1961    :: sym
1962    -> SymFloat sym fpp
1963    -> SymFloat sym fpp
1964    -> IO (Pred sym)
1965
1966  -- | Check logical non-equality of two floating point numbers.
1967  --
1968  --   NOTE! This does NOT accurately represent the non-equality test on floating point
1969  --   values typically found in programming languages.  See 'floatFpEq' instead.
1970  floatNe
1971    :: sym
1972    -> SymFloat sym fpp
1973    -> SymFloat sym fpp
1974    -> IO (Pred sym)
1975
1976  -- | Check IEEE-754 equality of two floating point numbers.
1977  --
1978  --   NOTE! This test returns false if either value is @NaN@; in particular
1979  --   @NaN@ is not equal to itself!  Moreover, positive and negative 0 will
1980  --   compare equal, despite having different bit patterns.
1981  --
1982  --   This test is most appropriate for interpreting the equality tests of
1983  --   typical languages using floating point.  Moreover, not-equal tests
1984  --   are usually the negation of this test, rather than the `floatFpNe`
1985  --   test below.
1986  floatFpEq
1987    :: sym
1988    -> SymFloat sym fpp
1989    -> SymFloat sym fpp
1990    -> IO (Pred sym)
1991
1992  -- | Check IEEE-754 apartness of two floating point numbers.
1993  --
1994  --   NOTE! This test returns false if either value is @NaN@; in particular
1995  --   @NaN@ is not apart from any other value!  Moreover, positive and
1996  --   negative 0 will not compare apart, despite having different
1997  --   bit patterns.  Note that @x@ is apart from @y@ iff @x < y@ or @x > y@.
1998  --
1999  --   This test usually does NOT correspond to the not-equal tests found
2000  --   in programming languages.  Instead, one generally takes the logical
2001  --   negation of the `floatFpEq` test.
2002  floatFpApart
2003    :: sym
2004    -> SymFloat sym fpp
2005    -> SymFloat sym fpp
2006    -> IO (Pred sym)
2007  floatFpApart sym x y =
2008    do l <- floatLt sym x y
2009       g <- floatGt sym x y
2010       orPred sym l g
2011
2012  -- | Check if two floating point numbers are "unordered".  This happens
2013  --   precicely when one or both of the inputs is @NaN@.
2014  floatFpUnordered
2015    :: sym
2016    -> SymFloat sym fpp
2017    -> SymFloat sym fpp
2018    -> IO (Pred sym)
2019  floatFpUnordered sym x y =
2020    do xnan <- floatIsNaN sym x
2021       ynan <- floatIsNaN sym y
2022       orPred sym xnan ynan
2023
2024  -- | Check IEEE-754 @<=@ on two floating point numbers.
2025  --
2026  --   NOTE! This test returns false if either value is @NaN@; in particular
2027  --   @NaN@ is not less-than-or-equal-to any other value!  Moreover, positive
2028  --   and negative 0 are considered equal, despite having different bit patterns.
2029  floatLe
2030    :: sym
2031    -> SymFloat sym fpp
2032    -> SymFloat sym fpp
2033    -> IO (Pred sym)
2034
2035  -- | Check IEEE-754 @<@ on two floating point numbers.
2036  --
2037  --   NOTE! This test returns false if either value is @NaN@; in particular
2038  --   @NaN@ is not less-than any other value! Moreover, positive
2039  --   and negative 0 are considered equal, despite having different bit patterns.
2040  floatLt
2041    :: sym
2042    -> SymFloat sym fpp
2043    -> SymFloat sym fpp
2044    -> IO (Pred sym)
2045
2046  -- | Check IEEE-754 @>=@ on two floating point numbers.
2047  --
2048  --   NOTE! This test returns false if either value is @NaN@; in particular
2049  --   @NaN@ is not greater-than-or-equal-to any other value!  Moreover, positive
2050  --   and negative 0 are considered equal, despite having different bit patterns.
2051  floatGe
2052    :: sym
2053    -> SymFloat sym fpp
2054    -> SymFloat sym fpp
2055    -> IO (Pred sym)
2056
2057  -- | Check IEEE-754 @>@ on two floating point numbers.
2058  --
2059  --   NOTE! This test returns false if either value is @NaN@; in particular
2060  --   @NaN@ is not greater-than any other value! Moreover, positive
2061  --   and negative 0 are considered equal, despite having different bit patterns.
2062  floatGt
2063    :: sym
2064    -> SymFloat sym fpp
2065    -> SymFloat sym fpp
2066    -> IO (Pred sym)
2067
2068  -- | Test if a floating-point value is NaN.
2069  floatIsNaN :: sym -> SymFloat sym fpp -> IO (Pred sym)
2070
2071  -- | Test if a floating-point value is (positive or negative) infinity.
2072  floatIsInf :: sym -> SymFloat sym fpp -> IO (Pred sym)
2073
2074  -- | Test if a floating-point value is (positive or negative) zero.
2075  floatIsZero :: sym -> SymFloat sym fpp -> IO (Pred sym)
2076
2077  -- | Test if a floating-point value is positive.  NOTE!
2078  --   NaN is considered neither positive nor negative.
2079  floatIsPos :: sym -> SymFloat sym fpp -> IO (Pred sym)
2080
2081  -- | Test if a floating-point value is negative.  NOTE!
2082  --   NaN is considered neither positive nor negative.
2083  floatIsNeg :: sym -> SymFloat sym fpp -> IO (Pred sym)
2084
2085  -- | Test if a floating-point value is subnormal.
2086  floatIsSubnorm :: sym -> SymFloat sym fpp -> IO (Pred sym)
2087
2088  -- | Test if a floating-point value is normal.
2089  floatIsNorm :: sym -> SymFloat sym fpp -> IO (Pred sym)
2090
2091  -- | If-then-else on floating point numbers.
2092  floatIte
2093    :: sym
2094    -> Pred sym
2095    -> SymFloat sym fpp
2096    -> SymFloat sym fpp
2097    -> IO (SymFloat sym fpp)
2098
2099  -- | Change the precision of a floating point number.
2100  floatCast
2101    :: sym
2102    -> FloatPrecisionRepr fpp
2103    -> RoundingMode
2104    -> SymFloat sym fpp'
2105    -> IO (SymFloat sym fpp)
2106  -- | Round a floating point number to an integral value.
2107  floatRound
2108    :: sym
2109    -> RoundingMode
2110    -> SymFloat sym fpp
2111    -> IO (SymFloat sym fpp)
2112  -- | Convert from binary representation in IEEE 754-2008 format to
2113  --   floating point.
2114  floatFromBinary
2115    :: (2 <= eb, 2 <= sb)
2116    => sym
2117    -> FloatPrecisionRepr (FloatingPointPrecision eb sb)
2118    -> SymBV sym (eb + sb)
2119    -> IO (SymFloat sym (FloatingPointPrecision eb sb))
2120  -- | Convert from floating point from to the binary representation in
2121  --   IEEE 754-2008 format.
2122  --
2123  --   NOTE! @NaN@ has multiple representations, i.e. all bit patterns where
2124  --   the exponent is @0b1..1@ and the significant is not @0b0..0@.
2125  --   This functions returns the representation of positive "quiet" @NaN@,
2126  --   i.e. the bit pattern where the sign is @0b0@, the exponent is @0b1..1@,
2127  --   and the significant is @0b10..0@.
2128  floatToBinary
2129    :: (2 <= eb, 2 <= sb)
2130    => sym
2131    -> SymFloat sym (FloatingPointPrecision eb sb)
2132    -> IO (SymBV sym (eb + sb))
2133  -- | Convert a unsigned bitvector to a floating point number.
2134  bvToFloat
2135    :: (1 <= w)
2136    => sym
2137    -> FloatPrecisionRepr fpp
2138    -> RoundingMode
2139    -> SymBV sym w
2140    -> IO (SymFloat sym fpp)
2141  -- | Convert a signed bitvector to a floating point number.
2142  sbvToFloat
2143    :: (1 <= w)
2144    => sym
2145    -> FloatPrecisionRepr fpp
2146    -> RoundingMode
2147    -> SymBV sym w
2148    -> IO (SymFloat sym fpp)
2149  -- | Convert a real number to a floating point number.
2150  realToFloat
2151    :: sym
2152    -> FloatPrecisionRepr fpp
2153    -> RoundingMode
2154    -> SymReal sym
2155    -> IO (SymFloat sym fpp)
2156  -- | Convert a floating point number to a unsigned bitvector.
2157  floatToBV
2158    :: (1 <= w)
2159    => sym
2160    -> NatRepr w
2161    -> RoundingMode
2162    -> SymFloat sym fpp
2163    -> IO (SymBV sym w)
2164  -- | Convert a floating point number to a signed bitvector.
2165  floatToSBV
2166    :: (1 <= w)
2167    => sym
2168    -> NatRepr w
2169    -> RoundingMode
2170    -> SymFloat sym fpp
2171    -> IO (SymBV sym w)
2172  -- | Convert a floating point number to a real number.
2173  floatToReal :: sym -> SymFloat sym fpp -> IO (SymReal sym)
2174
2175  ----------------------------------------------------------------------
2176  -- Cplx operations
2177
2178  -- | Create a complex from cartesian coordinates.
2179  mkComplex :: sym -> Complex (SymReal sym) -> IO (SymCplx sym)
2180
2181  -- | @getRealPart x@ returns the real part of @x@.
2182  getRealPart :: sym -> SymCplx sym -> IO (SymReal sym)
2183
2184  -- | @getImagPart x@ returns the imaginary part of @x@.
2185  getImagPart :: sym -> SymCplx sym -> IO (SymReal sym)
2186
2187  -- | Convert a complex number into the real and imaginary part.
2188  cplxGetParts :: sym -> SymCplx sym -> IO (Complex (SymReal sym))
2189
2190  -- | Create a constant complex literal.
2191  mkComplexLit :: sym -> Complex Rational -> IO (SymCplx sym)
2192  mkComplexLit sym d = mkComplex sym =<< traverse (realLit sym) d
2193
2194  -- | Create a complex from a real value.
2195  cplxFromReal :: sym -> SymReal sym -> IO (SymCplx sym)
2196  cplxFromReal sym r = mkComplex sym (r :+ realZero sym)
2197
2198  -- | If-then-else on complex values.
2199  cplxIte :: sym -> Pred sym -> SymCplx sym -> SymCplx sym -> IO (SymCplx sym)
2200  cplxIte sym c x y = do
2201    case asConstantPred c of
2202      Just True -> return x
2203      Just False -> return y
2204      _ -> do
2205        xr :+ xi <- cplxGetParts sym x
2206        yr :+ yi <- cplxGetParts sym y
2207        zr <- realIte sym c xr yr
2208        zi <- realIte sym c xi yi
2209        mkComplex sym (zr :+ zi)
2210
2211  -- | Negate a complex number.
2212  cplxNeg :: sym -> SymCplx sym -> IO (SymCplx sym)
2213  cplxNeg sym x = mkComplex sym =<< traverse (realNeg sym) =<< cplxGetParts sym x
2214
2215  -- | Add two complex numbers together.
2216  cplxAdd :: sym -> SymCplx sym -> SymCplx sym -> IO (SymCplx sym)
2217  cplxAdd sym x y = do
2218    xr :+ xi <- cplxGetParts sym x
2219    yr :+ yi <- cplxGetParts sym y
2220    zr <- realAdd sym xr yr
2221    zi <- realAdd sym xi yi
2222    mkComplex sym (zr :+ zi)
2223
2224  -- | Subtract one complex number from another.
2225  cplxSub :: sym -> SymCplx sym -> SymCplx sym -> IO (SymCplx sym)
2226  cplxSub sym x y = do
2227    xr :+ xi <- cplxGetParts sym x
2228    yr :+ yi <- cplxGetParts sym y
2229    zr <- realSub sym xr yr
2230    zi <- realSub sym xi yi
2231    mkComplex sym (zr :+ zi)
2232
2233  -- | Multiply two complex numbers together.
2234  cplxMul :: sym -> SymCplx sym -> SymCplx sym -> IO (SymCplx sym)
2235  cplxMul sym x y = do
2236    xr :+ xi <- cplxGetParts sym x
2237    yr :+ yi <- cplxGetParts sym y
2238    rz0 <- realMul sym xr yr
2239    rz <- realSub sym rz0 =<< realMul sym xi yi
2240    iz0 <- realMul sym xi yr
2241    iz <- realAdd sym iz0 =<< realMul sym xr yi
2242    mkComplex sym (rz :+ iz)
2243
2244  -- | Compute the magnitude of a complex number.
2245  cplxMag :: sym -> SymCplx sym -> IO (SymReal sym)
2246  cplxMag sym x = do
2247    (xr :+ xi) <- cplxGetParts sym x
2248    realHypot sym xr xi
2249
2250  -- | Return the principal square root of a complex number.
2251  cplxSqrt :: sym -> SymCplx sym -> IO (SymCplx sym)
2252  cplxSqrt sym x = do
2253    (r_part :+ i_part) <- cplxGetParts sym x
2254    case (asRational r_part :+ asRational i_part)of
2255      (Just r :+ Just i) | Just z <- tryComplexSqrt tryRationalSqrt (r :+ i) ->
2256        mkComplexLit sym z
2257
2258      (_ :+ Just 0) -> do
2259        c <- realGe sym r_part (realZero sym)
2260        u <- iteM realIte sym c
2261          (realSqrt sym r_part)
2262          (realLit sym 0)
2263        v <- iteM realIte sym c
2264          (realLit sym 0)
2265          (realSqrt sym =<< realNeg sym r_part)
2266        mkComplex sym (u :+ v)
2267
2268      _ -> do
2269        m <- realHypot sym r_part i_part
2270        m_plus_r <- realAdd sym m r_part
2271        m_sub_r  <- realSub sym m r_part
2272        two <- realLit sym 2
2273        u <- realSqrt sym =<< realDiv sym m_plus_r two
2274        v <- realSqrt sym =<< realDiv sym m_sub_r  two
2275        neg_v <- realNeg sym v
2276        i_part_nonneg <- realIsNonNeg sym i_part
2277        v' <- realIte sym i_part_nonneg v neg_v
2278        mkComplex sym (u :+ v')
2279
2280  -- | Compute sine of a complex number.
2281  cplxSin :: sym -> SymCplx sym -> IO (SymCplx sym)
2282  cplxSin sym arg = do
2283    c@(x :+ y) <- cplxGetParts sym arg
2284    case asRational <$> c of
2285      (Just 0 :+ Just 0) -> cplxFromReal sym (realZero sym)
2286      (_ :+ Just 0) -> cplxFromReal sym =<< realSin sym x
2287      (Just 0 :+ _) -> do
2288        -- sin(0 + bi) = sin(0) cosh(b) + i*cos(0)sinh(b) = i*sinh(b)
2289        sinh_y <- realSinh sym y
2290        mkComplex sym (realZero sym :+ sinh_y)
2291      _ -> do
2292        sin_x <- realSin sym x
2293        cos_x <- realCos sym x
2294        sinh_y <- realSinh sym y
2295        cosh_y <- realCosh sym y
2296        r_part <- realMul sym sin_x cosh_y
2297        i_part <- realMul sym cos_x sinh_y
2298        mkComplex sym (r_part :+ i_part)
2299
2300  -- | Compute cosine of a complex number.
2301  cplxCos :: sym -> SymCplx sym -> IO (SymCplx sym)
2302  cplxCos sym arg = do
2303    c@(x :+ y) <- cplxGetParts sym arg
2304    case asRational <$> c of
2305      (Just 0 :+ Just 0) -> cplxFromReal sym =<< realLit sym 1
2306      (_ :+ Just 0) -> cplxFromReal sym =<< realCos sym x
2307      (Just 0 :+ _) -> do
2308        -- cos(0 + bi) = cos(0) cosh(b) - i*sin(0)sinh(b) = cosh(b)
2309        cosh_y    <- realCosh sym y
2310        cplxFromReal sym cosh_y
2311      _ -> do
2312        neg_sin_x <- realNeg sym =<< realSin sym x
2313        cos_x     <- realCos sym x
2314        sinh_y    <- realSinh sym y
2315        cosh_y    <- realCosh sym y
2316        r_part <- realMul sym cos_x cosh_y
2317        i_part <- realMul sym neg_sin_x sinh_y
2318        mkComplex sym (r_part :+ i_part)
2319
2320  -- | Compute tangent of a complex number.  @cplxTan x@ is undefined
2321  --   when @cplxCos x@ is @0@, which occurs only along the real line
2322  --   in the same conditions where @realCos x@ is @0@.
2323  cplxTan :: sym -> SymCplx sym -> IO (SymCplx sym)
2324  cplxTan sym arg = do
2325    c@(x :+ y) <- cplxGetParts sym arg
2326    case asRational <$> c of
2327      (Just 0 :+ Just 0) -> cplxFromReal sym (realZero sym)
2328      (_ :+ Just 0) -> do
2329        cplxFromReal sym =<< realTan sym x
2330      (Just 0 :+ _) -> do
2331        i_part <- realTanh sym y
2332        mkComplex sym (realZero sym :+ i_part)
2333      _ -> do
2334        sin_x <- realSin sym x
2335        cos_x <- realCos sym x
2336        sinh_y <- realSinh sym y
2337        cosh_y <- realCosh sym y
2338        u <- realMul sym cos_x cosh_y
2339        v <- realMul sym sin_x sinh_y
2340        u2 <- realMul sym u u
2341        v2 <- realMul sym v v
2342        m <- realAdd sym u2 v2
2343        sin_x_cos_x   <- realMul sym sin_x cos_x
2344        sinh_y_cosh_y <- realMul sym sinh_y cosh_y
2345        r_part <- realDiv sym sin_x_cos_x m
2346        i_part <- realDiv sym sinh_y_cosh_y m
2347        mkComplex sym (r_part :+ i_part)
2348
2349  -- | @hypotCplx x y@ returns @sqrt(abs(x)^2 + abs(y)^2)@.
2350  cplxHypot :: sym -> SymCplx sym -> SymCplx sym -> IO (SymCplx sym)
2351  cplxHypot sym x y = do
2352    (xr :+ xi) <- cplxGetParts sym x
2353    (yr :+ yi) <- cplxGetParts sym y
2354    xr2 <- realSq sym xr
2355    xi2 <- realSq sym xi
2356    yr2 <- realSq sym yr
2357    yi2 <- realSq sym yi
2358
2359    r2 <- foldM (realAdd sym) xr2 [xi2, yr2, yi2]
2360    cplxFromReal sym =<< realSqrt sym r2
2361
2362  -- | @roundCplx x@ rounds complex number to nearest integer.
2363  -- Numbers with a fractional part of 0.5 are rounded away from 0.
2364  -- Imaginary and real parts are rounded independently.
2365  cplxRound :: sym -> SymCplx sym -> IO (SymCplx sym)
2366  cplxRound sym x = do
2367    c <- cplxGetParts sym x
2368    mkComplex sym =<< traverse (integerToReal sym <=< realRound sym) c
2369
2370  -- | @cplxFloor x@ rounds to nearest integer less than or equal to x.
2371  -- Imaginary and real parts are rounded independently.
2372  cplxFloor :: sym -> SymCplx sym -> IO (SymCplx sym)
2373  cplxFloor sym x =
2374    mkComplex sym =<< traverse (integerToReal sym <=< realFloor sym)
2375                  =<< cplxGetParts sym x
2376  -- | @cplxCeil x@ rounds to nearest integer greater than or equal to x.
2377  -- Imaginary and real parts are rounded independently.
2378  cplxCeil :: sym -> SymCplx sym -> IO (SymCplx sym)
2379  cplxCeil sym x =
2380    mkComplex sym =<< traverse (integerToReal sym <=< realCeil sym)
2381                  =<< cplxGetParts sym x
2382
2383  -- | @conjReal x@ returns the complex conjugate of the input.
2384  cplxConj :: sym -> SymCplx sym -> IO (SymCplx sym)
2385  cplxConj sym x  = do
2386    r :+ i <- cplxGetParts sym x
2387    ic <- realNeg sym i
2388    mkComplex sym (r :+ ic)
2389
2390  -- | Returns exponential of a complex number.
2391  cplxExp :: sym -> SymCplx sym -> IO (SymCplx sym)
2392  cplxExp sym x = do
2393    (rx :+ i_part) <- cplxGetParts sym x
2394    expx <- realExp sym rx
2395    cosx <- realCos sym i_part
2396    sinx <- realSin sym i_part
2397    rz <- realMul sym expx cosx
2398    iz <- realMul sym expx sinx
2399    mkComplex sym (rz :+ iz)
2400
2401  -- | Check equality of two complex numbers.
2402  cplxEq :: sym -> SymCplx sym -> SymCplx sym -> IO (Pred sym)
2403  cplxEq sym x y = do
2404    xr :+ xi <- cplxGetParts sym x
2405    yr :+ yi <- cplxGetParts sym y
2406    pr <- realEq sym xr yr
2407    pj <- realEq sym xi yi
2408    andPred sym pr pj
2409
2410  -- | Check non-equality of two complex numbers.
2411  cplxNe :: sym -> SymCplx sym -> SymCplx sym -> IO (Pred sym)
2412  cplxNe sym x y = do
2413    xr :+ xi <- cplxGetParts sym x
2414    yr :+ yi <- cplxGetParts sym y
2415    pr <- realNe sym xr yr
2416    pj <- realNe sym xi yi
2417    orPred sym pr pj
2418
2419-- | This newtype is necessary for @bvJoinVector@ and @bvSplitVector@.
2420-- These both use functions from Data.Parameterized.Vector that
2421-- that expect a wrapper of kind (Type -> Type), and we can't partially
2422-- apply the type synonym (e.g. SymBv sym), whereas we can partially
2423-- apply this newtype.
2424newtype SymBV' sym w = MkSymBV' (SymBV sym w)
2425
2426-- | Join a @Vector@ of smaller bitvectors.  The vector is
2427--   interpreted in big endian order; that is, with most
2428--   significant bitvector first.
2429bvJoinVector :: forall sym n w. (1 <= w, IsExprBuilder sym)
2430             => sym
2431             -> NatRepr w
2432             -> Vector.Vector n (SymBV sym w)
2433             -> IO (SymBV sym (n * w))
2434bvJoinVector sym w =
2435  coerce $ Vector.joinWithM @IO @(SymBV' sym) @n bvConcat' w
2436  where bvConcat' :: forall l. (1 <= l)
2437                  => NatRepr l
2438                  -> SymBV' sym w
2439                  -> SymBV' sym l
2440                  -> IO (SymBV' sym (w + l))
2441        bvConcat' _ (MkSymBV' x) (MkSymBV' y) = MkSymBV' <$> bvConcat sym x y
2442
2443-- | Split a bitvector to a @Vector@ of smaller bitvectors.
2444--   The returned vector is in big endian order; that is, with most
2445--   significant bitvector first.
2446bvSplitVector :: forall sym n w. (IsExprBuilder sym, 1 <= w, 1 <= n)
2447              => sym
2448              -> NatRepr n
2449              -> NatRepr w
2450              -> SymBV sym (n * w)
2451              -> IO (Vector.Vector n (SymBV sym w))
2452bvSplitVector sym n w x =
2453  coerce $ Vector.splitWithA @IO BigEndian bvSelect' n w (MkSymBV' @sym x)
2454  where
2455    bvSelect' :: forall i. (i + w <= n * w)
2456              => NatRepr (n * w)
2457              -> NatRepr i
2458              -> SymBV' sym (n * w)
2459              -> IO (SymBV' sym w)
2460    bvSelect' _ i (MkSymBV' y) =
2461      fmap MkSymBV' $ bvSelect @_ @i @w sym i w y
2462
2463-- | Implement LLVM's "bswap" intrinsic
2464--
2465-- See <https://llvm.org/docs/LangRef.html#llvm-bswap-intrinsics
2466--       the LLVM @bswap@ documentation.>
2467--
2468-- This is the implementation in SawCore:
2469--
2470-- > llvmBSwap :: (n :: Nat) -> bitvector (mulNat n 8) -> bitvector (mulNat n 8);
2471-- > llvmBSwap n x = join n 8 Bool (reverse n (bitvector 8) (split n 8 Bool x));
2472bvSwap :: forall sym n. (1 <= n, IsExprBuilder sym)
2473       => sym               -- ^ Symbolic interface
2474       -> NatRepr n
2475       -> SymBV sym (n*8)   -- ^ Bitvector to swap around
2476       -> IO (SymBV sym (n*8))
2477bvSwap sym n v = do
2478  bvJoinVector sym (knownNat @8) . Vector.reverse
2479    =<< bvSplitVector sym n (knownNat @8) v
2480
2481-- | Swap the order of the bits in a bitvector.
2482bvBitreverse :: forall sym w.
2483  (1 <= w, IsExprBuilder sym) =>
2484  sym ->
2485  SymBV sym w ->
2486  IO (SymBV sym w)
2487bvBitreverse sym v = do
2488  bvJoinVector sym (knownNat @1) . Vector.reverse
2489    =<< bvSplitVector sym (bvWidth v) (knownNat @1) v
2490
2491
2492-- | Create a literal from an 'IndexLit'.
2493indexLit :: IsExprBuilder sym => sym -> IndexLit idx -> IO (SymExpr sym idx)
2494indexLit sym (IntIndexLit i)  = intLit sym i
2495indexLit sym (BVIndexLit w v) = bvLit sym w v
2496
2497-- | A utility combinator for combining actions
2498--   that build terms with if/then/else.
2499--   If the given predicate is concretely true or
2500--   false only the corresponding "then" or "else"
2501--   action is run; otherwise both actions are run
2502--   and combined with the given "ite" action.
2503iteM :: IsExprBuilder sym =>
2504  (sym -> Pred sym -> v -> v -> IO v) ->
2505  sym -> Pred sym -> IO v -> IO v -> IO v
2506iteM ite sym p mx my = do
2507  case asConstantPred p of
2508    Just True -> mx
2509    Just False -> my
2510    Nothing -> join $ ite sym p <$> mx <*> my
2511
2512-- | An iterated sequence of if/then/else operations.
2513--   The list of predicates and "then" results is
2514--   constructed as-needed. The "default" value
2515--   represents the result of the expression if
2516--   none of the predicates in the given list
2517--   is true.
2518iteList :: IsExprBuilder sym =>
2519  (sym -> Pred sym -> v -> v -> IO v) ->
2520  sym ->
2521  [(IO (Pred sym), IO v)] ->
2522  (IO v) ->
2523  IO v
2524iteList _ite _sym [] def = def
2525iteList ite sym ((mp,mx):xs) def =
2526  do p <- mp
2527     iteM ite sym p mx (iteList ite sym xs def)
2528
2529-- | A function that can be applied to symbolic arguments.
2530--
2531-- This type is used by some methods in classes 'IsExprBuilder' and
2532-- 'IsSymExprBuilder'.
2533type family SymFn sym :: Ctx BaseType -> BaseType -> Type
2534
2535-- | A class for extracting type representatives from symbolic functions
2536class IsSymFn fn where
2537  -- | Get the argument types of a function.
2538  fnArgTypes :: fn args ret -> Ctx.Assignment BaseTypeRepr args
2539
2540  -- | Get the return type of a function.
2541  fnReturnType :: fn args ret -> BaseTypeRepr ret
2542
2543
2544-- | Describes when we unfold the body of defined functions.
2545data UnfoldPolicy
2546  = NeverUnfold
2547      -- ^ What4 will not unfold the body of functions when applied to arguments
2548   | AlwaysUnfold
2549      -- ^ The function will be unfolded into its definition whenever it is
2550      --   applied to arguments
2551   | UnfoldConcrete
2552      -- ^ The function will be unfolded into its definition only if all the provided
2553      --   arguments are concrete.
2554 deriving (Eq, Ord, Show)
2555
2556-- | Evaluates an @UnfoldPolicy@ on a collection of arguments.
2557shouldUnfold :: IsExpr e => UnfoldPolicy -> Ctx.Assignment e args -> Bool
2558shouldUnfold AlwaysUnfold _ = True
2559shouldUnfold NeverUnfold _ = False
2560shouldUnfold UnfoldConcrete args = allFC baseIsConcrete args
2561
2562
2563-- | This exception is thrown if the user requests to make a bounded variable,
2564--   but gives incoherent or out-of-range bounds.
2565data InvalidRange where
2566  InvalidRange ::
2567    BaseTypeRepr bt ->
2568    Maybe (ConcreteValue bt) ->
2569    Maybe (ConcreteValue bt) ->
2570    InvalidRange
2571
2572instance Exception InvalidRange
2573instance Show InvalidRange where
2574  show (InvalidRange bt mlo mhi) =
2575    case bt of
2576      BaseIntegerRepr -> unwords ["invalid integer range", show mlo, show mhi]
2577      BaseRealRepr    -> unwords ["invalid real range", show mlo, show mhi]
2578      BaseBVRepr w    -> unwords ["invalid bitvector range", show w ++ "-bit", show mlo, show mhi]
2579      _               -> unwords ["invalid range for type", show bt]
2580
2581-- | This extends the interface for building expressions with operations
2582--   for creating new symbolic constants and functions.
2583class ( IsExprBuilder sym
2584      , IsSymFn (SymFn sym)
2585      , OrdF (SymExpr sym)
2586      ) => IsSymExprBuilder sym where
2587
2588  ----------------------------------------------------------------------
2589  -- Fresh variables
2590
2591  -- | Create a fresh top-level uninterpreted constant.
2592  freshConstant :: sym -> SolverSymbol -> BaseTypeRepr tp -> IO (SymExpr sym tp)
2593
2594  -- | Create a fresh latch variable.
2595  freshLatch    :: sym -> SolverSymbol -> BaseTypeRepr tp -> IO (SymExpr sym tp)
2596
2597  -- | Create a fresh bitvector value with optional lower and upper bounds (which bound the
2598  --   unsigned value of the bitvector). If provided, the bounds are inclusive.
2599  --   If inconsistent or out-of-range bounds are given, an @InvalidRange@ exception will be thrown.
2600  freshBoundedBV :: (1 <= w) =>
2601    sym ->
2602    SolverSymbol ->
2603    NatRepr w ->
2604    Maybe Natural {- ^ lower bound -} ->
2605    Maybe Natural {- ^ upper bound -} ->
2606    IO (SymBV sym w)
2607
2608  -- | Create a fresh bitvector value with optional lower and upper bounds (which bound the
2609  --   signed value of the bitvector).  If provided, the bounds are inclusive.
2610  --   If inconsistent or out-of-range bounds are given, an InvalidRange exception will be thrown.
2611  freshBoundedSBV :: (1 <= w) =>
2612    sym ->
2613    SolverSymbol ->
2614    NatRepr w ->
2615    Maybe Integer {- ^ lower bound -} ->
2616    Maybe Integer {- ^ upper bound -} ->
2617    IO (SymBV sym w)
2618
2619  -- | Create a fresh integer constant with optional lower and upper bounds.
2620  --   If provided, the bounds are inclusive.
2621  --   If inconsistent bounds are given, an InvalidRange exception will be thrown.
2622  freshBoundedInt ::
2623    sym ->
2624    SolverSymbol ->
2625    Maybe Integer {- ^ lower bound -} ->
2626    Maybe Integer {- ^ upper bound -} ->
2627    IO (SymInteger sym)
2628
2629  -- | Create a fresh real constant with optional lower and upper bounds.
2630  --   If provided, the bounds are inclusive.
2631  --   If inconsistent bounds are given, an InvalidRange exception will be thrown.
2632  freshBoundedReal ::
2633    sym ->
2634    SolverSymbol ->
2635    Maybe Rational {- ^ lower bound -} ->
2636    Maybe Rational {- ^ upper bound -} ->
2637    IO (SymReal sym)
2638
2639
2640  ----------------------------------------------------------------------
2641  -- Functions needs to support quantifiers.
2642
2643  -- | Creates a bound variable.
2644  --
2645  -- This will be treated as a free constant when appearing inside asserted
2646  -- expressions.  These are intended to be bound using quantifiers or
2647  -- symbolic functions.
2648  freshBoundVar :: sym -> SolverSymbol -> BaseTypeRepr tp -> IO (BoundVar sym tp)
2649
2650  -- | Return an expression that references the bound variable.
2651  varExpr :: sym -> BoundVar sym tp -> SymExpr sym tp
2652
2653  -- | @forallPred sym v e@ returns an expression that represents @forall v . e@.
2654  -- Throws a user error if bound var has already been used in a quantifier.
2655  forallPred :: sym
2656             -> BoundVar sym tp
2657             -> Pred sym
2658             -> IO (Pred sym)
2659
2660  -- | @existsPred sym v e@ returns an expression that represents @exists v . e@.
2661  -- Throws a user error if bound var has already been used in a quantifier.
2662  existsPred :: sym
2663             -> BoundVar sym tp
2664             -> Pred sym
2665             -> IO (Pred sym)
2666
2667  ----------------------------------------------------------------------
2668  -- SymFn operations.
2669
2670  -- | Return a function defined by an expression over bound
2671  -- variables. The predicate argument allows the user to specify when
2672  -- an application of the function should be unfolded and evaluated,
2673  -- e.g. to perform constant folding.
2674  definedFn :: sym
2675            -- ^ Symbolic interface
2676            -> SolverSymbol
2677            -- ^ The name to give a function (need not be unique)
2678            -> Ctx.Assignment (BoundVar sym) args
2679            -- ^ Bound variables to use as arguments for function.
2680            -> SymExpr sym ret
2681            -- ^ Operation defining result of defined function.
2682            -> UnfoldPolicy
2683            -- ^ Policy for unfolding on applications
2684            -> IO (SymFn sym args ret)
2685
2686  -- | Return a function defined by Haskell computation over symbolic expressions.
2687  inlineDefineFun :: Ctx.CurryAssignmentClass args
2688                  => sym
2689                     -- ^ Symbolic interface
2690                  -> SolverSymbol
2691                  -- ^ The name to give a function (need not be unique)
2692                  -> Ctx.Assignment BaseTypeRepr args
2693                  -- ^ Type signature for the arguments
2694                  -> UnfoldPolicy
2695                  -- ^ Policy for unfolding on applications
2696                  -> Ctx.CurryAssignment args (SymExpr sym) (IO (SymExpr sym ret))
2697                  -- ^ Operation defining result of defined function.
2698                  -> IO (SymFn sym args ret)
2699  inlineDefineFun sym nm tps policy f = do
2700    -- Create bound variables for function
2701    vars <- traverseFC (freshBoundVar sym emptySymbol) tps
2702    -- Call operation on expressions created from variables
2703    r <- Ctx.uncurryAssignment f (fmapFC (varExpr sym) vars)
2704    -- Define function
2705    definedFn sym nm vars r policy
2706
2707  -- | Create a new uninterpreted function.
2708  freshTotalUninterpFn :: forall args ret
2709                        .  sym
2710                          -- ^ Symbolic interface
2711                       -> SolverSymbol
2712                          -- ^ The name to give a function (need not be unique)
2713                       -> Ctx.Assignment BaseTypeRepr args
2714                          -- ^ Types of arguments expected by function
2715                       -> BaseTypeRepr ret
2716                           -- ^ Return type of function
2717                       -> IO (SymFn sym args ret)
2718
2719  -- | Apply a set of arguments to a symbolic function.
2720  applySymFn :: sym
2721                -- ^ Symbolic interface
2722             -> SymFn sym args ret
2723                -- ^ Function to call
2724             -> Ctx.Assignment (SymExpr sym) args
2725                -- ^ Arguments to function
2726             -> IO (SymExpr sym ret)
2727
2728-- | This returns true if the value corresponds to a concrete value.
2729baseIsConcrete :: forall e bt
2730                . IsExpr e
2731               => e bt
2732               -> Bool
2733baseIsConcrete x =
2734  case exprType x of
2735    BaseBoolRepr    -> isJust $ asConstantPred x
2736    BaseIntegerRepr -> isJust $ asInteger x
2737    BaseBVRepr _    -> isJust $ asBV x
2738    BaseRealRepr    -> isJust $ asRational x
2739    BaseFloatRepr _ -> False
2740    BaseStringRepr{} -> isJust $ asString x
2741    BaseComplexRepr -> isJust $ asComplex x
2742    BaseStructRepr _ -> case asStruct x of
2743        Just flds -> allFC baseIsConcrete flds
2744        Nothing -> False
2745    BaseArrayRepr _ _bt' -> do
2746      case asConstantArray x of
2747        Just x' -> baseIsConcrete x'
2748        Nothing -> False
2749
2750-- | Return some default value for each base type.
2751--   For numeric types, this is 0; for booleans, false;
2752--   for strings, the empty string.  Structs are
2753--   filled with default values for every field,
2754--   default arrays are constant arrays of default values.
2755baseDefaultValue :: forall sym bt
2756                  . IsExprBuilder sym
2757                 => sym
2758                 -> BaseTypeRepr bt
2759                 -> IO (SymExpr sym bt)
2760baseDefaultValue sym bt =
2761  case bt of
2762    BaseBoolRepr    -> return $! falsePred sym
2763    BaseIntegerRepr -> intLit sym 0
2764    BaseBVRepr w    -> bvLit sym w (BV.zero w)
2765    BaseRealRepr    -> return $! realZero sym
2766    BaseFloatRepr fpp -> floatPZero sym fpp
2767    BaseComplexRepr -> mkComplexLit sym (0 :+ 0)
2768    BaseStringRepr si -> stringEmpty sym si
2769    BaseStructRepr flds -> do
2770      let f :: BaseTypeRepr tp -> IO (SymExpr sym tp)
2771          f v = baseDefaultValue sym v
2772      mkStruct sym =<< traverseFC f flds
2773    BaseArrayRepr idx bt' -> do
2774      elt <- baseDefaultValue sym bt'
2775      constantArray sym idx elt
2776
2777-- | Return predicate equivalent to a Boolean.
2778backendPred :: IsExprBuilder sym => sym -> Bool -> Pred sym
2779backendPred sym True  = truePred  sym
2780backendPred sym False = falsePred sym
2781
2782-- | Create a value from a rational.
2783mkRational :: IsExprBuilder sym => sym -> Rational -> IO (SymCplx sym)
2784mkRational sym v = mkComplexLit sym (v :+ 0)
2785
2786-- | Create a value from an integer.
2787mkReal  :: (IsExprBuilder sym, Real a) => sym -> a -> IO (SymCplx sym)
2788mkReal sym v = mkRational sym (toRational v)
2789
2790-- | Return 1 if the predicate is true; 0 otherwise.
2791predToReal :: IsExprBuilder sym => sym -> Pred sym -> IO (SymReal sym)
2792predToReal sym p = do
2793  r1 <- realLit sym 1
2794  realIte sym p r1 (realZero sym)
2795
2796-- | Extract the value of a rational expression; fail if the
2797--   value is not a constant.
2798realExprAsRational :: (MonadFail m, IsExpr e) => e BaseRealType -> m Rational
2799realExprAsRational x = do
2800  case asRational x of
2801    Just r -> return r
2802    Nothing -> fail "Value is not a constant expression."
2803
2804-- | Extract the value of a complex expression, which is assumed
2805--   to be a constant real number.  Fail if the number has nonzero
2806--   imaginary component, or if it is not a constant.
2807cplxExprAsRational :: (MonadFail m, IsExpr e) => e BaseComplexType -> m Rational
2808cplxExprAsRational x = do
2809  case asComplex x of
2810    Just (r :+ i) -> do
2811      when (i /= 0) $
2812        fail "Complex value has an imaginary part."
2813      return r
2814    Nothing -> do
2815      fail "Complex value is not a constant expression."
2816
2817-- | Return a complex value as a constant integer if it exists.
2818cplxExprAsInteger :: (MonadFail m, IsExpr e) => e BaseComplexType -> m Integer
2819cplxExprAsInteger x = rationalAsInteger =<< cplxExprAsRational x
2820
2821-- | Return value as a constant integer if it exists.
2822rationalAsInteger :: MonadFail m => Rational -> m Integer
2823rationalAsInteger r = do
2824  when (denominator r /= 1) $ do
2825    fail "Value is not an integer."
2826  return (numerator r)
2827
2828-- | Return value as a constant integer if it exists.
2829realExprAsInteger :: (IsExpr e, MonadFail m) => e BaseRealType -> m Integer
2830realExprAsInteger x =
2831  rationalAsInteger =<< realExprAsRational x
2832
2833-- | Compute the conjunction of a sequence of predicates.
2834andAllOf :: IsExprBuilder sym
2835         => sym
2836         -> Fold s (Pred sym)
2837         -> s
2838         -> IO (Pred sym)
2839andAllOf sym f s = foldlMOf f (andPred sym) (truePred sym) s
2840
2841-- | Compute the disjunction of a sequence of predicates.
2842orOneOf :: IsExprBuilder sym
2843         => sym
2844         -> Fold s (Pred sym)
2845         -> s
2846         -> IO (Pred sym)
2847orOneOf sym f s = foldlMOf f (orPred sym) (falsePred sym) s
2848
2849-- | Return predicate that holds if value is non-zero.
2850isNonZero :: IsExprBuilder sym => sym -> SymCplx sym -> IO (Pred sym)
2851isNonZero sym v = cplxNe sym v =<< mkRational sym 0
2852
2853-- | Return predicate that holds if imaginary part of number is zero.
2854isReal :: IsExprBuilder sym => sym -> SymCplx sym -> IO (Pred sym)
2855isReal sym v = do
2856  i <- getImagPart sym v
2857  realEq sym i (realZero sym)
2858
2859-- | Divide one number by another.
2860--
2861--   @cplxDiv x y@ is undefined when @y@ is @0@.
2862cplxDiv :: IsExprBuilder sym
2863        => sym
2864        -> SymCplx sym
2865        -> SymCplx sym
2866        -> IO (SymCplx sym)
2867cplxDiv sym x y = do
2868  xr :+ xi <- cplxGetParts sym x
2869  yc@(yr :+ yi) <- cplxGetParts sym y
2870  case asRational <$> yc of
2871    (_ :+ Just 0) -> do
2872      zc <- (:+) <$> realDiv sym xr yr <*> realDiv sym xi yr
2873      mkComplex sym zc
2874    (Just 0 :+ _) -> do
2875      zc <- (:+) <$> realDiv sym xi yi <*> realDiv sym xr yi
2876      mkComplex sym zc
2877    _ -> do
2878      yr_abs <- realMul sym yr yr
2879      yi_abs <- realMul sym yi yi
2880      y_abs <- realAdd sym yr_abs yi_abs
2881
2882      zr_1 <- realMul sym xr yr
2883      zr_2 <- realMul sym xi yi
2884      zr <- realAdd sym zr_1 zr_2
2885
2886      zi_1 <- realMul sym xi yr
2887      zi_2 <- realMul sym xr yi
2888      zi <- realSub sym zi_1 zi_2
2889
2890      zc <- (:+) <$> realDiv sym zr y_abs <*> realDiv sym zi y_abs
2891      mkComplex sym zc
2892
2893-- | Helper function that returns the principal logarithm of input.
2894cplxLog' :: IsExprBuilder sym
2895         => sym -> SymCplx sym -> IO (Complex (SymReal sym))
2896cplxLog' sym x = do
2897  xr :+ xi <- cplxGetParts sym x
2898  -- Get the magnitude of the value.
2899  xm <- realHypot sym xr xi
2900  -- Get angle of complex number.
2901  xa <- realAtan2 sym xi xr
2902  -- Get log of magnitude
2903  zr <- realLog sym xm
2904  return $! zr :+ xa
2905
2906-- | Returns the principal logarithm of the input value.
2907--
2908--   @cplxLog x@ is undefined when @x@ is @0@, and has a
2909--   cut discontinuity along the negative real line.
2910cplxLog :: IsExprBuilder sym
2911        => sym -> SymCplx sym -> IO (SymCplx sym)
2912cplxLog sym x = mkComplex sym =<< cplxLog' sym x
2913
2914-- | Returns logarithm of input at a given base.
2915--
2916--   @cplxLogBase b x@ is undefined when @x@ is @0@.
2917cplxLogBase :: IsExprBuilder sym
2918            => Rational {- ^ Base for the logarithm -}
2919            -> sym
2920            -> SymCplx sym
2921            -> IO (SymCplx sym)
2922cplxLogBase base sym x = do
2923  b <- realLog sym =<< realLit sym base
2924  z <- traverse (\r -> realDiv sym r b) =<< cplxLog' sym x
2925  mkComplex sym z
2926
2927--------------------------------------------------------------------------
2928-- Relationship to concrete values
2929
2930-- | Return a concrete representation of a value, if it
2931--   is concrete.
2932asConcrete :: IsExpr e => e tp -> Maybe (ConcreteVal tp)
2933asConcrete x =
2934  case exprType x of
2935    BaseBoolRepr    -> ConcreteBool <$> asConstantPred x
2936    BaseIntegerRepr -> ConcreteInteger <$> asInteger x
2937    BaseRealRepr    -> ConcreteReal <$> asRational x
2938    BaseStringRepr _si -> ConcreteString <$> asString x
2939    BaseComplexRepr -> ConcreteComplex <$> asComplex x
2940    BaseBVRepr w    -> ConcreteBV w <$> asBV x
2941    BaseFloatRepr _ -> Nothing
2942    BaseStructRepr _ -> ConcreteStruct <$> (asStruct x >>= traverseFC asConcrete)
2943    BaseArrayRepr idx _tp -> do
2944      def <- asConstantArray x
2945      c_def <- asConcrete def
2946      -- TODO: what about cases where there are updates to the array?
2947      -- Passing Map.empty is probably wrong.
2948      pure (ConcreteArray idx c_def Map.empty)
2949
2950-- | Create a literal symbolic value from a concrete value.
2951concreteToSym :: IsExprBuilder sym => sym -> ConcreteVal tp -> IO (SymExpr sym tp)
2952concreteToSym sym = \case
2953   ConcreteBool True    -> return (truePred sym)
2954   ConcreteBool False   -> return (falsePred sym)
2955   ConcreteInteger x    -> intLit sym x
2956   ConcreteReal x       -> realLit sym x
2957   ConcreteString x     -> stringLit sym x
2958   ConcreteComplex x    -> mkComplexLit sym x
2959   ConcreteBV w x       -> bvLit sym w x
2960   ConcreteStruct xs    -> mkStruct sym =<< traverseFC (concreteToSym sym) xs
2961   ConcreteArray idxTy def xs0 -> go (Map.toAscList xs0) =<< constantArray sym idxTy =<< concreteToSym sym def
2962     where
2963     go [] arr = return arr
2964     go ((i,x):xs) arr =
2965        do arr' <- go xs arr
2966           i' <- traverseFC (concreteToSym sym) i
2967           x' <- concreteToSym sym x
2968           arrayUpdate sym arr' i' x'
2969
2970------------------------------------------------------------------------
2971-- muxNatRange
2972
2973{-# INLINABLE muxRange #-}
2974{- | This function is used for selecting a value from among potential
2975values in a range.
2976
2977@muxRange p ite f l h@ returns an expression denoting the value obtained
2978from the value @f i@ where @i@ is the smallest value in the range @[l..h]@
2979such that @p i@ is true.  If @p i@ is true for no such value, then
2980this returns the value @f h@. -}
2981muxRange :: (IsExpr e, Monad m) =>
2982   (Natural -> m (e BaseBoolType))
2983      {- ^ Returns predicate that holds if we have found the value we are looking
2984           for.  It is assumed that the predicate must hold for a unique integer in
2985           the range.
2986      -} ->
2987   (e BaseBoolType -> a -> a -> m a) {- ^ Ite function -} ->
2988   (Natural -> m a) {- ^ Function for concrete values -} ->
2989   Natural {- ^ Lower bound (inclusive) -} ->
2990   Natural {- ^ Upper bound (inclusive) -} ->
2991   m a
2992muxRange predFn iteFn f l h
2993  | l < h = do
2994    c <- predFn l
2995    case asConstantPred c of
2996      Just True  -> f l
2997      Just False -> muxRange predFn iteFn f (succ l) h
2998      Nothing ->
2999        do match_branch <- f l
3000           other_branch <- muxRange predFn iteFn f (succ l) h
3001           iteFn c match_branch other_branch
3002  | otherwise = f h
3003
3004-- | This provides an interface for converting between Haskell values and a
3005-- solver representation.
3006data SymEncoder sym v tp
3007   = SymEncoder { symEncoderType :: !(BaseTypeRepr tp)
3008                , symFromExpr :: !(sym -> SymExpr sym tp -> IO v)
3009                , symToExpr   :: !(sym -> v -> IO (SymExpr sym tp))
3010                }
3011
3012----------------------------------------------------------------------
3013-- Statistics
3014
3015-- | Statistics gathered on a running expression builder.  See
3016-- 'getStatistics'.
3017data Statistics
3018  = Statistics { statAllocs :: !Integer
3019                 -- ^ The number of times an expression node has been
3020                 -- allocated.
3021               , statNonLinearOps :: !Integer
3022                 -- ^ The number of non-linear operations, such as
3023                 -- multiplications, that have occurred.
3024               }
3025  deriving ( Show )
3026
3027zeroStatistics :: Statistics
3028zeroStatistics = Statistics { statAllocs = 0
3029                            , statNonLinearOps = 0 }
3030