1-----------------------------------------------------------------------------
2-- |
3-- Module    : Data.SBV.Compilers.C
4-- Copyright : (c) Levent Erkok
5-- License   : BSD3
6-- Maintainer: erkokl@gmail.com
7-- Stability : experimental
8--
9-- Compilation of symbolic programs to C
10-----------------------------------------------------------------------------
11
12{-# LANGUAGE CPP           #-}
13{-# LANGUAGE PatternGuards #-}
14
15{-# OPTIONS_GHC -Wall -Werror #-}
16
17module Data.SBV.Compilers.C(compileToC, compileToCLib, compileToC', compileToCLib') where
18
19import Control.DeepSeq                (rnf)
20import Data.Char                      (isSpace)
21import Data.List                      (nub, intercalate, intersperse)
22import Data.Maybe                     (isJust, isNothing, fromJust)
23import qualified Data.Foldable as F   (toList)
24import qualified Data.Set      as Set (member, union, unions, empty, toList, singleton, fromList)
25import System.FilePath                (takeBaseName, replaceExtension)
26import System.Random
27
28-- Work around the fact that GHC 8.4.1 started exporting <>.. Hmm..
29import Text.PrettyPrint.HughesPJ
30import qualified Text.PrettyPrint.HughesPJ as P ((<>))
31
32import Data.SBV.Core.Data
33import Data.SBV.Compilers.CodeGen
34
35import Data.SBV.Utils.PrettyNum   (chex, showCFloat, showCDouble)
36
37import GHC.Stack
38
39---------------------------------------------------------------------------
40-- * API
41---------------------------------------------------------------------------
42
43-- | Given a symbolic computation, render it as an equivalent collection of files
44-- that make up a C program:
45--
46--   * The first argument is the directory name under which the files will be saved. To save
47--     files in the current directory pass @'Just' \".\"@. Use 'Nothing' for printing to stdout.
48--
49--   * The second argument is the name of the C function to generate.
50--
51--   * The final argument is the function to be compiled.
52--
53-- Compilation will also generate a @Makefile@,  a header file, and a driver (test) program, etc. As a
54-- result, we return whatever the code-gen function returns. Most uses should simply have @()@ as
55-- the return type here, but the value can be useful if you want to chain the result of
56-- one compilation act to the next.
57compileToC :: Maybe FilePath -> String -> SBVCodeGen a -> IO a
58compileToC mbDirName nm f = do (retVal, cfg, bundle) <- compileToC' nm f
59                               renderCgPgmBundle mbDirName (cfg, bundle)
60                               return retVal
61
62-- | Lower level version of 'compileToC', producing a 'CgPgmBundle'
63compileToC' :: String -> SBVCodeGen a -> IO (a, CgConfig, CgPgmBundle)
64compileToC' nm f = do rands <- randoms `fmap` newStdGen
65                      codeGen SBVToC (defaultCgConfig { cgDriverVals = rands }) nm f
66
67-- | Create code to generate a library archive (.a) from given symbolic functions. Useful when generating code
68-- from multiple functions that work together as a library.
69--
70--   * The first argument is the directory name under which the files will be saved. To save
71--     files in the current directory pass @'Just' \".\"@. Use 'Nothing' for printing to stdout.
72--
73--   * The second argument is the name of the archive to generate.
74--
75--   * The third argument is the list of functions to include, in the form of function-name/code pairs, similar
76--     to the second and third arguments of 'compileToC', except in a list.
77compileToCLib :: Maybe FilePath -> String -> [(String, SBVCodeGen a)] -> IO [a]
78compileToCLib mbDirName libName comps = do (retVal, cfg, pgm) <- compileToCLib' libName comps
79                                           renderCgPgmBundle mbDirName (cfg, pgm)
80                                           return retVal
81
82-- | Lower level version of 'compileToCLib', producing a 'CgPgmBundle'
83compileToCLib' :: String -> [(String, SBVCodeGen a)] -> IO ([a], CgConfig, CgPgmBundle)
84compileToCLib' libName comps = do resCfgBundles <- mapM (uncurry compileToC') comps
85                                  let (finalCfg, finalPgm) = mergeToLib libName [(c, b) | (_, c, b) <- resCfgBundles]
86                                  return ([r | (r, _, _) <- resCfgBundles], finalCfg, finalPgm)
87
88---------------------------------------------------------------------------
89-- * Implementation
90---------------------------------------------------------------------------
91
92-- token for the target language
93data SBVToC = SBVToC
94
95instance CgTarget SBVToC where
96  targetName _ = "C"
97  translate _  = cgen
98
99-- Unexpected input, or things we will probably never support
100die :: String -> a
101die msg = error $ "SBV->C: Unexpected: " ++ msg
102
103-- Unsupported features, or features TBD
104tbd :: String -> a
105tbd msg = error $ "SBV->C: Not yet supported: " ++ msg
106
107cgen :: CgConfig -> String -> CgState -> Result -> CgPgmBundle
108cgen cfg nm st sbvProg
109   -- we rnf the main pg and the sig to make sure any exceptions in type conversion pop-out early enough
110   -- this is purely cosmetic, of course..
111   = rnf (render sig) `seq` rnf (render (vcat body)) `seq` result
112  where result = CgPgmBundle bundleKind
113                        $ filt [ ("Makefile",  (CgMakefile flags, [genMake (cgGenDriver cfg) nm nmd flags]))
114                               , (nm  ++ ".h", (CgHeader [sig],   [genHeader bundleKind nm [sig] extProtos]))
115                               , (nmd ++ ".c", (CgDriver,         genDriver cfg randVals nm ins outs mbRet))
116                               , (nm  ++ ".c", (CgSource,         body))
117                               ]
118
119        (body, flagsNeeded) = genCProg cfg nm sig sbvProg ins outs mbRet extDecls
120
121        bundleKind = (cgInteger cfg, cgReal cfg)
122
123        randVals = cgDriverVals cfg
124
125        filt xs  = [c | c@(_, (k, _)) <- xs, need k]
126          where need k | isCgDriver   k = cgGenDriver cfg
127                       | isCgMakefile k = cgGenMakefile cfg
128                       | True           = True
129
130        nmd      = nm ++ "_driver"
131        sig      = pprCFunHeader nm ins outs mbRet
132        ins      = cgInputs st
133        outs     = cgOutputs st
134        mbRet    = case cgReturns st of
135                     []           -> Nothing
136                     [CgAtomic o] -> Just o
137                     [CgArray _]  -> tbd "Non-atomic return values"
138                     _            -> tbd "Multiple return values"
139        extProtos = case cgPrototypes st of
140                     [] -> empty
141                     xs -> vcat $ text "/* User given prototypes: */" : map text xs
142        extDecls  = case cgDecls st of
143                     [] -> empty
144                     xs -> vcat $ text "/* User given declarations: */" : map text xs
145        flags    = flagsNeeded ++ cgLDFlags st
146
147-- | Pretty print a functions type. If there is only one output, we compile it
148-- as a function that returns that value. Otherwise, we compile it as a void function
149-- that takes return values as pointers to be updated.
150pprCFunHeader :: String -> [(String, CgVal)] -> [(String, CgVal)] -> Maybe SV -> Doc
151pprCFunHeader fn ins outs mbRet = retType <+> text fn P.<> parens (fsep (punctuate comma (map mkParam ins ++ map mkPParam outs)))
152  where retType = case mbRet of
153                   Nothing -> text "void"
154                   Just sv -> pprCWord False sv
155
156mkParam, mkPParam :: (String, CgVal) -> Doc
157mkParam  (n, CgAtomic sv)     = pprCWord True  sv <+> text n
158mkParam  (_, CgArray  [])     = die "mkParam: CgArray with no elements!"
159mkParam  (n, CgArray  (sv:_)) = pprCWord True  sv <+> text "*" P.<> text n
160mkPParam (n, CgAtomic sv)     = pprCWord False sv <+> text "*" P.<> text n
161mkPParam (_, CgArray  [])     = die "mPkParam: CgArray with no elements!"
162mkPParam (n, CgArray  (sv:_)) = pprCWord False sv <+> text "*" P.<> text n
163
164-- | Renders as "const SWord8 s0", etc. the first parameter is the width of the typefield
165declSV :: Int -> SV -> Doc
166declSV w sv = text "const" <+> pad (showCType sv) <+> text (show sv)
167  where pad s = text $ s ++ replicate (w - length s) ' '
168
169-- | Return the proper declaration and the result as a pair. No consts
170declSVNoConst :: Int -> SV -> (Doc, Doc)
171declSVNoConst w sv = (text "     " <+> pad (showCType sv), text (show sv))
172  where pad s = text $ s ++ replicate (w - length s) ' '
173
174-- | Renders as "s0", etc, or the corresponding constant
175showSV :: CgConfig -> [(SV, CV)] -> SV -> Doc
176showSV cfg consts sv
177  | sv == falseSV                 = text "false"
178  | sv == trueSV                  = text "true"
179  | Just cv <- sv `lookup` consts = mkConst cfg cv
180  | True                          = text $ show sv
181
182-- | Words as it would map to a C word
183pprCWord :: HasKind a => Bool -> a -> Doc
184pprCWord cnst v = (if cnst then text "const" else empty) <+> text (showCType v)
185
186-- | Almost a "show", but map "SWord1" to "SBool"
187-- which is used for extracting one-bit words.
188showCType :: HasKind a => a -> String
189showCType i = case kindOf i of
190                KBounded False 1 -> "SBool"
191                k                -> show k
192
193-- | The printf specifier for the type
194specifier :: CgConfig -> SV -> Doc
195specifier cfg sv = case kindOf sv of
196                     KBool         -> spec (False, 1)
197                     KBounded b i  -> spec (b, i)
198                     KUnbounded    -> spec (True, fromJust (cgInteger cfg))
199                     KReal         -> specF (fromJust (cgReal cfg))
200                     KFloat        -> specF CgFloat
201                     KDouble       -> specF CgDouble
202                     KString       -> text "%s"
203                     KChar         -> text "%c"
204                     KFP{}         -> die   "arbitrary float sort"
205                     KList k       -> die $ "list sort: " ++ show k
206                     KSet  k       -> die $ "set sort: " ++ show k
207                     KUserSort s _ -> die $ "user sort: " ++ s
208                     KTuple k      -> die $ "tuple sort: " ++ show k
209                     KMaybe  k     -> die $ "maybe sort: "  ++ show k
210                     KEither k1 k2 -> die $ "either sort: " ++ show (k1, k2)
211  where u8InHex = cgShowU8InHex cfg
212
213        spec :: (Bool, Int) -> Doc
214        spec (False,  1) = text "%d"
215        spec (False,  8)
216          | u8InHex      = text "0x%02\"PRIx8\""
217          | True         = text "%\"PRIu8\""
218        spec (True,   8) = text "%\"PRId8\""
219        spec (False, 16) = text "0x%04\"PRIx16\"U"
220        spec (True,  16) = text "%\"PRId16\""
221        spec (False, 32) = text "0x%08\"PRIx32\"UL"
222        spec (True,  32) = text "%\"PRId32\"L"
223        spec (False, 64) = text "0x%016\"PRIx64\"ULL"
224        spec (True,  64) = text "%\"PRId64\"LL"
225        spec (s, sz)     = die $ "Format specifier at type " ++ (if s then "SInt" else "SWord") ++ show sz
226
227        specF :: CgSRealType -> Doc
228        specF CgFloat      = text "%a"
229        specF CgDouble     = text "%a"
230        specF CgLongDouble = text "%Lf"
231
232-- | Make a constant value of the given type. We don't check for out of bounds here, as it should not be needed.
233--   There are many options here, using binary, decimal, etc. We simply use decimal for values 8-bits or less,
234--   and hex otherwise.
235mkConst :: CgConfig -> CV -> Doc
236mkConst cfg  (CV KReal (CAlgReal (AlgRational _ r))) = double (fromRational r :: Double) P.<> sRealSuffix (fromJust (cgReal cfg))
237  where sRealSuffix CgFloat      = text "F"
238        sRealSuffix CgDouble     = empty
239        sRealSuffix CgLongDouble = text "L"
240mkConst cfg (CV KUnbounded       (CInteger i)) = showSizedConst (cgShowU8InHex cfg) i (True, fromJust (cgInteger cfg))
241mkConst cfg (CV (KBounded sg sz) (CInteger i)) = showSizedConst (cgShowU8InHex cfg) i (sg,   sz)
242mkConst cfg (CV KBool            (CInteger i)) = showSizedConst (cgShowU8InHex cfg) i (False, 1)
243mkConst _   (CV KFloat           (CFloat f))   = text $ showCFloat f
244mkConst _   (CV KDouble          (CDouble d))  = text $ showCDouble d
245mkConst _   (CV KString          (CString s))  = text $ show s
246mkConst _   (CV KChar            (CChar c))    = text $ show c
247mkConst _   cv                                 = die $ "mkConst: " ++ show cv
248
249showSizedConst :: Bool -> Integer -> (Bool, Int) -> Doc
250showSizedConst _   i   (False,  1) = text (if i == 0 then "false" else "true")
251showSizedConst u8h i t@(False,  8)
252  | u8h                            = text (chex False True t i)
253  | True                           = integer i
254showSizedConst _   i   (True,   8) = integer i
255showSizedConst _   i t@(False, 16) = text $ chex False True t i
256showSizedConst _   i t@(True,  16) = text $ chex False True t i
257showSizedConst _   i t@(False, 32) = text $ chex False True t i
258showSizedConst _   i t@(True,  32) = text $ chex False True t i
259showSizedConst _   i t@(False, 64) = text $ chex False True t i
260showSizedConst _   i t@(True,  64) = text $ chex False True t i
261showSizedConst _   i   (True,  1)  = die $ "Signed 1-bit value " ++ show i
262showSizedConst _   i   (s, sz)     = die $ "Constant " ++ show i ++ " at type " ++ (if s then "SInt" else "SWord") ++ show sz
263
264-- | Generate a makefile. The first argument is True if we have a driver.
265genMake :: Bool -> String -> String -> [String] -> Doc
266genMake ifdr fn dn ldFlags = foldr1 ($$) [l | (True, l) <- lns]
267 where ifld = not (null ldFlags)
268       ld | ifld = text "${LDFLAGS}"
269          | True = empty
270       lns = [ (True, text "# Makefile for" <+> nm P.<> text ". Automatically generated by SBV. Do not edit!")
271             , (True, text "")
272             , (True, text "# include any user-defined .mk file in the current directory.")
273             , (True, text "-include *.mk")
274             , (True, text "")
275             , (True, text "CC?=gcc")
276             , (True, text "CCFLAGS?=-Wall -O3 -DNDEBUG -fomit-frame-pointer")
277             , (ifld, text "LDFLAGS?=" P.<> text (unwords ldFlags))
278             , (True, text "")
279             , (ifdr, text "all:" <+> nmd)
280             , (ifdr, text "")
281             , (True, nmo P.<> text (": " ++ ppSameLine (hsep [nmc, nmh])))
282             , (True, text "\t${CC} ${CCFLAGS}" <+> text "-c $< -o $@")
283             , (True, text "")
284             , (ifdr, nmdo P.<> text ":" <+> nmdc)
285             , (ifdr, text "\t${CC} ${CCFLAGS}" <+> text "-c $< -o $@")
286             , (ifdr, text "")
287             , (ifdr, nmd P.<> text (": " ++ ppSameLine (hsep [nmo, nmdo])))
288             , (ifdr, text "\t${CC} ${CCFLAGS}" <+> text "$^ -o $@" <+> ld)
289             , (ifdr, text "")
290             , (True, text "clean:")
291             , (True, text "\trm -f *.o")
292             , (True, text "")
293             , (ifdr, text "veryclean: clean")
294             , (ifdr, text "\trm -f" <+> nmd)
295             , (ifdr, text "")
296             ]
297       nm   = text fn
298       nmd  = text dn
299       nmh  = nm P.<> text ".h"
300       nmc  = nm P.<> text ".c"
301       nmo  = nm P.<> text ".o"
302       nmdc = nmd P.<> text ".c"
303       nmdo = nmd P.<> text ".o"
304
305-- | Generate the header
306genHeader :: (Maybe Int, Maybe CgSRealType) -> String -> [Doc] -> Doc -> Doc
307genHeader (ik, rk) fn sigs protos =
308     text "/* Header file for" <+> nm P.<> text ". Automatically generated by SBV. Do not edit! */"
309  $$ text ""
310  $$ text "#ifndef" <+> tag
311  $$ text "#define" <+> tag
312  $$ text ""
313  $$ text "#include <stdio.h>"
314  $$ text "#include <stdlib.h>"
315  $$ text "#include <inttypes.h>"
316  $$ text "#include <stdint.h>"
317  $$ text "#include <stdbool.h>"
318  $$ text "#include <string.h>"
319  $$ text "#include <math.h>"
320  $$ text ""
321  $$ text "/* The boolean type */"
322  $$ text "typedef bool SBool;"
323  $$ text ""
324  $$ text "/* The float type */"
325  $$ text "typedef float SFloat;"
326  $$ text ""
327  $$ text "/* The double type */"
328  $$ text "typedef double SDouble;"
329  $$ text ""
330  $$ text "/* Unsigned bit-vectors */"
331  $$ text "typedef uint8_t  SWord8;"
332  $$ text "typedef uint16_t SWord16;"
333  $$ text "typedef uint32_t SWord32;"
334  $$ text "typedef uint64_t SWord64;"
335  $$ text ""
336  $$ text "/* Signed bit-vectors */"
337  $$ text "typedef int8_t  SInt8;"
338  $$ text "typedef int16_t SInt16;"
339  $$ text "typedef int32_t SInt32;"
340  $$ text "typedef int64_t SInt64;"
341  $$ text ""
342  $$ imapping
343  $$ rmapping
344  $$ text ("/* Entry point prototype" ++ plu ++ ": */")
345  $$ vcat (map (P.<> semi) sigs)
346  $$ text ""
347  $$ protos
348  $$ text "#endif /*" <+> tag <+> text "*/"
349  $$ text ""
350 where nm  = text fn
351       tag = text "__" P.<> nm P.<> text "__HEADER_INCLUDED__"
352       plu = if length sigs /= 1 then "s" else ""
353       imapping = case ik of
354                    Nothing -> empty
355                    Just i  ->    text "/* User requested mapping for SInteger.                                 */"
356                               $$ text "/* NB. Loss of precision: Target type is subject to modular arithmetic. */"
357                               $$ text ("typedef SInt" ++ show i ++ " SInteger;")
358                               $$ text ""
359       rmapping = case rk of
360                    Nothing -> empty
361                    Just t  ->    text "/* User requested mapping for SReal.                          */"
362                               $$ text "/* NB. Loss of precision: Target type is subject to rounding. */"
363                               $$ text ("typedef " ++ show t ++ " SReal;")
364                               $$ text ""
365
366sepIf :: Bool -> Doc
367sepIf b = if b then text "" else empty
368
369-- | Generate an example driver program
370genDriver :: CgConfig -> [Integer] -> String -> [(String, CgVal)] -> [(String, CgVal)] -> Maybe SV -> [Doc]
371genDriver cfg randVals fn inps outs mbRet = [pre, header, body, post]
372 where pre    =  text "/* Example driver program for" <+> nm P.<> text ". */"
373              $$ text "/* Automatically generated by SBV. Edit as you see fit! */"
374              $$ text ""
375              $$ text "#include <stdio.h>"
376       header =  text "#include" <+> doubleQuotes (nm P.<> text ".h")
377              $$ text ""
378              $$ text "int main(void)"
379              $$ text "{"
380       body   =  text ""
381              $$ nest 2 (   vcat (map mkInp pairedInputs)
382                      $$ vcat (map mkOut outs)
383                      $$ sepIf (not (null [() | (_, _, CgArray{}) <- pairedInputs]) || not (null outs))
384                      $$ call
385                      $$ text ""
386                      $$ (case mbRet of
387                              Just sv -> text "printf" P.<> parens (printQuotes (fcall <+> text "=" <+> specifier cfg sv P.<> text "\\n")
388                                                                              P.<> comma <+> resultVar) P.<> semi
389                              Nothing -> text "printf" P.<> parens (printQuotes (fcall <+> text "->\\n")) P.<> semi)
390                      $$ vcat (map display outs)
391                      )
392       post   =   text ""
393              $+$ nest 2 (text "return 0" P.<> semi)
394              $$  text "}"
395              $$  text ""
396       nm = text fn
397       pairedInputs = matchRands randVals inps
398       matchRands _      []                                 = []
399       matchRands []     _                                  = die "Run out of driver values!"
400       matchRands (r:rs) ((n, CgAtomic sv)            : cs) = ([mkRVal sv r], n, CgAtomic sv) : matchRands rs cs
401       matchRands _      ((n, CgArray [])             : _ ) = die $ "Unsupported empty array input " ++ show n
402       matchRands rs     ((n, a@(CgArray sws@(sv:_))) : cs)
403          | length frs /= l                                 = die "Run out of driver values!"
404          | True                                            = (map (mkRVal sv) frs, n, a) : matchRands srs cs
405          where l          = length sws
406                (frs, srs) = splitAt l rs
407       mkRVal sv r = mkConst cfg $ mkConstCV (kindOf sv) r
408       mkInp (_,  _, CgAtomic{})         = empty  -- constant, no need to declare
409       mkInp (_,  n, CgArray [])         = die $ "Unsupported empty array value for " ++ show n
410       mkInp (vs, n, CgArray sws@(sv:_)) =  pprCWord True sv <+> text n P.<> brackets (int (length sws)) <+> text "= {"
411                                                      $$ nest 4 (fsep (punctuate comma (align vs)))
412                                                      $$ text "};"
413                                         $$ text ""
414                                         $$ text "printf" P.<> parens (printQuotes (text "Contents of input array" <+> text n P.<> text ":\\n")) P.<> semi
415                                         $$ display (n, CgArray sws)
416                                         $$ text ""
417       mkOut (v, CgAtomic sv)            = pprCWord False sv <+> text v P.<> semi
418       mkOut (v, CgArray [])             = die $ "Unsupported empty array value for " ++ show v
419       mkOut (v, CgArray sws@(sv:_))     = pprCWord False sv <+> text v P.<> brackets (int (length sws)) P.<> semi
420       resultVar = text "__result"
421       call = case mbRet of
422                Nothing -> fcall P.<> semi
423                Just sv -> pprCWord True sv <+> resultVar <+> text "=" <+> fcall P.<> semi
424       fcall = nm P.<> parens (fsep (punctuate comma (map mkCVal pairedInputs ++ map mkOVal outs)))
425       mkCVal ([v], _, CgAtomic{}) = v
426       mkCVal (vs,  n, CgAtomic{}) = die $ "Unexpected driver value computed for " ++ show n ++ render (hcat vs)
427       mkCVal (_,   n, CgArray{})  = text n
428       mkOVal (n, CgAtomic{})      = text "&" P.<> text n
429       mkOVal (n, CgArray{})       = text n
430       display (n, CgAtomic sv)         = text "printf" P.<> parens (printQuotes (text " " <+> text n <+> text "=" <+> specifier cfg sv
431                                                                                P.<> text "\\n") P.<> comma <+> text n) P.<> semi
432       display (n, CgArray [])         =  die $ "Unsupported empty array value for " ++ show n
433       display (n, CgArray sws@(sv:_)) =   text "int" <+> nctr P.<> semi
434                                        $$ text "for(" P.<> nctr <+> text "= 0;" <+> nctr <+> text "<" <+> int len <+> text "; ++" P.<> nctr P.<> text ")"
435                                        $$ nest 2 (text "printf" P.<> parens (printQuotes (text " " <+> entrySpec <+> text "=" <+> spec P.<> text "\\n")
436                                                                 P.<> comma <+> nctr <+> comma P.<> entry) P.<> semi)
437                  where nctr      = text n P.<> text "_ctr"
438                        entry     = text n P.<> text "[" P.<> nctr P.<> text "]"
439                        entrySpec = text n P.<> text "[%" P.<> int tab P.<> text "d]"
440                        spec      = specifier cfg sv
441                        len       = length sws
442                        tab       = length $ show (len - 1)
443
444-- | Generate the C program
445genCProg :: CgConfig -> String -> Doc -> Result -> [(String, CgVal)] -> [(String, CgVal)] -> Maybe SV -> Doc -> ([Doc], [String])
446genCProg cfg fn proto (Result kindInfo _tvals _ovals cgs ins (_, preConsts) tbls arrs _uis _axioms (SBVPgm asgns) cstrs origAsserts _) inVars outVars mbRet extDecls
447  | isNothing (cgInteger cfg) && KUnbounded `Set.member` kindInfo
448  = error $ "SBV->C: Unbounded integers are not supported by the C compiler."
449          ++ "\nUse 'cgIntegerSize' to specify a fixed size for SInteger representation."
450  | KString `Set.member` kindInfo
451  = error "SBV->C: Strings are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
452  | KChar `Set.member` kindInfo
453  = error "SBV->C: Characters are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
454  | any isSet kindInfo
455  = error "SBV->C: Sets (SSet) are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
456  | any isList kindInfo
457  = error "SBV->C: Lists (SList) are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
458  | any isTuple kindInfo
459  = error "SBV->C: Tuples (STupleN) are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
460  | any isMaybe kindInfo
461  = error "SBV->C: Optional (SMaybe) values are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
462  | any isEither kindInfo
463  = error "SBV->C: Either (SEither) values are currently not supported by the C compiler. Please get in touch if you'd like support for this feature!"
464  | isNothing (cgReal cfg) && KReal `Set.member` kindInfo
465  = error $ "SBV->C: SReal values are not supported by the C compiler."
466          ++ "\nUse 'cgSRealType' to specify a custom type for SReal representation."
467  | not (null usorts)
468  = error $ "SBV->C: Cannot compile functions with uninterpreted sorts: " ++ intercalate ", " usorts
469  | not (null cstrs)
470  = tbd "Explicit constraints"
471  | not (null arrs)
472  = tbd "User specified arrays"
473  | needsExistentials (map fst (fst ins))
474  = error "SBV->C: Cannot compile functions with existentially quantified variables."
475  | True
476  = ([pre, header, post], flagsNeeded)
477 where asserts | cgIgnoreAsserts cfg = []
478               | True                = origAsserts
479
480       usorts = [s | KUserSort s _ <- Set.toList kindInfo, s /= "RoundingMode"] -- No support for any sorts other than RoundingMode!
481
482       pre    =  text "/* File:" <+> doubleQuotes (nm P.<> text ".c") P.<> text ". Automatically generated by SBV. Do not edit! */"
483              $$ text ""
484
485       header = text "#include" <+> doubleQuotes (nm P.<> text ".h")
486
487       post   = text ""
488             $$ vcat (map codeSeg cgs)
489             $$ extDecls
490             $$ proto
491             $$ text "{"
492             $$ text ""
493             $$ nest 2 (   vcat (concatMap (genIO True . (\v -> (isAlive v, v))) inVars)
494                        $$ vcat (merge (map genTbl tbls) (map genAsgn assignments) (map genAssert asserts))
495                        $$ sepIf (not (null assignments) || not (null tbls))
496                        $$ vcat (concatMap (genIO False) (zip (repeat True) outVars))
497                        $$ maybe empty mkRet mbRet
498                       )
499             $$ text "}"
500             $$ text ""
501
502       nm = text fn
503
504       assignments = F.toList asgns
505
506       -- Do we need any linker flags for C?
507       flagsNeeded = nub $ concatMap (getLDFlag . opRes) assignments
508          where opRes (sv, SBVApp o _) = (o, kindOf sv)
509
510       codeSeg (fnm, ls) =  text "/* User specified custom code for" <+> doubleQuotes (text fnm) <+> text "*/"
511                         $$ vcat (map text ls)
512                         $$ text ""
513
514       typeWidth = getMax 0 $ [len (kindOf s) | (s, _) <- assignments] ++ [len (kindOf s) | (_, NamedSymVar s _) <- fst ins]
515                where len KReal{}            = 5
516                      len KFloat{}           = 6 -- SFloat
517                      len KDouble{}          = 7 -- SDouble
518                      len KString{}          = 7 -- SString
519                      len KChar{}            = 5 -- SChar
520                      len KUnbounded{}       = 8
521                      len KBool              = 5 -- SBool
522                      len (KBounded False n) = 5 + length (show n) -- SWordN
523                      len (KBounded True  n) = 4 + length (show n) -- SIntN
524                      len KFP{}              = die   "Arbitrary float."
525                      len (KList s)          = die $ "List sort: " ++ show s
526                      len (KSet  s)          = die $ "Set sort: " ++ show s
527                      len (KTuple s)         = die $ "Tuple sort: " ++ show s
528                      len (KMaybe k)         = die $ "Maybe sort: " ++ show k
529                      len (KEither k1 k2)    = die $ "Either sort: " ++ show (k1, k2)
530                      len (KUserSort s _)    = die $ "Uninterpreted sort: " ++ s
531
532                      getMax 8 _      = 8  -- 8 is the max we can get with SInteger, so don't bother looking any further
533                      getMax m []     = m
534                      getMax m (x:xs) = getMax (m `max` x) xs
535
536       consts = (falseSV, falseCV) : (trueSV, trueCV) : preConsts
537
538       isConst s = isJust (lookup s consts)
539
540       -- TODO: The following is brittle. We should really have a function elsewhere
541       -- that walks the SBVExprs and collects the SWs together.
542       usedVariables = Set.unions (retSWs : map usedCgVal outVars ++ map usedAsgn assignments)
543         where retSWs = maybe Set.empty Set.singleton mbRet
544
545               usedCgVal (_, CgAtomic s)  = Set.singleton s
546               usedCgVal (_, CgArray ss)  = Set.fromList ss
547               usedAsgn  (_, SBVApp o ss) = Set.union (opSWs o) (Set.fromList ss)
548
549               opSWs (LkUp _ a b)             = Set.fromList [a, b]
550               opSWs (IEEEFP (FP_Cast _ _ s)) = Set.singleton s
551               opSWs _                        = Set.empty
552
553       isAlive :: (String, CgVal) -> Bool
554       isAlive (_, CgAtomic sv) = sv `Set.member` usedVariables
555       isAlive (_, _)           = True
556
557       genIO :: Bool -> (Bool, (String, CgVal)) -> [Doc]
558       genIO True  (alive, (cNm, CgAtomic sv)) = [declSV typeWidth sv  <+> text "=" <+> text cNm P.<> semi               | alive]
559       genIO False (alive, (cNm, CgAtomic sv)) = [text "*" P.<> text cNm <+> text "=" <+> showSV cfg consts sv P.<> semi | alive]
560       genIO isInp (_,     (cNm, CgArray sws)) = zipWith genElt sws [(0::Int)..]
561         where genElt sv i
562                 | isInp = declSV typeWidth sv <+> text "=" <+> text entry       P.<> semi
563                 | True  = text entry          <+> text "=" <+> showSV cfg consts sv P.<> semi
564                 where entry = cNm ++ "[" ++ show i ++ "]"
565
566       mkRet sv = text "return" <+> showSV cfg consts sv P.<> semi
567
568       genTbl :: ((Int, Kind, Kind), [SV]) -> (Int, Doc)
569       genTbl ((i, _, k), elts) =  (location, static <+> text "const" <+> text (show k) <+> text ("table" ++ show i) P.<> text "[] = {"
570                                              $$ nest 4 (fsep (punctuate comma (align (map (showSV cfg consts) elts))))
571                                              $$ text "};")
572         where static   = if location == -1 then text "static" else empty
573               location = maximum (-1 : map getNodeId elts)
574
575       getNodeId s@(SV _ (NodeId n)) | isConst s = -1
576                                     | True      = n
577
578       genAsgn :: (SV, SBVExpr) -> (Int, Doc)
579       genAsgn (sv, n) = (getNodeId sv, ppExpr cfg consts n (declSV typeWidth sv) (declSVNoConst typeWidth sv) P.<> semi)
580
581       -- merge tables intermixed with assignments and assertions, paying attention to putting tables as
582       -- early as possible and tables right after.. Note that the assignment list (second argument) is sorted on its order
583       merge :: [(Int, Doc)] -> [(Int, Doc)] -> [(Int, Doc)] -> [Doc]
584       merge tables asgnments asrts = map snd $ merge2 asrts (merge2 tables asgnments)
585         where merge2 []               as                  = as
586               merge2 ts               []                  = ts
587               merge2 ts@((i, t):trest) as@((i', a):arest)
588                 | i < i'                                 = (i,  t)  : merge2 trest as
589                 | True                                   = (i', a) : merge2 ts arest
590
591       genAssert (msg, cs, sv) = (getNodeId sv, doc)
592         where doc =     text "/* ASSERTION:" <+> text msg
593                     $$  maybe empty (vcat . map text) (locInfo (getCallStack `fmap` cs))
594                     $$  text " */"
595                     $$  text "if" P.<> parens (showSV cfg consts sv)
596                     $$  text "{"
597                     $+$ nest 2 (vcat [errOut, text "exit(-1);"])
598                     $$  text "}"
599                     $$  text ""
600               errOut = text $ "fprintf(stderr, \"%s:%d:ASSERTION FAILED: " ++ msg ++ "\\n\", __FILE__, __LINE__);"
601               locInfo (Just ps) = let loc (f, sl) = concat [srcLocFile sl, ":", show (srcLocStartLine sl), ":", show (srcLocStartCol sl), ":", f ]
602                                   in case map loc ps of
603                                         []     -> Nothing
604                                         (f:rs) -> Just $ (" * SOURCE   : " ++ f) : map (" *            " ++)  rs
605               locInfo _         = Nothing
606
607handlePB :: PBOp -> [Doc] -> Doc
608handlePB o args = case o of
609                    PB_AtMost  k -> addIf (repeat 1) <+> text "<=" <+> int k
610                    PB_AtLeast k -> addIf (repeat 1) <+> text ">=" <+> int k
611                    PB_Exactly k -> addIf (repeat 1) <+> text "==" <+> int k
612                    PB_Le cs   k -> addIf cs         <+> text "<=" <+> int k
613                    PB_Ge cs   k -> addIf cs         <+> text ">=" <+> int k
614                    PB_Eq cs   k -> addIf cs         <+> text "==" <+> int k
615
616  where addIf :: [Int] -> Doc
617        addIf cs = parens $ fsep $ intersperse (text "+") [parens (a <+> text "?" <+> int c <+> text ":" <+> int 0) | (a, c) <- zip args cs]
618
619handleIEEE :: FPOp -> [(SV, CV)] -> [(SV, Doc)] -> Doc -> Doc
620handleIEEE w consts as var = cvt w
621  where same f                   = (f, f)
622        named fnm dnm f          = (f fnm, f dnm)
623
624        cvt (FP_Cast from to m)     = case checkRM (m `lookup` consts) of
625                                        Nothing          -> cast $ \[a] -> parens (text (show to)) <+> rnd a
626                                        Just (Left  msg) -> die msg
627                                        Just (Right msg) -> tbd msg
628                                      where -- if we're converting from float to some integral like; first use rint/rintf to do the internal conversion and then cast.
629                                            rnd a
630                                             | (isFloat from || isDouble from) && (isBounded to || isUnbounded to)
631                                             = let f = if isFloat from then "rintf" else "rint"
632                                               in text f P.<> parens a
633                                             | True
634                                             = a
635
636        cvt (FP_Reinterpret f t) = case (f, t) of
637                                     (KBounded False 32, KFloat)  -> cast $ cpy "sizeof(SFloat)"
638                                     (KBounded False 64, KDouble) -> cast $ cpy "sizeof(SDouble)"
639                                     (KFloat,  KBounded False 32) -> cast $ cpy "sizeof(SWord32)"
640                                     (KDouble, KBounded False 64) -> cast $ cpy "sizeof(SWord64)"
641                                     _                            -> die $ "Reinterpretation from : " ++ show f ++ " to " ++ show t
642                                    where cpy sz = \[a] -> let alhs = text "&" P.<> var
643                                                               arhs = text "&" P.<> a
644                                                           in text "memcpy" P.<> parens (fsep (punctuate comma [alhs, arhs, text sz]))
645        cvt FP_Abs               = dispatch $ named "fabsf" "fabs" $ \nm _ [a] -> text nm P.<> parens a
646        cvt FP_Neg               = dispatch $ same $ \_ [a] -> text "-" P.<> a
647        cvt FP_Add               = dispatch $ same $ \_ [a, b] -> a <+> text "+" <+> b
648        cvt FP_Sub               = dispatch $ same $ \_ [a, b] -> a <+> text "-" <+> b
649        cvt FP_Mul               = dispatch $ same $ \_ [a, b] -> a <+> text "*" <+> b
650        cvt FP_Div               = dispatch $ same $ \_ [a, b] -> a <+> text "/" <+> b
651        cvt FP_FMA               = dispatch $ named "fmaf"  "fma"  $ \nm _ [a, b, c] -> text nm P.<> parens (fsep (punctuate comma [a, b, c]))
652        cvt FP_Sqrt              = dispatch $ named "sqrtf" "sqrt" $ \nm _ [a]       -> text nm P.<> parens a
653        cvt FP_Rem               = dispatch $ named "fmodf" "fmod" $ \nm _ [a, b]    -> text nm P.<> parens (fsep (punctuate comma [a, b]))
654        cvt FP_RoundToIntegral   = dispatch $ named "rintf" "rint" $ \nm _ [a]       -> text nm P.<> parens a
655        cvt FP_Min               = dispatch $ named "fminf" "fmin" $ \nm k [a, b]    -> wrapMinMax k a b (text nm P.<> parens (fsep (punctuate comma [a, b])))
656        cvt FP_Max               = dispatch $ named "fmaxf" "fmax" $ \nm k [a, b]    -> wrapMinMax k a b (text nm P.<> parens (fsep (punctuate comma [a, b])))
657        cvt FP_ObjEqual          = let mkIte   x y z = x <+> text "?" <+> y <+> text ":" <+> z
658                                       chkNaN  x     = text "isnan"   P.<> parens x
659                                       signbit x     = text "signbit" P.<> parens x
660                                       eq      x y   = parens (x <+> text "==" <+> y)
661                                       eqZero  x     = eq x (text "0")
662                                       negZero x     = parens (signbit x <+> text "&&" <+> eqZero x)
663                                   in dispatch $ same $ \_ [a, b] -> mkIte (chkNaN a) (chkNaN b) (mkIte (negZero a) (negZero b) (mkIte (negZero b) (negZero a) (eq a b)))
664        cvt FP_IsNormal          = dispatch $ same $ \_ [a] -> text "isnormal" P.<> parens a
665        cvt FP_IsSubnormal       = dispatch $ same $ \_ [a] -> text "FP_SUBNORMAL == fpclassify" P.<> parens a
666        cvt FP_IsZero            = dispatch $ same $ \_ [a] -> text "FP_ZERO == fpclassify" P.<> parens a
667        cvt FP_IsInfinite        = dispatch $ same $ \_ [a] -> text "isinf" P.<> parens a
668        cvt FP_IsNaN             = dispatch $ same $ \_ [a] -> text "isnan" P.<> parens a
669        cvt FP_IsNegative        = dispatch $ same $ \_ [a] -> text "!isnan" P.<> parens a <+> text "&&" <+> text "signbit"  P.<> parens a
670        cvt FP_IsPositive        = dispatch $ same $ \_ [a] -> text "!isnan" P.<> parens a <+> text "&&" <+> text "!signbit" P.<> parens a
671
672        -- grab the rounding-mode, if present, and make sure it's RoundNearestTiesToEven. Otherwise skip.
673        fpArgs = case as of
674                   []            -> []
675                   ((m, _):args) -> case kindOf m of
676                                      KUserSort "RoundingMode" _ -> case checkRM (m `lookup` consts) of
677                                                                      Nothing          -> args
678                                                                      Just (Left  msg) -> die msg
679                                                                      Just (Right msg) -> tbd msg
680                                      _                          -> as
681
682        -- Check that the RM is RoundNearestTiesToEven.
683        -- If we start supporting other rounding-modes, this would be the point where we'd insert the rounding-mode set/reset code
684        -- instead of merely returning OK or not
685        checkRM (Just cv@(CV (KUserSort "RoundingMode" _) v)) =
686              case v of
687                CUserSort (_, "RoundNearestTiesToEven") -> Nothing
688                CUserSort (_, s)                        -> Just (Right $ "handleIEEE: Unsupported rounding-mode: " ++ show s ++ " for: " ++ show w)
689                _                                       -> Just (Left  $ "handleIEEE: Unexpected value for rounding-mode: " ++ show cv ++ " for: " ++ show w)
690        checkRM (Just cv) = Just (Left  $ "handleIEEE: Expected rounding-mode, but got: " ++ show cv ++ " for: " ++ show w)
691        checkRM Nothing   = Just (Right $ "handleIEEE: Non-constant rounding-mode for: " ++ show w)
692
693        pickOp _          []             = die $ "Cannot determine float/double kind for op: " ++ show w
694        pickOp (fOp, dOp) args@((a,_):_) = case kindOf a of
695                                             KFloat  -> fOp KFloat  (map snd args)
696                                             KDouble -> dOp KDouble (map snd args)
697                                             k       -> die $ "handleIEEE: Expected double/float args, but got: " ++ show k ++ " for: " ++ show w
698
699        dispatch (fOp, dOp) = pickOp (fOp, dOp) fpArgs
700        cast f              = f (map snd fpArgs)
701
702        -- In SMT-Lib, fpMin/fpMax is underspecified when given +0/-0 as the two arguments. (In any order.)
703        -- In C, the second argument is returned. (I think, might depend on the architecture, optimizations etc.).
704        -- We'll translate it so that we deterministically return +0.
705        -- There's really no good choice here.
706        wrapMinMax k a b s = parens cond <+> text "?" <+> zero <+> text ":" <+> s
707          where zero = text $ if k == KFloat then showCFloat 0 else showCDouble 0
708                cond =                   parens (text "FP_ZERO == fpclassify" P.<> parens a)                                      -- a is zero
709                       <+> text "&&" <+> parens (text "FP_ZERO == fpclassify" P.<> parens b)                                      -- b is zero
710                       <+> text "&&" <+> parens (text "signbit" P.<> parens a <+> text "!=" <+> text "signbit" P.<> parens b)       -- a and b differ in sign
711
712ppExpr :: CgConfig -> [(SV, CV)] -> SBVExpr -> Doc -> (Doc, Doc) -> Doc
713ppExpr cfg consts (SBVApp op opArgs) lhs (typ, var)
714  | doNotAssign op
715  = typ <+> var P.<> semi <+> rhs
716  | True
717  = lhs <+> text "=" <+> rhs
718  where doNotAssign (IEEEFP FP_Reinterpret{}) = True   -- generates a memcpy instead; no simple assignment
719        doNotAssign _                         = False  -- generates simple assignment
720
721        rhs = p op (map (showSV cfg consts) opArgs)
722
723        rtc = cgRTC cfg
724
725        cBinOps = [ (Plus, "+"),  (Times, "*"), (Minus, "-")
726                  , (Equal, "=="), (NotEqual, "!="), (LessThan, "<"), (GreaterThan, ">"), (LessEq, "<="), (GreaterEq, ">=")
727                  , (And, "&"), (Or, "|"), (XOr, "^")
728                  ]
729
730        -- see if we can find a constant shift; makes the output way more readable
731        getShiftAmnt def [_, sv] = case sv `lookup` consts of
732                                    Just (CV _  (CInteger i)) -> integer i
733                                    _                         -> def
734        getShiftAmnt def _       = def
735
736        p :: Op -> [Doc] -> Doc
737        p (ArrRead _)       _  = tbd "User specified arrays (ArrRead)"
738        p (ArrEq _ _)       _  = tbd "User specified arrays (ArrEq)"
739        p (Label s)        [a] = a <+> text "/*" <+> text s <+> text "*/"
740        p (IEEEFP w)         as = handleIEEE w  consts (zip opArgs as) var
741        p (PseudoBoolean pb) as = handlePB pb as
742        p (OverflowOp o) _      = tbd $ "Overflow operations" ++ show o
743        p (KindCast _ to)   [a] = parens (text (show to)) <+> a
744        p (Uninterpreted s) [] = text "/* Uninterpreted constant */" <+> text s
745        p (Uninterpreted s) as = text "/* Uninterpreted function */" <+> text s P.<> parens (fsep (punctuate comma as))
746        p (Extract i j) [a]    = extract i j (head opArgs) a
747        p Join [a, b]          = join (let (s1 : s2 : _) = opArgs in (s1, s2, a, b))
748        p (Rol i) [a]          = rotate True  i a (head opArgs)
749        p (Ror i) [a]          = rotate False i a (head opArgs)
750        p Shl     [a, i]       = shift  True  (getShiftAmnt i opArgs) a -- The order of i/a being reversed here is
751        p Shr     [a, i]       = shift  False (getShiftAmnt i opArgs) a -- intentional and historical (from the days when Shl/Shr had a constant parameter.)
752        p Not [a]              = case kindOf (head opArgs) of
753                                   -- be careful about booleans, bitwise complement is not correct for them!
754                                   KBool -> text "!" P.<> a
755                                   _     -> text "~" P.<> a
756        p Ite [a, b, c] = a <+> text "?" <+> b <+> text ":" <+> c
757        p (LkUp (t, k, _, len) ind def) []
758          | not rtc                    = lkUp -- ignore run-time-checks per user request
759          | needsCheckL && needsCheckR = cndLkUp checkBoth
760          | needsCheckL                = cndLkUp checkLeft
761          | needsCheckR                = cndLkUp checkRight
762          | True                       = lkUp
763          where [index, defVal] = map (showSV cfg consts) [ind, def]
764
765                lkUp = text "table" P.<> int t P.<> brackets (showSV cfg consts ind)
766                cndLkUp cnd = cnd <+> text "?" <+> defVal <+> text ":" <+> lkUp
767
768                checkLeft  = index <+> text "< 0"
769                checkRight = index <+> text ">=" <+> int len
770                checkBoth  = parens (checkLeft <+> text "||" <+> checkRight)
771
772                canOverflow True  sz = (2::Integer)^(sz-1)-1 >= fromIntegral len
773                canOverflow False sz = (2::Integer)^sz    -1 >= fromIntegral len
774
775                (needsCheckL, needsCheckR) = case k of
776                                               KBool           -> (False, canOverflow False (1::Int))
777                                               KBounded sg sz  -> (sg, canOverflow sg sz)
778                                               KReal           -> die "array index with real value"
779                                               KFloat          -> die "array index with float value"
780                                               KDouble         -> die "array index with double value"
781                                               KFP{}           -> die "array index with arbitrary float value"
782                                               KString         -> die "array index with string value"
783                                               KChar           -> die "array index with character value"
784                                               KUnbounded      -> case cgInteger cfg of
785                                                                    Nothing -> (True, True) -- won't matter, it'll be rejected later
786                                                                    Just i  -> (True, canOverflow True i)
787                                               KList     s     -> die $ "List sort " ++ show s
788                                               KSet      s     -> die $ "Set sort " ++ show s
789                                               KTuple    s     -> die $ "Tuple sort " ++ show s
790                                               KMaybe    ek    -> die $ "Maybe sort " ++ show ek
791                                               KEither   k1 k2 -> die $ "Either sort " ++ show (k1, k2)
792                                               KUserSort s _   -> die $ "Uninterpreted sort: " ++ s
793
794        -- Div/Rem should be careful on 0, in the SBV world x `div` 0 is 0, x `rem` 0 is x
795        -- NB: Quot is supposed to truncate toward 0; Not clear to me if C guarantees this behavior.
796        -- Brief googling suggests C99 does indeed truncate toward 0, but other C compilers might differ.
797        p Quot [a, b] = let k = kindOf (head opArgs)
798                            z = mkConst cfg $ mkConstCV k (0::Integer)
799                        in protectDiv0 k "/" z a b
800        p Rem  [a, b] = protectDiv0 (kindOf (head opArgs)) "%" a a b
801        p UNeg [a]    = parens (text "-" <+> a)
802        p Abs  [a]    = let f KFloat             = text "fabsf" P.<> parens a
803                            f KDouble            = text "fabs"  P.<> parens a
804                            f (KBounded False _) = text "/* unsigned, skipping call to abs */" <+> a
805                            f (KBounded True 32) = text "labs"  P.<> parens a
806                            f (KBounded True 64) = text "llabs" P.<> parens a
807                            f KUnbounded         = case cgInteger cfg of
808                                                     Nothing -> f $ KBounded True 32 -- won't matter, it'll be rejected later
809                                                     Just i  -> f $ KBounded True i
810                            f KReal              = case cgReal cfg of
811                                                     Nothing           -> f KDouble -- won't matter, it'll be rejected later
812                                                     Just CgFloat      -> f KFloat
813                                                     Just CgDouble     -> f KDouble
814                                                     Just CgLongDouble -> text "fabsl" P.<> parens a
815                            f _                  = text "abs" P.<> parens a
816                        in f (kindOf (head opArgs))
817        -- for And/Or, translate to boolean versions if on boolean kind
818        p And [a, b] | kindOf (head opArgs) == KBool = a <+> text "&&" <+> b
819        p Or  [a, b] | kindOf (head opArgs) == KBool = a <+> text "||" <+> b
820        p o [a, b]
821          | Just co <- lookup o cBinOps
822          = a <+> text co <+> b
823        p NotEqual xs = mkDistinct xs
824        p o args = die $ "Received operator " ++ show o ++ " applied to " ++ show args
825
826        -- generate a pairwise inequality check
827        mkDistinct args = fsep $ andAll $ walk args
828          where walk []     = []
829                walk (e:es) = map (pair e) es ++ walk es
830
831                pair e1 e2  = parens (e1 <+> text "!=" <+> e2)
832
833                -- like punctuate, but more spacing
834                andAll []     = []
835                andAll (d:ds) = go d ds
836                     where go d' [] = [d']
837                           go d' (e:es) = (d' <+> text "&&") : go e es
838
839        -- Div0 needs to protect, but only when the arguments are not float/double. (Div by 0 for those are well defined to be Inf/NaN etc.)
840        protectDiv0 k divOp def a b = case k of
841                                        KFloat  -> res
842                                        KDouble -> res
843                                        _       -> wrap
844           where res  = a <+> text divOp <+> b
845                 wrap = parens (b <+> text "== 0") <+> text "?" <+> def <+> text ":" <+> parens res
846
847        shift toLeft i a = a <+> text cop <+> i
848          where cop | toLeft = "<<"
849                    | True   = ">>"
850
851        rotate toLeft i a s
852          | i < 0   = rotate (not toLeft) (-i) a s
853          | i == 0  = a
854          | True    = case kindOf s of
855                        KBounded True _             -> tbd $ "Rotation of signed quantities: " ++ show (toLeft, i, s)
856                        KBounded False sz | i >= sz -> rotate toLeft (i `mod` sz) a s
857                        KBounded False sz           ->     parens (a <+> text cop  <+> int i)
858                                                      <+> text "|"
859                                                      <+> parens (a <+> text cop' <+> int (sz - i))
860                        KUnbounded                  -> shift toLeft (int i) a -- For SInteger, rotate is the same as shift in Haskell
861                        _                           -> tbd $ "Rotation for unbounded quantity: " ++ show (toLeft, i, s)
862          where (cop, cop') | toLeft = ("<<", ">>")
863                            | True   = (">>", "<<")
864
865        -- TBD: below we only support the values for extract that are "easy" to implement. These should cover
866        -- almost all instances actually generated by SBV, however.
867        extract hi lo i a  -- Isolate the bit-extraction case
868          | hi == lo, KBounded _ sz <- kindOf i, hi < sz, hi >= 0
869          = if hi == 0
870            then text "(SBool)" <+> parens (a <+> text "& 1")
871            else text "(SBool)" <+> parens (parens (a <+> text ">>" <+> int hi) <+> text "& 1")
872        extract hi lo i a
873          | srcSize `notElem` [64, 32, 16]
874          = bad "Unsupported source size"
875          | (hi + 1) `mod` 8 /= 0 || lo `mod` 8 /= 0
876          = bad "Unsupported non-byte-aligned extraction"
877          | tgtSize < 8 || tgtSize `mod` 8 /= 0
878          = bad "Unsupported target size"
879          | True
880          = text cast <+> shifted
881          where bad why    = tbd $ "extract with " ++ show (hi, lo, k, i) ++ " (Reason: " ++ why ++ ".)"
882
883                k          = kindOf i
884                srcSize    = intSizeOf k
885                tgtSize    = hi - lo + 1
886                signChange = srcSize == tgtSize
887
888                cast
889                  | signChange && hasSign k = "(SWord" ++ show srcSize ++ ")"
890                  | signChange              = "(SInt"  ++ show srcSize ++ ")"
891                  | True                    = "(SWord" ++ show tgtSize ++ ")"
892
893                shifted
894                  | lo == 0 = a
895                  | True    = parens (a <+> text ">>" <+> int lo)
896
897        -- TBD: ditto here for join, just like extract above
898        join (i, j, a, b) = case (kindOf i, kindOf j) of
899                              (KBounded False  8, KBounded False  8) -> parens (parens (text "(SWord16)" <+> a) <+> text "<< 8")  <+> text "|" <+> parens (text "(SWord16)" <+> b)
900                              (KBounded False 16, KBounded False 16) -> parens (parens (text "(SWord32)" <+> a) <+> text "<< 16") <+> text "|" <+> parens (text "(SWord32)" <+> b)
901                              (KBounded False 32, KBounded False 32) -> parens (parens (text "(SWord64)" <+> a) <+> text "<< 32") <+> text "|" <+> parens (text "(SWord64)" <+> b)
902                              (k1,                k2)                -> tbd $ "join with " ++ show ((k1, i), (k2, j))
903
904-- same as doubleQuotes, except we have to make sure there are no line breaks..
905-- Otherwise breaks the generated code.. sigh
906printQuotes :: Doc -> Doc
907printQuotes d = text $ '"' : ppSameLine d ++ "\""
908
909-- Remove newlines.. Useful when generating Makefile and such
910ppSameLine :: Doc -> String
911ppSameLine = trim . render
912 where trim ""        = ""
913       trim ('\n':cs) = ' ' : trim (dropWhile isSpace cs)
914       trim (c:cs)    = c   : trim cs
915
916-- Align a bunch of docs to occupy the exact same length by padding in the left by space
917-- this is ugly and inefficient, but easy to code..
918align :: [Doc] -> [Doc]
919align ds = map (text . pad) ss
920  where ss    = map render ds
921        l     = maximum (0 : map length ss)
922        pad s = replicate (l - length s) ' ' ++ s
923
924-- | Merge a bunch of bundles to generate code for a library. For the final
925-- config, we simply return the first config we receive, or the default if none.
926mergeToLib :: String -> [(CgConfig, CgPgmBundle)] -> (CgConfig, CgPgmBundle)
927mergeToLib libName cfgBundles
928  | length nubKinds /= 1
929  = error $  "Cannot merge programs with differing SInteger/SReal mappings. Received the following kinds:\n"
930          ++ unlines (map show nubKinds)
931  | True
932  = (finalCfg, CgPgmBundle bundleKind $ sources ++ libHeader : [libDriver | anyDriver] ++ [libMake | anyMake])
933  where bundles     = map snd cfgBundles
934        kinds       = [k | CgPgmBundle k _ <- bundles]
935        nubKinds    = nub kinds
936        bundleKind  = head nubKinds
937        files       = concat [fs | CgPgmBundle _ fs <- bundles]
938        sigs        = concat [ss | (_, (CgHeader ss, _)) <- files]
939        anyMake     = not (null [() | (_, (CgMakefile{}, _)) <- files])
940        drivers     = [ds | (_, (CgDriver, ds)) <- files]
941        anyDriver   = not (null drivers)
942        mkFlags     = nub (concat [xs | (_, (CgMakefile xs, _)) <- files])
943        sources     = [(f, (CgSource, [pre, libHInclude, post])) | (f, (CgSource, [pre, _, post])) <- files]
944        sourceNms   = map fst sources
945        libHeader   = (libName ++ ".h", (CgHeader sigs, [genHeader bundleKind libName sigs empty]))
946        libHInclude = text "#include" <+> text (show (libName ++ ".h"))
947        libMake     = ("Makefile", (CgMakefile mkFlags, [genLibMake anyDriver libName sourceNms mkFlags]))
948        libDriver   = (libName ++ "_driver.c", (CgDriver, mergeDrivers libName libHInclude (zip (map takeBaseName sourceNms) drivers)))
949        finalCfg    = case cfgBundles of
950                        []         -> defaultCgConfig
951                        ((c, _):_) -> c
952
953-- | Create a Makefile for the library
954genLibMake :: Bool -> String -> [String] -> [String] -> Doc
955genLibMake ifdr libName fs ldFlags = foldr1 ($$) [l | (True, l) <- lns]
956 where ifld = not (null ldFlags)
957       ld | ifld = text "${LDFLAGS}"
958          | True = empty
959       lns = [ (True, text "# Makefile for" <+> nm P.<> text ". Automatically generated by SBV. Do not edit!")
960             , (True,  text "")
961             , (True,  text "# include any user-defined .mk file in the current directory.")
962             , (True,  text "-include *.mk")
963             , (True,  text "")
964             , (True,  text "CC?=gcc")
965             , (True,  text "CCFLAGS?=-Wall -O3 -DNDEBUG -fomit-frame-pointer")
966             , (ifld,  text "LDFLAGS?=" P.<> text (unwords ldFlags))
967             , (True,  text "AR?=ar")
968             , (True,  text "ARFLAGS?=cr")
969             , (True,  text "")
970             , (not ifdr,  text ("all: " ++ liba))
971             , (ifdr,      text ("all: " ++ unwords [liba, libd]))
972             , (True,  text "")
973             , (True,  text liba P.<> text (": " ++ unwords os))
974             , (True,  text "\t${AR} ${ARFLAGS} $@ $^")
975             , (True,  text "")
976             , (ifdr,  text libd P.<> text (": " ++ unwords [libd ++ ".c", libh]))
977             , (ifdr,  text ("\t${CC} ${CCFLAGS} $< -o $@ " ++ liba) <+> ld)
978             , (ifdr,  text "")
979             , (True,  vcat (zipWith mkObj os fs))
980             , (True,  text "clean:")
981             , (True,  text "\trm -f *.o")
982             , (True,  text "")
983             , (True,  text "veryclean: clean")
984             , (not ifdr,  text "\trm -f" <+> text liba)
985             , (ifdr,      text "\trm -f" <+> text (unwords [liba, libd]))
986             , (True,  text "")
987             ]
988       nm = text libName
989       liba = libName ++ ".a"
990       libh = libName ++ ".h"
991       libd = libName ++ "_driver"
992       os   = map (`replaceExtension` ".o") fs
993       mkObj o f =  text o P.<> text (": " ++ unwords [f, libh])
994                 $$ text "\t${CC} ${CCFLAGS} -c $< -o $@"
995                 $$ text ""
996
997-- | Create a driver for a library
998mergeDrivers :: String -> Doc -> [(FilePath, [Doc])] -> [Doc]
999mergeDrivers libName inc ds = pre : concatMap mkDFun ds ++ [callDrivers (map fst ds)]
1000  where pre =  text "/* Example driver program for" <+> text libName P.<> text ". */"
1001            $$ text "/* Automatically generated by SBV. Edit as you see fit! */"
1002            $$ text ""
1003            $$ text "#include <stdio.h>"
1004            $$ inc
1005        mkDFun (f, [_pre, _header, body, _post]) = [header, body, post]
1006           where header =  text ""
1007                        $$ text ("void " ++ f ++ "_driver(void)")
1008                        $$ text "{"
1009                 post   =  text "}"
1010        mkDFun (f, _) = die $ "mergeDrivers: non-conforming driver program for " ++ show f
1011        callDrivers fs =   text ""
1012                       $$  text "int main(void)"
1013                       $$  text "{"
1014                       $+$ nest 2 (vcat (map call fs))
1015                       $$  nest 2 (text "return 0;")
1016                       $$  text "}"
1017        call f =  text psep
1018               $$ text ptag
1019               $$ text psep
1020               $$ text (f ++ "_driver();")
1021               $$ text ""
1022           where tag  = "** Driver run for " ++ f ++ ":"
1023                 ptag = "printf(\"" ++ tag ++ "\\n\");"
1024                 lsep = replicate (length tag) '='
1025                 psep = "printf(\"" ++ lsep ++ "\\n\");"
1026
1027-- Does this operation with this result kind require an LD flag?
1028getLDFlag :: (Op, Kind) -> [String]
1029getLDFlag (o, k) = flag o
1030  where math = ["-lm"]
1031
1032        flag (IEEEFP FP_Cast{})                                     = math
1033        flag (IEEEFP fop)       | fop `elem` requiresMath           = math
1034        flag Abs                | k `elem` [KFloat, KDouble, KReal] = math
1035        flag _                                                      = []
1036
1037        requiresMath = [ FP_Abs
1038                       , FP_FMA
1039                       , FP_Sqrt
1040                       , FP_Rem
1041                       , FP_Min
1042                       , FP_Max
1043                       , FP_RoundToIntegral
1044                       , FP_ObjEqual
1045                       , FP_IsSubnormal
1046                       , FP_IsInfinite
1047                       , FP_IsNaN
1048                       , FP_IsNegative
1049                       , FP_IsPositive
1050                       , FP_IsNormal
1051                       , FP_IsZero
1052                       ]
1053
1054{-# ANN module ("HLint: ignore Redundant lambda" :: String) #-}
1055