1 /*
2 
3     This file is part of the Maude 2 interpreter.
4 
5     Copyright 1997-2003 SRI International, Menlo Park, CA 94025, USA.
6 
7     This program is free software; you can redistribute it and/or modify
8     it under the terms of the GNU General Public License as published by
9     the Free Software Foundation; either version 2 of the License, or
10     (at your option) any later version.
11 
12     This program is distributed in the hope that it will be useful,
13     but WITHOUT ANY WARRANTY; without even the implied warranty of
14     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15     GNU General Public License for more details.
16 
17     You should have received a copy of the GNU General Public License
18     along with this program; if not, write to the Free Software
19     Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
20 
21 */
22 
23 //
24 //      Implementation for class FloatOpSymbol.
25 //
26 
27 //      utility stuff
28 #include "macros.hh"
29 #include "vector.hh"
30 #include "mathStuff.hh"
31 
32 //      forward declarations
33 #include "interface.hh"
34 #include "core.hh"
35 #include "freeTheory.hh"
36 #include "NA_Theory.hh"
37 #include "builtIn.hh"
38 
39 //      interface class definitions
40 #include "symbol.hh"
41 #include "dagNode.hh"
42 #include "term.hh"
43 
44 //      core class definitions
45 #include "rewritingContext.hh"
46 #include "symbolMap.hh"
47 
48 //      free theory class definitions
49 #include "freeNet.hh"
50 #include "freeDagNode.hh"
51 
52 //      built in class definitions
53 #include "succSymbol.hh"
54 #include "minusSymbol.hh"
55 #include "divisionSymbol.hh"
56 #include "floatSymbol.hh"
57 #include "floatDagNode.hh"
58 #include "floatOpSymbol.hh"
59 #include "bindingMacros.hh"
60 
FloatOpSymbol(int id,int arity)61 FloatOpSymbol::FloatOpSymbol(int id, int arity)
62   : FreeSymbol(id, arity)
63 {
64   op = NONE;
65   floatSymbol = 0;
66   succSymbol = 0;
67   minusSymbol = 0;
68   divisionSymbol = 0;
69 }
70 
71 bool
attachData(const Vector<Sort * > & opDeclaration,const char * purpose,const Vector<const char * > & data)72 FloatOpSymbol::attachData(const Vector<Sort*>& opDeclaration,
73 			  const char* purpose,
74 			  const Vector<const char*>& data)
75 {
76   BIND_OP(purpose, FloatOpSymbol, op, data);
77   return FreeSymbol::attachData(opDeclaration, purpose, data);
78 }
79 
80 bool
attachSymbol(const char * purpose,Symbol * symbol)81 FloatOpSymbol::attachSymbol(const char* purpose, Symbol* symbol)
82 {
83   BIND_SYMBOL(purpose, symbol, floatSymbol, FloatSymbol*);
84   BIND_SYMBOL(purpose, symbol, succSymbol, SuccSymbol*);
85   BIND_SYMBOL(purpose, symbol, minusSymbol, MinusSymbol*);
86   BIND_SYMBOL(purpose, symbol, divisionSymbol, DivisionSymbol*);
87   return FreeSymbol::attachSymbol(purpose, symbol);
88 }
89 
90 bool
attachTerm(const char * purpose,Term * term)91 FloatOpSymbol::attachTerm(const char* purpose, Term* term)
92 {
93   BIND_TERM(purpose, term, trueTerm);
94   BIND_TERM(purpose, term, falseTerm);
95   return FreeSymbol::attachTerm(purpose, term);
96 }
97 
98 void
copyAttachments(Symbol * original,SymbolMap * map)99 FloatOpSymbol::copyAttachments(Symbol* original, SymbolMap* map)
100 {
101   FloatOpSymbol* orig = safeCast(FloatOpSymbol*, original);
102   op = orig->op;
103   COPY_SYMBOL(orig, floatSymbol, map, FloatSymbol*);
104   COPY_SYMBOL(orig, succSymbol, map, SuccSymbol*);
105   COPY_SYMBOL(orig, minusSymbol, map, MinusSymbol*);
106   COPY_SYMBOL(orig, divisionSymbol, map, DivisionSymbol*);
107   COPY_TERM(orig, trueTerm, map);
108   COPY_TERM(orig, falseTerm, map);
109   FreeSymbol::copyAttachments(original, map);
110 }
111 
112 void
getDataAttachments(const Vector<Sort * > & opDeclaration,Vector<const char * > & purposes,Vector<Vector<const char * >> & data)113 FloatOpSymbol::getDataAttachments(const Vector<Sort*>& opDeclaration,
114 				  Vector<const char*>& purposes,
115 				  Vector<Vector<const char*> >& data)
116 {
117   int nrDataAttachments = purposes.length();
118   purposes.resize(nrDataAttachments + 1);
119   purposes[nrDataAttachments] = "FloatOpSymbol";
120   data.resize(nrDataAttachments + 1);
121   data[nrDataAttachments].resize(1);
122   const char*& d = data[nrDataAttachments][0];
123   switch (op)
124     {
125     CODE_CASE(d, '-', 0, "-")
126     CODE_CASE(d, 'a', 'b', "abs")
127     CODE_CASE(d, 'c', 'e', "ceiling")
128     CODE_CASE(d, 's', 'q', "sqrt")
129     CODE_CASE(d, 'e', 'x', "exp")
130     CODE_CASE(d, 'l', 'o', "log")
131     CODE_CASE(d, 's', 'i', "sin")
132     CODE_CASE(d, 'c', 'o', "cos")
133     CODE_CASE(d, 't', 'a', "tan")
134     CODE_CASE(d, 'a', 's', "asin")
135     CODE_CASE(d, 'a', 'c', "acos")
136     CODE_CASE(d, 'a', 't', "atan")
137     CODE_CASE(d, 'r', 'a', "rat")
138     CODE_CASE(d, '+', 0, "+")
139     CODE_CASE(d, '*', 0, "*")
140     CODE_CASE(d, '/', 0, "/")
141     CODE_CASE(d, 'r', 'e', "rem")
142     CODE_CASE(d, '^', 0, "^")
143     CODE_CASE(d, '<', 0, "<")
144     CODE_CASE(d, '<', '=', "<=")
145     CODE_CASE(d, '>', 0, ">")
146     CODE_CASE(d, '>', '=', ">=")
147     CODE_CASE(d, 'm', 'i', "min")
148     CODE_CASE(d, 'm', 'a', "max")
149     case CODE('f', 'l'):
150       {
151 	d = (succSymbol == 0) ? "floor" : "float";  // HACK
152 	break;
153       }
154     default:
155       CantHappen("bad float op");
156     }
157   FreeSymbol::getDataAttachments(opDeclaration, purposes, data);
158 }
159 
160 void
getSymbolAttachments(Vector<const char * > & purposes,Vector<Symbol * > & symbols)161 FloatOpSymbol::getSymbolAttachments(Vector<const char*>& purposes,
162 				    Vector<Symbol*>& symbols)
163 {
164   APPEND_SYMBOL(purposes, symbols, floatSymbol);
165   APPEND_SYMBOL(purposes, symbols, succSymbol);
166   APPEND_SYMBOL(purposes, symbols, minusSymbol);
167   APPEND_SYMBOL(purposes, symbols, divisionSymbol);
168   FreeSymbol::getSymbolAttachments(purposes, symbols);
169 }
170 
171 void
getTermAttachments(Vector<const char * > & purposes,Vector<Term * > & terms)172 FloatOpSymbol::getTermAttachments(Vector<const char*>& purposes,
173 				  Vector<Term*>& terms)
174 {
175   APPEND_TERM(purposes, terms, trueTerm);
176   APPEND_TERM(purposes, terms, falseTerm);
177   FreeSymbol::getTermAttachments(purposes, terms);
178 }
179 
180 void
postInterSymbolPass()181 FloatOpSymbol::postInterSymbolPass()
182 {
183   PREPARE_TERM(trueTerm);
184   PREPARE_TERM(falseTerm);
185 }
186 
187 void
reset()188 FloatOpSymbol::reset()
189 {
190   trueTerm.reset();  // so true dag can be garbage collected
191   falseTerm.reset();  // so false dag can be garbage collected
192   FreeSymbol::reset();  // parents reset() tasks
193 }
194 
195 bool
eqRewrite(DagNode * subject,RewritingContext & context)196 FloatOpSymbol::eqRewrite(DagNode* subject, RewritingContext& context)
197 {
198   Assert(this == subject->symbol(), "bad symbol");
199   int nrArgs = arity();
200   FreeDagNode* d = static_cast<FreeDagNode*>(subject);
201   bool floatEval = true;
202   //
203   //	Evaluate our arguments and check that they are all floats.
204   //
205   for (int i = 0; i < nrArgs; i++)
206     {
207       DagNode* a = d->getArgument(i);
208       a->reduce(context);
209       if (a->symbol() != floatSymbol)
210 	floatEval = false;
211     }
212   if (floatEval)
213     {
214       double a1 = static_cast<FloatDagNode*>(d->getArgument(0))->getValue();
215       double r;
216       if (nrArgs == 1)
217 	{
218 	  switch (op)
219 	    {
220 	    case '-':
221 	      r = -a1;
222 	      break;
223 	    case CODE('a', 'b'):
224 	      r = fabs(a1);
225 	      break;
226 	    case CODE('f', 'l'):
227 	      r = floor(a1);
228 	      break;
229 	    case CODE('c', 'e'):
230 	      r = ceil(a1);
231 	      break;
232 	    case CODE('s', 'q'):
233 	      r = sqrt(a1);
234 	      break;
235 	    case CODE('e', 'x'):
236 	      r = exp(a1);
237 	      break;
238 	    case CODE('l', 'o'):
239 	      {
240 		if (a1 < 0)
241 		  goto fail;  // some platforms return NaN, some -Infinity
242 		r = log(a1);
243 		break;
244 	      }
245 	    case CODE('s', 'i'):
246 	      r = sin(a1);
247 	      break;
248 	    case CODE('c', 'o'):
249 	      r = cos(a1);
250 	      break;
251 	    case CODE('t', 'a'):
252 	      r = tan(a1);
253 	      break;
254 	    case CODE('a', 's'):
255 	      {
256 		if (a1 < -1.0 || a1 > 1.0)
257 		  goto fail;
258 		r = asin(a1);
259 		break;
260 	      }
261 	    case CODE('a', 'c'):
262 	      {
263 		if (a1 < -1.0 || a1 > 1.0)
264 		  goto fail;
265 		r = acos(a1);
266 		break;
267 	      }
268 	    case CODE('a', 't'):
269 	      r = atan(a1);
270 	      break;
271 	    case CODE('r', 'a'):
272 	      {
273 		if (!(finite(a1)))
274 		  goto fail;
275 		mpq_class t;
276 		mpq_set_d(t.get_mpq_t(), a1);
277 		const mpz_class& numerator = t.get_num();
278 		const mpz_class& denominator = t.get_den();
279 		DagNode* r;
280 		if (denominator == 1)
281 		  {
282 		    if (numerator >= 0)
283 		      return succSymbol->rewriteToNat(subject, context, numerator);
284 		    r = minusSymbol->makeNegDag(numerator);
285 		  }
286 		else
287 		  r = divisionSymbol->makeRatDag(numerator, denominator);
288 		return context.builtInReplace(subject, r);
289 	      }
290 	    default:
291 	      CantHappen("bad float op");
292 	      r = 0.0;
293 	    }
294 	}
295       else
296 	{
297 	  double a2 = static_cast<FloatDagNode*>(d->getArgument(1))->getValue();
298 	  switch (op)
299 	    {
300 	    case '+':
301 	      r = a1 + a2;
302 	      break;
303 	    case '-':
304 	      r = a1 - a2;
305 	      break;
306 	    case '*':
307 	      r = a1 * a2;
308 	      break;
309 	    case '/':
310 	      {
311 		if (a2 == 0)
312 		  goto fail;
313 		r = a1 / a2;
314 		break;
315 	      }
316 	    case CODE('r', 'e'):
317 	      {
318 		if (a2 == 0)
319 		  goto fail;
320 		r = fmod(a1, a2);
321 		break;
322 	      }
323 	    case '^':
324 	      {
325 		bool defined;
326 		r = safePow(a1, a2, defined);
327 		if (!defined)
328 		  goto fail;
329 		break;
330 	      }
331 	    case CODE('a', 't'):
332 	      {
333 		if (!finite(a1) && !finite(a2))
334 		  {
335 		    //
336 		    //	Double infinity case: make args finite
337 		    //
338 		    a1 = (a1 < 0) ? -1 : 1;
339 		    a2 = (a2 < 0) ? -1 : 1;
340 		  }
341 		r = atan2(a1, a2);
342 		break;
343 	      }
344 	    case CODE('m', 'i'):
345 	      {
346 		r = (a1 < a2) ? a1 : a2;
347 		break;
348 	      }
349 	    case CODE('m', 'a'):
350 	      {
351 		r = (a1 < a2) ? a2 : a1;
352 		break;
353 	      }
354 	    default:
355 	      {
356 		switch (op)
357 		  {
358 		  case '<':
359 		    r = a1 < a2;
360 		    break;
361 		  case CODE('<', '='):
362 		    r = a1 <= a2;
363 		    break;
364 		  case '>':
365 		    r = a1 > a2;
366 		    break;
367 		  case CODE('>', '='):
368 		    r = a1 >= a2;
369 		    break;
370 		  default:
371 		    CantHappen("bad float op");
372 		    r = 0.0;  // avoid compiler warning
373 		  }
374 		Assert(trueTerm.getTerm() != 0 && falseTerm.getTerm() != 0,
375 		       "null true/false for relational op");
376 		return context.builtInReplace(subject, r ? trueTerm.getDag() : falseTerm.getDag());
377 	      }
378 	    }
379 	}
380       if (!isnan(r))
381 	return floatSymbol->rewriteToFloat(subject, context, r);
382     }
383   else if (nrArgs == 1)
384     {
385       DagNode* a0 = d->getArgument(0);
386       if (op == CODE('f', 'l') && succSymbol != 0)  // check we're float() and not floor()
387 	{
388 	  if (succSymbol->isNat(a0))
389 	    {
390 	      mpq_class tq(succSymbol->getNat(a0), 1);
391 	      return floatSymbol->rewriteToFloat(subject, context, mpq_get_d(tq.get_mpq_t()));
392 	    }
393 	  else if (a0->symbol() == minusSymbol)
394 	    {
395 	      if (minusSymbol->isNeg(a0))
396 		{
397 		  mpz_class result;
398 		  mpq_class tq(minusSymbol->getNeg(a0, result), 1);
399 		  return floatSymbol->rewriteToFloat(subject, context, mpq_get_d(tq.get_mpq_t()));
400 		}
401 	    }
402 	  else if (a0->symbol() == divisionSymbol)
403 	    {
404 	      if (divisionSymbol->isRat(a0))
405 		{
406 		  mpz_class numerator;
407 		  const mpz_class& denomenator = divisionSymbol->getRat(a0, numerator);
408 		  mpq_class tq(numerator, denomenator);
409 		  return floatSymbol->rewriteToFloat(subject, context, mpq_get_d(tq.get_mpq_t()));
410 		}
411 	    }
412 	}
413     }
414  fail:
415   return FreeSymbol::eqRewrite(subject, context);
416 }
417 
418 int
isOdd(double n)419 FloatOpSymbol::isOdd(double n)
420 {
421   //
422   //	Decide if a floating point number is odd or even;
423   //	return -1 if neither or can't decide.
424   //
425   if (n != floor(n))
426     return -1;  // fractional
427   if (n < 0)
428     n = -n;
429   if (n > INT_DOUBLE_MAX)  // oddness is essentially random
430     return -1;
431   return static_cast<Int64>(n) & 1;
432 }
433 
434 double
safePow(double a1,double a2,bool & defined)435 FloatOpSymbol::safePow(double a1, double a2, bool& defined)
436 {
437   defined = true;
438   if (isnan(a1))
439     {
440       defined = false;
441       return a1;
442     }
443   if (isnan(a2))
444     {
445       defined = false;
446       return a2;
447     }
448   if (!finite(a1))
449     {
450       if (a2 == 0.0)
451 	return 1.0;
452       if (a2 < 0)
453 	return 0.0;
454       if (a1 > 0)
455 	return a1;
456       int odd = isOdd(a2);
457       if (odd == -1)
458 	{
459 	  defined = false;
460 	  return 0.0;
461 	}
462       return odd ? a1 : -a1;
463     }
464   if (!finite(a2))
465     {
466       if (a1 > 1.0)
467 	return a2 > 0 ? a2 : 0;
468       if (a1 == 1.0)
469 	return 1.0;
470       if (a1 > 0.0)
471 	return a2 < 0 ? -a2 : 0;
472       if (a2 > 0)
473 	{
474 	  if (a1 <= -1.0)
475 	    defined = false;
476 	}
477       else
478 	{
479 	  if (a1 >= -1.0)
480 	    defined = false;
481 	}
482       return 0;
483     }
484   if (a1 == 0.0 && a2 < 0.0)
485     {
486       //
487       //	Some platforms return Infinity.
488       //
489       defined = false;
490       return 0.0;
491     }
492   double r = pow(a1, a2);
493   if (isnan(r))
494     defined = false;
495   else if (a1 < 0.0 && r != 0.0)
496     {
497       //
498       //	Some platforms get this badly wrong.
499       //
500       int odd = isOdd(a2);
501       if (odd == -1)
502 	defined = false;
503       else if ((odd == 1) != (r < 0))
504 	r = -r;  // fix sign if pow() got it wrong
505     }
506   return r;
507 }
508