1 /*
2
3 This file is part of the Maude 2 interpreter.
4
5 Copyright 1997-2003 SRI International, Menlo Park, CA 94025, USA.
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 This program is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with this program; if not, write to the Free Software
19 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
20
21 */
22
23 //
24 // Implementation for class FloatOpSymbol.
25 //
26
27 // utility stuff
28 #include "macros.hh"
29 #include "vector.hh"
30 #include "mathStuff.hh"
31
32 // forward declarations
33 #include "interface.hh"
34 #include "core.hh"
35 #include "freeTheory.hh"
36 #include "NA_Theory.hh"
37 #include "builtIn.hh"
38
39 // interface class definitions
40 #include "symbol.hh"
41 #include "dagNode.hh"
42 #include "term.hh"
43
44 // core class definitions
45 #include "rewritingContext.hh"
46 #include "symbolMap.hh"
47
48 // free theory class definitions
49 #include "freeNet.hh"
50 #include "freeDagNode.hh"
51
52 // built in class definitions
53 #include "succSymbol.hh"
54 #include "minusSymbol.hh"
55 #include "divisionSymbol.hh"
56 #include "floatSymbol.hh"
57 #include "floatDagNode.hh"
58 #include "floatOpSymbol.hh"
59 #include "bindingMacros.hh"
60
FloatOpSymbol(int id,int arity)61 FloatOpSymbol::FloatOpSymbol(int id, int arity)
62 : FreeSymbol(id, arity)
63 {
64 op = NONE;
65 floatSymbol = 0;
66 succSymbol = 0;
67 minusSymbol = 0;
68 divisionSymbol = 0;
69 }
70
71 bool
attachData(const Vector<Sort * > & opDeclaration,const char * purpose,const Vector<const char * > & data)72 FloatOpSymbol::attachData(const Vector<Sort*>& opDeclaration,
73 const char* purpose,
74 const Vector<const char*>& data)
75 {
76 BIND_OP(purpose, FloatOpSymbol, op, data);
77 return FreeSymbol::attachData(opDeclaration, purpose, data);
78 }
79
80 bool
attachSymbol(const char * purpose,Symbol * symbol)81 FloatOpSymbol::attachSymbol(const char* purpose, Symbol* symbol)
82 {
83 BIND_SYMBOL(purpose, symbol, floatSymbol, FloatSymbol*);
84 BIND_SYMBOL(purpose, symbol, succSymbol, SuccSymbol*);
85 BIND_SYMBOL(purpose, symbol, minusSymbol, MinusSymbol*);
86 BIND_SYMBOL(purpose, symbol, divisionSymbol, DivisionSymbol*);
87 return FreeSymbol::attachSymbol(purpose, symbol);
88 }
89
90 bool
attachTerm(const char * purpose,Term * term)91 FloatOpSymbol::attachTerm(const char* purpose, Term* term)
92 {
93 BIND_TERM(purpose, term, trueTerm);
94 BIND_TERM(purpose, term, falseTerm);
95 return FreeSymbol::attachTerm(purpose, term);
96 }
97
98 void
copyAttachments(Symbol * original,SymbolMap * map)99 FloatOpSymbol::copyAttachments(Symbol* original, SymbolMap* map)
100 {
101 FloatOpSymbol* orig = safeCast(FloatOpSymbol*, original);
102 op = orig->op;
103 COPY_SYMBOL(orig, floatSymbol, map, FloatSymbol*);
104 COPY_SYMBOL(orig, succSymbol, map, SuccSymbol*);
105 COPY_SYMBOL(orig, minusSymbol, map, MinusSymbol*);
106 COPY_SYMBOL(orig, divisionSymbol, map, DivisionSymbol*);
107 COPY_TERM(orig, trueTerm, map);
108 COPY_TERM(orig, falseTerm, map);
109 FreeSymbol::copyAttachments(original, map);
110 }
111
112 void
getDataAttachments(const Vector<Sort * > & opDeclaration,Vector<const char * > & purposes,Vector<Vector<const char * >> & data)113 FloatOpSymbol::getDataAttachments(const Vector<Sort*>& opDeclaration,
114 Vector<const char*>& purposes,
115 Vector<Vector<const char*> >& data)
116 {
117 int nrDataAttachments = purposes.length();
118 purposes.resize(nrDataAttachments + 1);
119 purposes[nrDataAttachments] = "FloatOpSymbol";
120 data.resize(nrDataAttachments + 1);
121 data[nrDataAttachments].resize(1);
122 const char*& d = data[nrDataAttachments][0];
123 switch (op)
124 {
125 CODE_CASE(d, '-', 0, "-")
126 CODE_CASE(d, 'a', 'b', "abs")
127 CODE_CASE(d, 'c', 'e', "ceiling")
128 CODE_CASE(d, 's', 'q', "sqrt")
129 CODE_CASE(d, 'e', 'x', "exp")
130 CODE_CASE(d, 'l', 'o', "log")
131 CODE_CASE(d, 's', 'i', "sin")
132 CODE_CASE(d, 'c', 'o', "cos")
133 CODE_CASE(d, 't', 'a', "tan")
134 CODE_CASE(d, 'a', 's', "asin")
135 CODE_CASE(d, 'a', 'c', "acos")
136 CODE_CASE(d, 'a', 't', "atan")
137 CODE_CASE(d, 'r', 'a', "rat")
138 CODE_CASE(d, '+', 0, "+")
139 CODE_CASE(d, '*', 0, "*")
140 CODE_CASE(d, '/', 0, "/")
141 CODE_CASE(d, 'r', 'e', "rem")
142 CODE_CASE(d, '^', 0, "^")
143 CODE_CASE(d, '<', 0, "<")
144 CODE_CASE(d, '<', '=', "<=")
145 CODE_CASE(d, '>', 0, ">")
146 CODE_CASE(d, '>', '=', ">=")
147 CODE_CASE(d, 'm', 'i', "min")
148 CODE_CASE(d, 'm', 'a', "max")
149 case CODE('f', 'l'):
150 {
151 d = (succSymbol == 0) ? "floor" : "float"; // HACK
152 break;
153 }
154 default:
155 CantHappen("bad float op");
156 }
157 FreeSymbol::getDataAttachments(opDeclaration, purposes, data);
158 }
159
160 void
getSymbolAttachments(Vector<const char * > & purposes,Vector<Symbol * > & symbols)161 FloatOpSymbol::getSymbolAttachments(Vector<const char*>& purposes,
162 Vector<Symbol*>& symbols)
163 {
164 APPEND_SYMBOL(purposes, symbols, floatSymbol);
165 APPEND_SYMBOL(purposes, symbols, succSymbol);
166 APPEND_SYMBOL(purposes, symbols, minusSymbol);
167 APPEND_SYMBOL(purposes, symbols, divisionSymbol);
168 FreeSymbol::getSymbolAttachments(purposes, symbols);
169 }
170
171 void
getTermAttachments(Vector<const char * > & purposes,Vector<Term * > & terms)172 FloatOpSymbol::getTermAttachments(Vector<const char*>& purposes,
173 Vector<Term*>& terms)
174 {
175 APPEND_TERM(purposes, terms, trueTerm);
176 APPEND_TERM(purposes, terms, falseTerm);
177 FreeSymbol::getTermAttachments(purposes, terms);
178 }
179
180 void
postInterSymbolPass()181 FloatOpSymbol::postInterSymbolPass()
182 {
183 PREPARE_TERM(trueTerm);
184 PREPARE_TERM(falseTerm);
185 }
186
187 void
reset()188 FloatOpSymbol::reset()
189 {
190 trueTerm.reset(); // so true dag can be garbage collected
191 falseTerm.reset(); // so false dag can be garbage collected
192 FreeSymbol::reset(); // parents reset() tasks
193 }
194
195 bool
eqRewrite(DagNode * subject,RewritingContext & context)196 FloatOpSymbol::eqRewrite(DagNode* subject, RewritingContext& context)
197 {
198 Assert(this == subject->symbol(), "bad symbol");
199 int nrArgs = arity();
200 FreeDagNode* d = static_cast<FreeDagNode*>(subject);
201 bool floatEval = true;
202 //
203 // Evaluate our arguments and check that they are all floats.
204 //
205 for (int i = 0; i < nrArgs; i++)
206 {
207 DagNode* a = d->getArgument(i);
208 a->reduce(context);
209 if (a->symbol() != floatSymbol)
210 floatEval = false;
211 }
212 if (floatEval)
213 {
214 double a1 = static_cast<FloatDagNode*>(d->getArgument(0))->getValue();
215 double r;
216 if (nrArgs == 1)
217 {
218 switch (op)
219 {
220 case '-':
221 r = -a1;
222 break;
223 case CODE('a', 'b'):
224 r = fabs(a1);
225 break;
226 case CODE('f', 'l'):
227 r = floor(a1);
228 break;
229 case CODE('c', 'e'):
230 r = ceil(a1);
231 break;
232 case CODE('s', 'q'):
233 r = sqrt(a1);
234 break;
235 case CODE('e', 'x'):
236 r = exp(a1);
237 break;
238 case CODE('l', 'o'):
239 {
240 if (a1 < 0)
241 goto fail; // some platforms return NaN, some -Infinity
242 r = log(a1);
243 break;
244 }
245 case CODE('s', 'i'):
246 r = sin(a1);
247 break;
248 case CODE('c', 'o'):
249 r = cos(a1);
250 break;
251 case CODE('t', 'a'):
252 r = tan(a1);
253 break;
254 case CODE('a', 's'):
255 {
256 if (a1 < -1.0 || a1 > 1.0)
257 goto fail;
258 r = asin(a1);
259 break;
260 }
261 case CODE('a', 'c'):
262 {
263 if (a1 < -1.0 || a1 > 1.0)
264 goto fail;
265 r = acos(a1);
266 break;
267 }
268 case CODE('a', 't'):
269 r = atan(a1);
270 break;
271 case CODE('r', 'a'):
272 {
273 if (!(finite(a1)))
274 goto fail;
275 mpq_class t;
276 mpq_set_d(t.get_mpq_t(), a1);
277 const mpz_class& numerator = t.get_num();
278 const mpz_class& denominator = t.get_den();
279 DagNode* r;
280 if (denominator == 1)
281 {
282 if (numerator >= 0)
283 return succSymbol->rewriteToNat(subject, context, numerator);
284 r = minusSymbol->makeNegDag(numerator);
285 }
286 else
287 r = divisionSymbol->makeRatDag(numerator, denominator);
288 return context.builtInReplace(subject, r);
289 }
290 default:
291 CantHappen("bad float op");
292 r = 0.0;
293 }
294 }
295 else
296 {
297 double a2 = static_cast<FloatDagNode*>(d->getArgument(1))->getValue();
298 switch (op)
299 {
300 case '+':
301 r = a1 + a2;
302 break;
303 case '-':
304 r = a1 - a2;
305 break;
306 case '*':
307 r = a1 * a2;
308 break;
309 case '/':
310 {
311 if (a2 == 0)
312 goto fail;
313 r = a1 / a2;
314 break;
315 }
316 case CODE('r', 'e'):
317 {
318 if (a2 == 0)
319 goto fail;
320 r = fmod(a1, a2);
321 break;
322 }
323 case '^':
324 {
325 bool defined;
326 r = safePow(a1, a2, defined);
327 if (!defined)
328 goto fail;
329 break;
330 }
331 case CODE('a', 't'):
332 {
333 if (!finite(a1) && !finite(a2))
334 {
335 //
336 // Double infinity case: make args finite
337 //
338 a1 = (a1 < 0) ? -1 : 1;
339 a2 = (a2 < 0) ? -1 : 1;
340 }
341 r = atan2(a1, a2);
342 break;
343 }
344 case CODE('m', 'i'):
345 {
346 r = (a1 < a2) ? a1 : a2;
347 break;
348 }
349 case CODE('m', 'a'):
350 {
351 r = (a1 < a2) ? a2 : a1;
352 break;
353 }
354 default:
355 {
356 switch (op)
357 {
358 case '<':
359 r = a1 < a2;
360 break;
361 case CODE('<', '='):
362 r = a1 <= a2;
363 break;
364 case '>':
365 r = a1 > a2;
366 break;
367 case CODE('>', '='):
368 r = a1 >= a2;
369 break;
370 default:
371 CantHappen("bad float op");
372 r = 0.0; // avoid compiler warning
373 }
374 Assert(trueTerm.getTerm() != 0 && falseTerm.getTerm() != 0,
375 "null true/false for relational op");
376 return context.builtInReplace(subject, r ? trueTerm.getDag() : falseTerm.getDag());
377 }
378 }
379 }
380 if (!isnan(r))
381 return floatSymbol->rewriteToFloat(subject, context, r);
382 }
383 else if (nrArgs == 1)
384 {
385 DagNode* a0 = d->getArgument(0);
386 if (op == CODE('f', 'l') && succSymbol != 0) // check we're float() and not floor()
387 {
388 if (succSymbol->isNat(a0))
389 {
390 mpq_class tq(succSymbol->getNat(a0), 1);
391 return floatSymbol->rewriteToFloat(subject, context, mpq_get_d(tq.get_mpq_t()));
392 }
393 else if (a0->symbol() == minusSymbol)
394 {
395 if (minusSymbol->isNeg(a0))
396 {
397 mpz_class result;
398 mpq_class tq(minusSymbol->getNeg(a0, result), 1);
399 return floatSymbol->rewriteToFloat(subject, context, mpq_get_d(tq.get_mpq_t()));
400 }
401 }
402 else if (a0->symbol() == divisionSymbol)
403 {
404 if (divisionSymbol->isRat(a0))
405 {
406 mpz_class numerator;
407 const mpz_class& denomenator = divisionSymbol->getRat(a0, numerator);
408 mpq_class tq(numerator, denomenator);
409 return floatSymbol->rewriteToFloat(subject, context, mpq_get_d(tq.get_mpq_t()));
410 }
411 }
412 }
413 }
414 fail:
415 return FreeSymbol::eqRewrite(subject, context);
416 }
417
418 int
isOdd(double n)419 FloatOpSymbol::isOdd(double n)
420 {
421 //
422 // Decide if a floating point number is odd or even;
423 // return -1 if neither or can't decide.
424 //
425 if (n != floor(n))
426 return -1; // fractional
427 if (n < 0)
428 n = -n;
429 if (n > INT_DOUBLE_MAX) // oddness is essentially random
430 return -1;
431 return static_cast<Int64>(n) & 1;
432 }
433
434 double
safePow(double a1,double a2,bool & defined)435 FloatOpSymbol::safePow(double a1, double a2, bool& defined)
436 {
437 defined = true;
438 if (isnan(a1))
439 {
440 defined = false;
441 return a1;
442 }
443 if (isnan(a2))
444 {
445 defined = false;
446 return a2;
447 }
448 if (!finite(a1))
449 {
450 if (a2 == 0.0)
451 return 1.0;
452 if (a2 < 0)
453 return 0.0;
454 if (a1 > 0)
455 return a1;
456 int odd = isOdd(a2);
457 if (odd == -1)
458 {
459 defined = false;
460 return 0.0;
461 }
462 return odd ? a1 : -a1;
463 }
464 if (!finite(a2))
465 {
466 if (a1 > 1.0)
467 return a2 > 0 ? a2 : 0;
468 if (a1 == 1.0)
469 return 1.0;
470 if (a1 > 0.0)
471 return a2 < 0 ? -a2 : 0;
472 if (a2 > 0)
473 {
474 if (a1 <= -1.0)
475 defined = false;
476 }
477 else
478 {
479 if (a1 >= -1.0)
480 defined = false;
481 }
482 return 0;
483 }
484 if (a1 == 0.0 && a2 < 0.0)
485 {
486 //
487 // Some platforms return Infinity.
488 //
489 defined = false;
490 return 0.0;
491 }
492 double r = pow(a1, a2);
493 if (isnan(r))
494 defined = false;
495 else if (a1 < 0.0 && r != 0.0)
496 {
497 //
498 // Some platforms get this badly wrong.
499 //
500 int odd = isOdd(a2);
501 if (odd == -1)
502 defined = false;
503 else if ((odd == 1) != (r < 0))
504 r = -r; // fix sign if pow() got it wrong
505 }
506 return r;
507 }
508