1 /*
2 
3     This file is part of the Maude 2 interpreter.
4 
5     Copyright 1997-2003 SRI International, Menlo Park, CA 94025, USA.
6 
7     This program is free software; you can redistribute it and/or modify
8     it under the terms of the GNU General Public License as published by
9     the Free Software Foundation; either version 2 of the License, or
10     (at your option) any later version.
11 
12     This program is distributed in the hope that it will be useful,
13     but WITHOUT ANY WARRANTY; without even the implied warranty of
14     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15     GNU General Public License for more details.
16 
17     You should have received a copy of the GNU General Public License
18     along with this program; if not, write to the Free Software
19     Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
20 
21 */
22 
23 //
24 //      Implementation for class StringOpSymbol.
25 //
26 
27 //      utility stuff
28 #include "macros.hh"
29 #include "vector.hh"
30 
31 //      forward declarations
32 #include "interface.hh"
33 #include "core.hh"
34 #include "freeTheory.hh"
35 #include "NA_Theory.hh"
36 #include "builtIn.hh"
37 
38 //      interface class definitions
39 #include "symbol.hh"
40 #include "dagNode.hh"
41 #include "term.hh"
42 
43 //      core class definitions
44 #include "rewritingContext.hh"
45 #include "symbolMap.hh"
46 
47 //      free theory class definitions
48 #include "freeNet.hh"
49 #include "freeDagNode.hh"
50 
51 //      built in class definitions
52 #include "floatSymbol.hh"
53 #include "floatDagNode.hh"
54 #include "succSymbol.hh"
55 #include "minusSymbol.hh"
56 #include "divisionSymbol.hh"
57 #include "stringSymbol.hh"
58 #include "stringDagNode.hh"
59 #include "stringOpSymbol.hh"
60 #include "bindingMacros.hh"
61 
StringOpSymbol(int id,int arity)62 StringOpSymbol::StringOpSymbol(int id, int arity)
63   : FreeSymbol(id, arity)
64 {
65   op = NONE;
66   stringSymbol = 0;
67   succSymbol = 0;
68   minusSymbol = 0;
69   divisionSymbol = 0;
70   floatSymbol = 0;
71   decFloatSymbol = 0;
72 }
73 
74 bool
attachData(const Vector<Sort * > & opDeclaration,const char * purpose,const Vector<const char * > & data)75 StringOpSymbol::attachData(const Vector<Sort*>& opDeclaration,
76 			   const char* purpose,
77 			   const Vector<const char*>& data)
78 {
79   BIND_OP(purpose, StringOpSymbol, op, data);
80   return FreeSymbol::attachData(opDeclaration, purpose, data);
81 }
82 
83 bool
attachSymbol(const char * purpose,Symbol * symbol)84 StringOpSymbol::attachSymbol(const char* purpose, Symbol* symbol)
85 {
86   BIND_SYMBOL(purpose, symbol, stringSymbol, StringSymbol*);
87   BIND_SYMBOL(purpose, symbol, succSymbol, SuccSymbol*);
88   BIND_SYMBOL(purpose, symbol, minusSymbol, MinusSymbol*);
89   BIND_SYMBOL(purpose, symbol, divisionSymbol, DivisionSymbol*);
90   BIND_SYMBOL(purpose, symbol, floatSymbol, FloatSymbol*);
91   BIND_SYMBOL(purpose, symbol, decFloatSymbol, Symbol*);
92   return FreeSymbol::attachSymbol(purpose, symbol);
93 }
94 
95 bool
attachTerm(const char * purpose,Term * term)96 StringOpSymbol::attachTerm(const char* purpose, Term* term)
97 {
98   BIND_TERM(purpose, term, trueTerm);
99   BIND_TERM(purpose, term, falseTerm);
100   BIND_TERM(purpose, term, notFoundTerm);
101   return FreeSymbol::attachTerm(purpose, term);
102 }
103 
104 void
copyAttachments(Symbol * original,SymbolMap * map)105 StringOpSymbol::copyAttachments(Symbol* original, SymbolMap* map)
106 {
107   StringOpSymbol* orig = safeCast(StringOpSymbol*, original);
108   op = orig->op;
109   COPY_SYMBOL(orig, stringSymbol, map, StringSymbol*);
110   COPY_SYMBOL(orig, succSymbol, map, SuccSymbol*);
111   COPY_SYMBOL(orig, minusSymbol, map, MinusSymbol*);
112   COPY_SYMBOL(orig, divisionSymbol, map, DivisionSymbol*);
113   COPY_SYMBOL(orig, floatSymbol, map, FloatSymbol*);
114   COPY_SYMBOL(orig, decFloatSymbol, map, Symbol*);
115   COPY_TERM(orig, trueTerm, map);
116   COPY_TERM(orig, falseTerm, map);
117   COPY_TERM(orig, notFoundTerm, map);
118   FreeSymbol::copyAttachments(original, map);
119 }
120 
121 void
getDataAttachments(const Vector<Sort * > & opDeclaration,Vector<const char * > & purposes,Vector<Vector<const char * >> & data)122 StringOpSymbol::getDataAttachments(const Vector<Sort*>& opDeclaration,
123 				   Vector<const char*>& purposes,
124 				   Vector<Vector<const char*> >& data)
125 {
126   int nrDataAttachments = purposes.length();
127   purposes.resize(nrDataAttachments + 1);
128   purposes[nrDataAttachments] = "StringOpSymbol";
129   data.resize(nrDataAttachments + 1);
130   data[nrDataAttachments].resize(1);
131   const char*& d = data[nrDataAttachments][0];
132   switch (op)
133     {
134     CODE_CASE(d, 'f', 'l', "float")
135     CODE_CASE(d, 'l', 'e', "length")
136     CODE_CASE(d, 'a', 's', "ascii")
137     CODE_CASE(d, '+', 0, "+")
138     CODE_CASE(d, '<', 0, "<")
139     CODE_CASE(d, '<', '=', "<=")
140     CODE_CASE(d, '>', 0, ">")
141     CODE_CASE(d, '>', '=', ">=")
142     CODE_CASE(d, 'r', 'a', "rat")
143     CODE_CASE(d, 's', 'u', "substr")
144     CODE_CASE(d, 'f', 'i', "find")
145     CODE_CASE(d, 'r', 'f', "rfind")
146     CODE_CASE(d, 's', 't', "string")
147     CODE_CASE(d, 'd', 'e', "decFloat")
148     CODE_CASE(d, 'c', 'h', "char")
149     default:
150       CantHappen("bad string op");
151     }
152   FreeSymbol::getDataAttachments(opDeclaration, purposes, data);
153 }
154 
155 void
getSymbolAttachments(Vector<const char * > & purposes,Vector<Symbol * > & symbols)156 StringOpSymbol::getSymbolAttachments(Vector<const char*>& purposes,
157 				     Vector<Symbol*>& symbols)
158 {
159   APPEND_SYMBOL(purposes, symbols, stringSymbol);
160   APPEND_SYMBOL(purposes, symbols, succSymbol);
161   APPEND_SYMBOL(purposes, symbols, minusSymbol);
162   APPEND_SYMBOL(purposes, symbols, divisionSymbol);
163   APPEND_SYMBOL(purposes, symbols, floatSymbol);
164   APPEND_SYMBOL(purposes, symbols, decFloatSymbol);
165   FreeSymbol::getSymbolAttachments(purposes, symbols);
166 }
167 
168 void
getTermAttachments(Vector<const char * > & purposes,Vector<Term * > & terms)169 StringOpSymbol::getTermAttachments(Vector<const char*>& purposes,
170 				   Vector<Term*>& terms)
171 {
172   APPEND_TERM(purposes, terms, trueTerm);
173   APPEND_TERM(purposes, terms, falseTerm);
174   APPEND_TERM(purposes, terms, notFoundTerm);
175   FreeSymbol::getTermAttachments(purposes, terms);
176 }
177 
178 void
postInterSymbolPass()179 StringOpSymbol::postInterSymbolPass()
180 {
181   PREPARE_TERM(trueTerm);
182   PREPARE_TERM(falseTerm);
183   PREPARE_TERM(notFoundTerm);
184 }
185 
186 void
reset()187 StringOpSymbol::reset()
188 {
189   trueTerm.reset();  // so true dag can be garbage collected
190   falseTerm.reset();  // so false dag can be garbage collected
191   notFoundTerm.reset();  // so notFound dag can be garbage collected
192   FreeSymbol::reset();  // parents reset() tasks
193 }
194 
195 bool
eqRewrite(DagNode * subject,RewritingContext & context)196 StringOpSymbol::eqRewrite(DagNode* subject, RewritingContext& context)
197 {
198   Assert(this == subject->symbol(), "bad symbol");
199   DebugAdvisory("StringOpSymbol::eqRewrite() called on " << subject);
200   int nrArgs = arity();
201   FreeDagNode* d = safeCast(FreeDagNode*, subject);
202   //
203   //	Evaluate our arguments.
204   //
205   for (int i = 0; i < nrArgs; i++)
206     {
207       DagNode* a = d->getArgument(i);
208       a->reduce(context);
209     }
210   DagNode* a0 = d->getArgument(0);
211   if (a0->symbol() == stringSymbol)
212     {
213       const Rope& left = safeCast(StringDagNode*, a0)->getValue();
214       switch (nrArgs)
215 	{
216 	case 1:
217 	  {
218 	    mpz_class r;
219 	    switch (op)
220 	      {
221 	      case CODE('f', 'l'):
222 		{
223 		  bool error;
224 		  char* flStr = left.makeZeroTerminatedString();
225 		  double fl = stringToDouble(flStr, error);
226 		  delete [] flStr;
227 
228 		  if (error)
229 		    goto fail;
230 		  return floatSymbol->rewriteToFloat(subject, context, fl);
231 		}
232 	      case CODE('l', 'e'):  // length
233 		{
234 		  r = left.length();
235 		  break;
236 		}
237 	      case CODE('a', 's'):  // acsii
238 		{
239 		  if (left.length() != 1)
240 		    goto fail;
241 		  r = static_cast<unsigned char>(left[0]);
242 		  break;
243 		}
244 	      default:
245 		CantHappen("bad string op");
246 	      }
247 	    return succSymbol->rewriteToNat(subject, context, r);
248 	  }
249 	case 2:
250 	  {
251 	    DagNode* a1 = d->getArgument(1);
252 	    if (a1->symbol() == stringSymbol)
253 	      {
254 		const Rope& right = safeCast(StringDagNode*, a1)->getValue();
255 		bool r;
256 		switch (op)
257 		  {
258 		  case '+':
259 		    {
260 		      Rope t(left);
261 		      t += right;
262 		      return rewriteToString(subject, context, t);
263 		    }
264 		  case '<':
265 		    r = left < right;
266 		    break;
267 		  case '>':
268 		    r = left > right;
269 		    break;
270 		  case CODE('<', '='):
271 		    r = left <= right;
272 		    break;
273 		  case CODE('>', '='):
274  		    r = left >= right;
275 		    break;
276 		  default:
277 		    CantHappen("bad string op");
278 		    r = false;  // avoid compiler warning
279 		  }
280 		Assert(trueTerm.getTerm() != 0 && falseTerm.getTerm() != 0,
281 		       "null true/false for relational op");
282 		return context.builtInReplace(subject, r ? trueTerm.getDag() : falseTerm.getDag());
283 	      }
284 	    else if (op == CODE('r', 'a'))
285 	      {
286 		DebugAdvisory("StringOpSymbol::eqRewrite() entered rat case for " << subject);
287 		DagNode* a1 = d->getArgument(1);
288 		Assert(succSymbol != 0, "succSymbol undefined");
289 		if (succSymbol->isNat(a1))
290 		  {
291 		    const mpz_class& n1 = succSymbol->getNat(a1);
292 		    if (n1 >= 2 && n1 <= 36)
293 		      {
294 			mpz_class numerator;
295 			mpz_class denominator;
296 			if (ropeToNumber(left, n1.get_si(), numerator, denominator))
297 			  {
298 			    DagNode* r;
299 			    if (denominator == 0)
300 			      {
301 				if (numerator >= 0)
302 				  return succSymbol->rewriteToNat(subject, context, numerator);
303 				r = minusSymbol->makeNegDag(numerator);
304 			      }
305 			    else
306 			      r = divisionSymbol->makeRatDag(numerator, denominator);
307 			    return context.builtInReplace(subject, r);
308 			  }
309 			else
310 			  DebugAdvisory("StringOpSymbol::eqRewrite() rope to number failed " << subject);
311 		      }
312 		    else
313 		      DebugAdvisory("StringOpSymbol::eqRewrite() a1 out of range " << subject);
314 		  }
315 		else
316 		  DebugAdvisory("StringOpSymbol::eqRewrite() a1 not a nat " << subject);
317 		DebugAdvisory("StringOpSymbol::eqRewrite() failed to rewrite " << subject);
318 	      }
319 	    break;
320 	  }
321 	case 3:
322 	  {
323 	    switch (op)
324 	      {
325 	      case CODE('s', 'u'):  // substr
326 		{
327 		  DagNode* a1 = d->getArgument(1);
328 		  DagNode* a2 = d->getArgument(2);
329 		  Assert(succSymbol != 0, "succSymbol undefined");
330 		  if (succSymbol->isNat(a1) && succSymbol->isNat(a2))
331 		    {
332 		      const mpz_class& n1 = succSymbol->getNat(a1);
333 		      Uint index = n1.fits_uint_p() ? n1.get_ui() : UINT_MAX;
334 		      const mpz_class& n2 = succSymbol->getNat(a2);
335 		      Uint length = n2.fits_uint_p() ? n2.get_ui() : UINT_MAX;
336 		      return rewriteToString(subject, context, substring(left, index, length));
337 		    }
338 		  break;
339 		}
340 	      default:
341 		{
342 		  DagNode* a1 = d->getArgument(1);
343 		  if (a1->symbol() == stringSymbol)
344 		    {
345 		      const Rope& pattern = safeCast(StringDagNode*, a1)->getValue();
346 		      DagNode* a2 = d->getArgument(2);
347 		      Assert(succSymbol != 0, "succSymbol undefined");
348 		      if (succSymbol->isNat(a2))
349 			{
350 			  const mpz_class& n2 = succSymbol->getNat(a2);
351 			  Uint index = n2.fits_uint_p() ? n2.get_ui() : UINT_MAX;
352 			  int r;
353 			  switch (op)
354 			    {
355 			    case CODE('f', 'i'):  // find
356 			      r = fwdFind(left, pattern, index);
357 			      break;
358 			    case CODE('r', 'f'):  // rfind
359 			      r = revFind(left, pattern, index);
360 			      break;
361 			    default:
362 			      CantHappen("bad string op");
363 			      r = 0;  // avoid compiler warning
364 			    }
365 			  Assert(notFoundTerm.getTerm() != 0, "null notFound for find op");
366 			  if (r == NONE)
367 			    return context.builtInReplace(subject, notFoundTerm.getDag());
368 			  return succSymbol->rewriteToNat(subject, context, r);
369 			}
370 		    }
371 		  break;
372 		}
373 	      }
374 	  }
375 	}
376     }
377   else if (a0->symbol() == floatSymbol)
378     {
379       if (nrArgs == 1 && op == CODE('s', 't'))
380 	{
381 	  double fl = safeCast(FloatDagNode*, a0)->getValue();
382 	  return rewriteToString(subject, context, doubleToString(fl));
383 	}
384       else if (nrArgs == 2 && op == CODE('d', 'e'))
385 	{
386 	  DagNode* a1 = d->getArgument(1);
387 	  Assert(succSymbol != 0, "succSymbol undefined");
388 	  Assert(minusSymbol != 0, "minusSymbol undefined");
389 	  if (succSymbol->isNat(a1))
390 	    {
391 	      double fl = safeCast(FloatDagNode*, a0)->getValue();
392 	      const mpz_class& n1 = succSymbol->getNat(a1);
393 	      int nrDigits = (n1 < MAX_FLOAT_DIGITS) ? n1.get_si() : MAX_FLOAT_DIGITS;
394 	      char buffer[MAX_FLOAT_DIGITS + 1];
395 	      int decPt;
396 	      int sign;
397 	      correctEcvt(fl, nrDigits, buffer, decPt, sign);
398 	      Vector<DagNode*> args(0, 3);
399 	      args.append((sign < 0) ? minusSymbol->makeNegDag(sign) :
400 			  succSymbol->makeNatDag(sign));
401 	      args.append(new StringDagNode(stringSymbol, buffer));
402 	      args.append((decPt < 0) ? minusSymbol->makeNegDag(decPt) :
403 			  succSymbol->makeNatDag(decPt));
404 	      return context.builtInReplace(subject, decFloatSymbol->makeDagNode(args));
405 	    }
406 	}
407     }
408   else if (op == CODE('s', 't') && nrArgs == 2)
409     {
410       DagNode* a1 = d->getArgument(1);
411       Assert(succSymbol != 0, "succSymbol undefined");
412       if (succSymbol->isNat(a1))
413 	{
414 	  const mpz_class& n1 = succSymbol->getNat(a1);
415 	  if (n1 >= 2 && n1 <= 36)
416 	    {
417 	      int base = n1.get_si();
418 	      if (succSymbol->isNat(a0))
419 		{
420 		  if (succSymbol->isNat(a0))
421 		    {
422 		      char* ts = mpz_get_str(0, base, succSymbol->getNat(a0).get_mpz_t());
423 		      Rope tr(ts);
424 		      free(ts);
425 		      return rewriteToString(subject, context, tr);
426 		    }
427 		}
428 	      else if (a0->symbol() == minusSymbol)
429 		{
430 		  if (minusSymbol->isNeg(a0))
431 		    {
432 		      mpz_class result;
433 		      char* ts =
434 			mpz_get_str(0, base, minusSymbol->getNeg(a0, result).get_mpz_t());
435 		      Rope tr(ts);
436 		      free(ts);
437 		      return rewriteToString(subject, context, tr);
438 		    }
439 		}
440 	      else if (a0->symbol() == divisionSymbol)
441 		{
442 		  if (divisionSymbol->isRat(a0))
443 		    {
444 		      mpz_class numerator;
445 		      const mpz_class& denomenator = divisionSymbol->getRat(a0, numerator);
446 		      char* ns = mpz_get_str(0, base, numerator.get_mpz_t());
447 		      Rope tr(ns);
448 		      free(ns);
449 		      tr += '/';
450 		      char* ds = mpz_get_str(0, base, denomenator.get_mpz_t());
451 		      tr += ds;
452 		      free(ds);
453 		      return rewriteToString(subject, context, tr);
454 		    }
455 		}
456 	    }
457 	}
458     }
459   else
460     {
461       switch (op)
462 	{
463 	case CODE('c', 'h'):  // char
464 	  {
465 	    DagNode* a0 = d->getArgument(0);
466 	    Assert(succSymbol != 0, "succSymbol undefined");
467 	    if (succSymbol->isNat(a0))
468 	      {
469 		const mpz_class& n0 = succSymbol->getNat(a0);
470 		if (n0 <= 255)
471 		  {
472 		    char c = n0.get_si();
473 		    return rewriteToString(subject, context, Rope(c));
474 		  }
475 	      }
476 	    break;
477 	  }
478 	default:
479 	  ;  // Can get here if args are bad
480 	}
481     }
482  fail:
483   return FreeSymbol::eqRewrite(subject, context);
484 }
485 
486 bool
rewriteToString(DagNode * subject,RewritingContext & context,const Rope & result)487 StringOpSymbol::rewriteToString(DagNode* subject, RewritingContext& context, const Rope& result)
488 {
489   bool trace = RewritingContext::getTraceStatus();
490   if (trace)
491     {
492       context.tracePreEqRewrite(subject, 0, RewritingContext::BUILTIN);
493       if (context.traceAbort())
494 	return false;
495     }
496   (void) new(subject) StringDagNode(stringSymbol, result);
497   context.incrementEqCount();
498   if (trace)
499     context.tracePostEqRewrite(subject);
500   return true;
501 }
502 
503 Rope
substring(const Rope & subject,Rope::size_type index,Rope::size_type length)504 StringOpSymbol::substring(const Rope& subject, Rope::size_type index, Rope::size_type length)
505 {
506   Rope::size_type sLen = subject.length();
507   //  if (index < 0)
508   //    {
509   //      if (length > 0)
510   //	    length += index;
511   //      index = 0;
512   //    }
513   if (length == 0 || index >= sLen)
514     return Rope();
515   if (length > sLen - index)
516     length = sLen - index;
517   return subject.substr(index, length);
518 }
519 
520 int
fwdFind(const Rope & subject,const Rope & pattern,Rope::size_type start)521 StringOpSymbol::fwdFind(const Rope& subject, const Rope& pattern, Rope::size_type start)
522 {
523   Rope::size_type sLen = subject.length();
524   if (pattern.empty())
525     return (start <= sLen) ? static_cast<int>(start) : NONE;
526   //
527   //	Testing start < sLen is important because otherwise 2nd test
528   //	could succeed by wrap around.
529   //
530   if (start < sLen && start + pattern.length() <= sLen)
531     {
532       Rope::const_iterator b(subject.begin());
533       Rope::const_iterator e(subject.end());
534       Rope::const_iterator p(search(b + start, e, pattern.begin(), pattern.end()));
535       if (p != e)
536 	return p - b;
537     }
538   return NONE;
539 }
540 
541 int
revFind(const Rope & subject,const Rope & pattern,Rope::size_type start)542 StringOpSymbol::revFind(const Rope& subject, const Rope& pattern, Rope::size_type start)
543 {
544   Rope::size_type sLen = subject.length();
545   if (pattern.empty())
546     return (start <= sLen) ? start : sLen;
547   Rope::size_type pLen = pattern.length();
548   if (pLen <= sLen)
549     {
550       Rope::size_type reflect = sLen - pLen;  // pattern can't start after this since we need pLen characters.
551       if (start > reflect)
552 	start = reflect;
553       //
554       //	We are going to search the subject from beginning to beginning + start + pLen - 1
555 
556       Rope::const_iterator b(subject.begin());
557       Rope::const_iterator e(b + (start + pLen));
558       Rope::const_iterator p = find_end(b, e, pattern.begin(), pattern.end());
559       if (p != e)
560 	return p - b;
561     }
562   return NONE;
563 }
564 
565 bool
ropeToNumber(const Rope & subject,int base,mpz_class & numerator,mpz_class & denominator)566 StringOpSymbol::ropeToNumber(const Rope& subject,
567 			     int base,
568 			     mpz_class& numerator,
569 			     mpz_class& denominator)
570 {
571   int len = subject.length();
572   if (len == 0)
573     return false;
574   int i = 0;
575   if (subject[i] == '-')
576     {
577       if (len == 1)
578 	return false;
579       ++i;
580     }
581   char c = subject[i];
582   if (!isalnum(c) || (c == '0' && len > 1))
583     return false;
584   for (i++; i < len; i++)
585     {
586       char c = subject[i];
587       if (!isalnum(c))
588 	{
589 	  if (c == '/')
590 	    {
591 	      int j = i + 1;
592 	      if (j == len || subject[j] == '0')
593 		return false;
594 	      for (; j < len; j++)
595 		{
596 		  if (!isalnum(subject[j]))
597 		    return false;
598 		}
599 	      //
600 	      //	We have detected a fraction form.
601 	      //
602 	      char* numStr = subject.substr(0, i).makeZeroTerminatedString();
603 	      char* denomStr = subject.substr(i + 1, len - (i + 1)).makeZeroTerminatedString();
604 	      bool result = (mpz_set_str(denominator.get_mpz_t(), denomStr, base) == 0 &&
605 			     mpz_set_str(numerator.get_mpz_t(), numStr, base) == 0);
606 	      delete [] numStr;
607 	      delete [] denomStr;
608 	      return result;
609 	    }
610 	  else
611 	    return false;
612 	}
613     }
614   //
615   //	We have a regular integer form.
616   //
617   denominator = 0;
618   char* numStr = subject.makeZeroTerminatedString();
619   bool result = (mpz_set_str(numerator.get_mpz_t(), numStr, base) == 0);
620   delete [] numStr;
621   return result;
622 }
623