1 /*
2 
3     This file is part of the Maude 2 interpreter.
4 
5     Copyright 1997-2003 SRI International, Menlo Park, CA 94025, USA.
6 
7     This program is free software; you can redistribute it and/or modify
8     it under the terms of the GNU General Public License as published by
9     the Free Software Foundation; either version 2 of the License, or
10     (at your option) any later version.
11 
12     This program is distributed in the hope that it will be useful,
13     but WITHOUT ANY WARRANTY; without even the implied warranty of
14     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15     GNU General Public License for more details.
16 
17     You should have received a copy of the GNU General Public License
18     along with this program; if not, write to the Free Software
19     Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
20 
21 */
22 
23 //
24 //      Implementation for class S_Symbol
25 //
26 #include <map>
27 
28 //	utility stuff
29 #include "macros.hh"
30 #include "vector.hh"
31 
32 //      forward declarations
33 #include "interface.hh"
34 #include "core.hh"
35 #include "variable.hh"
36 #include "S_Theory.hh"
37 
38 //      core class definitions
39 #include "hashConsSet.hh"
40 
41 //	successor theory class definitions
42 #include "S_Symbol.hh"
43 #include "S_DagNode.hh"
44 #include "S_Term.hh"
45 #include "S_ExtensionInfo.hh"
46 
S_Symbol(int id,const Vector<int> & strategy,bool memoFlag)47 S_Symbol::S_Symbol(int id, const Vector<int>& strategy, bool memoFlag)
48   : Symbol(id, 1, memoFlag)
49 {
50   setStrategy(strategy, 1, memoFlag);
51 }
52 
53 void
compileOpDeclarations()54 S_Symbol::compileOpDeclarations()
55 {
56   Symbol::compileOpDeclarations();  // do default sort processing
57   const ConnectedComponent* kind = rangeComponent();
58   int nrSorts =  kind->nrSorts();
59   sortPathTable.resize(nrSorts);
60   for (int i = 0; i < nrSorts; i++)
61     computePath(i, sortPathTable[i]);
62 }
63 
64 void
finalizeSortInfo()65 S_Symbol::finalizeSortInfo()
66 {
67   Symbol::finalizeSortInfo();  // do parent classes stuff
68   WarningCheck(kindLevelDeclarationsOnly() || getSortConstraints().empty(),
69 	       "membership axioms are not guaranteed to work correctly for iterated symbol " <<
70 	       QUOTE(this) << " as it has declarations that are not at the kind level.");
71 }
72 
73 void
computePath(int sortIndex,SortPath & path)74 S_Symbol::computePath(int sortIndex, SortPath& path)
75 {
76   path.nonCtorBound = NONE;
77   bool needToLookForNonCtor = false;
78   switch (getCtorStatus())
79     {
80     case SortTable::IS_NON_CTOR:
81       {
82 	path.nonCtorBound = 0;
83 	break;
84       }
85     case SortTable::IS_COMPLEX:
86       {
87 	needToLookForNonCtor = true;
88 	break;
89       }
90     }
91 
92   map<int, int> seen;
93   Vector<int> indices;
94   for (;;)
95     {
96       if (needToLookForNonCtor && !(ctorTraverse(0, sortIndex)))
97 	{
98 	  path.nonCtorBound = indices.length();
99 	  needToLookForNonCtor = false;
100 	}
101       sortIndex = traverse(0, sortIndex);
102       if (seen.find(sortIndex) != seen.end())
103 	{
104 	  path.leadLength = seen[sortIndex];
105 	  break;
106 	}
107 
108       seen[sortIndex] = indices.length();
109       indices.append(sortIndex);
110     }
111   path.sortIndices = indices;
112 }
113 
114 DagNode*
ruleRewrite(DagNode * subject,RewritingContext & context)115 S_Symbol::ruleRewrite(DagNode* subject, RewritingContext& context)
116 {
117   S_ExtensionInfo extensionInfo(safeCast(S_DagNode*, subject));
118   return applyRules(subject, context, &extensionInfo);
119 }
120 
121 Term*
makeTerm(const Vector<Term * > & args)122 S_Symbol::makeTerm(const Vector<Term*>& args)
123 {
124   Assert(args.length() == 1, "bad number of arguments");
125   Assert(args[0] != 0, "null argument");
126   return new S_Term(this, 1, args[0]);
127 }
128 
129 DagNode*
makeDagNode(const Vector<DagNode * > & args)130 S_Symbol::makeDagNode(const Vector<DagNode*>& args)
131 {
132   Assert(args.length() == 1, "bad number of arguments");
133   Assert(args[0] != 0, "null argument");
134   return new S_DagNode(this, 1, args[0]);
135 }
136 
137 bool
eqRewrite(DagNode * subject,RewritingContext & context)138 S_Symbol::eqRewrite(DagNode* subject, RewritingContext& context)
139 {
140   Assert(this == subject->symbol(), "bad symbol");
141   S_DagNode* s = safeCast(S_DagNode*, subject);
142   if (standardStrategy())
143     {
144       //
145       //	Fast eager strategy case.
146       //
147       s->arg->reduce(context);
148       s->normalizeAtTop();  // always needed because shared node may have rewritten
149       if (equationFree())
150 	return false;
151       S_ExtensionInfo extensionInfo(s);
152       return applyReplace(subject, context, &extensionInfo);
153     }
154   if (isMemoized())
155     {
156       //
157       //	Memoized case - get the reduced form and enter
158       //	it in the memoization table.
159       //
160       MemoTable::SourceSet from;
161       memoStrategy(from, subject, context);
162       memoEnter(from, subject);
163       return false;
164     }
165   //
166   //	Complex strategy case.
167   //
168   S_ExtensionInfo extensionInfo(s);
169   const Vector<int>& userStrategy = getStrategy();
170   int stratLen = userStrategy.length();
171   bool seenZero = false;
172 
173   for (int i = 0; i < stratLen; i++)
174     {
175       if(userStrategy[i] == 0)
176 	{
177 	  if (!seenZero)
178 	    {
179 	      s->arg->computeTrueSort(context);
180 	      seenZero = true;
181 	    }
182 	  s->normalizeAtTop();
183 	  if ((i + 1 == stratLen) ?
184 	      applyReplace(subject, context, &extensionInfo) :
185 	      applyReplaceNoOwise(subject, context, &extensionInfo))
186 	      return true;
187 	}
188       else
189 	{
190 	  if (seenZero)
191 	    {
192 	      s->arg->copyReducible();
193 	      //
194 	      //	A previous call to applyReplace() may have
195 	      //	computed a true sort for our subject which will be
196 	      //	invalidated by the reduce we are about to do.
197 	      //
198 	      subject->repudiateSortInfo();
199 	    }
200 	  s->arg->reduce(context);
201 	}
202     }
203   return false;
204 }
205 
206 void
memoStrategy(MemoTable::SourceSet & from,DagNode * subject,RewritingContext & context)207 S_Symbol::memoStrategy(MemoTable::SourceSet& from,
208 		       DagNode* subject,
209 		       RewritingContext& context)
210 {
211   S_DagNode* s = safeCast(S_DagNode*, subject);
212   //
213   //	Execute user supplied strategy.
214   //
215   const Vector<int>& userStrategy = getStrategy();
216   int stratLen = userStrategy.length();
217   bool seenZero = false;
218   for (int i = 0; i < stratLen; i++)
219     {
220       if(userStrategy[i] == 0)
221 	{
222 	  if (!seenZero)
223 	    {
224 	      s->arg->computeTrueSort(context);
225 	      seenZero = true;
226 	    }
227 	  s->normalizeAtTop();
228 	  if (memoRewrite(from, subject, context))
229 	    return;
230 	  S_ExtensionInfo extensionInfo(s);
231 	  if ((i + 1 == stratLen) ?
232 	      applyReplace(subject, context, &extensionInfo) :
233 	      applyReplaceNoOwise(subject, context, &extensionInfo))
234 	    {
235 	      subject->reduce(context);
236 	      return;
237 	    }
238 	}
239       else
240 	{
241 	  if (seenZero)
242 	    {
243 	      s->arg->copyReducible();
244 	      //
245 	      //	A previous call to applyReplace() may have
246 	      //	computed a true sort for our subject which will be
247 	      //	invalidated by the reduce we are about to do.
248 	      //
249 	      subject->repudiateSortInfo();
250 	    }
251 	  s->arg->reduce(context);
252 	}
253     }
254 }
255 
256 void
computeBaseSort(DagNode * subject)257 S_Symbol::computeBaseSort(DagNode* subject)
258 {
259   Assert(this == subject->symbol(), "bad symbol");
260   S_DagNode* s = safeCast(S_DagNode*, subject);
261   int argSortIndex = s->getArgument()->getSortIndex();
262   Assert(argSortIndex != Sort::SORT_UNKNOWN, "unknown sort");
263   subject->setSortIndex(sortPathTable[argSortIndex].computeSortIndex(s->getNumber()));
264 }
265 
266 void
fillInSortInfo(Term * subject)267 S_Symbol::fillInSortInfo(Term* subject)
268 {
269   Assert(this == subject->symbol(), "bad symbol");
270   S_Term* s = safeCast(S_Term*, subject);
271   Term* arg = s->getArgument();
272   arg->symbol()->fillInSortInfo(arg);
273   subject->setSortInfo(rangeComponent(), sortPathTable[arg->getSortIndex()].computeSortIndex(s->getNumber()));
274 }
275 
276 bool
isConstructor(DagNode * subject)277 S_Symbol::isConstructor(DagNode* subject)
278 {
279   Assert(this == subject->symbol(), "bad symbol");
280   S_DagNode* s = safeCast(S_DagNode*, subject);
281   const SortPath& path = sortPathTable[s->arg->getSortIndex()];
282   if (path.nonCtorBound == NONE)
283     return true;
284   const mpz_class& number = *(s->number);
285   return number <= path.nonCtorBound;
286 }
287 
288 void
normalizeAndComputeTrueSort(DagNode * subject,RewritingContext & context)289 S_Symbol::normalizeAndComputeTrueSort(DagNode* subject, RewritingContext& context)
290 {
291   Assert(this == subject->symbol(), "bad symbol");
292   S_DagNode* s = safeCast(S_DagNode*, subject);
293   s->arg->computeTrueSort(context);
294   s->normalizeAtTop();
295   fastComputeTrueSort(subject, context);
296 }
297 
298 void
stackArguments(DagNode * subject,Vector<RedexPosition> & stack,int parentIndex)299 S_Symbol::stackArguments(DagNode* subject,
300 			 Vector<RedexPosition>& stack,
301 			 int parentIndex)
302 {
303   DagNode* arg = safeCast(S_DagNode*, subject)->arg;
304   if (!(getFrozen().contains(0)) && !(arg->isUnstackable()))
305     stack.append(RedexPosition(arg, parentIndex, 0, eagerArgument(0)));
306 }
307 
308 bool
mightCollapseToOurSymbol(const Term * subterm) const309 S_Symbol::mightCollapseToOurSymbol(const Term* subterm) const
310 {
311   const PointerSet& cs = subterm->collapseSymbols();
312   int nrSymbols = cs.cardinality();
313   for (int i = 0; i < nrSymbols; i++)
314     {
315       Symbol* s = static_cast<Symbol*>(cs.index2Pointer(i));
316       if (static_cast<const Symbol*>(s) == this)
317         return true;
318       VariableSymbol* vs = dynamic_cast<VariableSymbol*>(s);
319       if (vs != 0)  // might want to check that vs has big enough sort
320         return true;
321     }
322   return false;
323 }
324 
325 Term*
termify(DagNode * dagNode)326 S_Symbol::termify(DagNode* dagNode)
327 {
328   S_DagNode* d = safeCast(S_DagNode*, dagNode);
329   DagNode* a = d->getArgument();
330   return new S_Term(this, d->getNumber(), a->symbol()->termify(a));
331 }
332 
333 //
334 //	Unification code.
335 //
336 
337 int
unificationPriority() const338 S_Symbol::unificationPriority() const
339 {
340   //
341   //	We don't expect this to be used by current code since there are no S Theory
342   //	unification subproblems.
343   //
344   return 1;
345 }
346 
347 void
computeSortFunctionBdds(const SortBdds &,Vector<Bdd> &) const348 S_Symbol::computeSortFunctionBdds(const SortBdds& /* sortBdds */, Vector<Bdd>& /* sortFunctionBdds */) const
349 {
350   //
351   //	We don't make use of a precomputed sort function since we need to handle stacks of
352   //	symbols efficiently - therefore we don't waste time and space computing one.
353   //
354 }
355 
356 void
computeGeneralizedSort(const SortBdds & sortBdds,const Vector<int> & realToBdd,DagNode * subject,Vector<Bdd> & generalizedSort)357 S_Symbol::computeGeneralizedSort(const SortBdds& sortBdds,
358 				 const Vector<int>& realToBdd,
359 				 DagNode* subject,
360 				 Vector<Bdd>& generalizedSort)
361 {
362   Assert(this == subject->symbol(), "bad symbol");
363   //
364   //	First we compute the generalized sort of our argument.
365   //
366   S_DagNode* s = safeCast(S_DagNode*, subject);
367   DagNode* arg = s->getArgument();
368   const mpz_class& number = s->getNumber();
369   Vector<Bdd> argGenSort;
370   arg->computeGeneralizedSort(sortBdds, realToBdd, argGenSort);
371   //
372   //	Prepare all false generalized output sort vector.
373   //
374   Assert(generalizedSort.empty(), "non-empty generalizedSort");
375   int nrBits = argGenSort.size();
376   generalizedSort.resize(nrBits);
377   //
378   //	The negation of each input BDD will be used at least once
379   //	(otherwise the bit would always be 1 and hence unneeded) and
380   //	thus we calculate them in advance.
381   //
382   Vector<Bdd> negArgGenSort(nrBits);
383   for (int i = 0; i < nrBits; ++i)
384     negArgGenSort[i] = bdd_not(argGenSort[i]);
385   //
386   //	Then for each possible value of this sort we compute
387   //	the index of the sort produced by our iterated functon symbol.
388   //
389   int nrSorts = sortPathTable.size();
390   for (int i = 0; i < nrSorts; ++i)
391     {
392       //
393       //	equal will hold the BDD that returns true when our argument
394       //	has sort index i.
395       //
396       Bdd equal = bddtrue;
397       int inIndex = i;
398       for (int j = 0; j < nrBits; ++j, inIndex >>= 1)
399 	equal = bdd_and(equal, (inIndex & 1) ? argGenSort[j] : negArgGenSort[j]);
400       //
401       //	We compute the output sort index and OR the equal BDD into each
402       //	output BDD whose corrresponding bit is 1 in the output sort index.
403       //
404       int outIndex = sortPathTable[i].computeSortIndex(number);
405       for (int j = 0; j < nrBits; ++j, outIndex >>= 1)
406 	{
407 	  if (outIndex & 1)
408 	    generalizedSort[j] = bdd_or(generalizedSort[j], equal);
409 	}
410     }
411 }
412 
413 // experimental code for faster sort computations
414 void
computeGeneralizedSort2(const SortBdds & sortBdds,const Vector<int> & realToBdd,DagNode * subject,Vector<Bdd> & outputBdds)415 S_Symbol::computeGeneralizedSort2(const SortBdds& sortBdds,
416 				    const Vector<int>& realToBdd,
417 				    DagNode* subject,
418 				    Vector<Bdd>& outputBdds)
419 {
420   Assert(this == subject->symbol(), "bad symbol");
421   //
422   //	First we compute the generalized sort of our argument.
423   //
424   S_DagNode* s = safeCast(S_DagNode*, subject);
425   DagNode* arg = s->getArgument();
426   Vector<Bdd> inputBdds;
427   arg->computeGeneralizedSort2(sortBdds, realToBdd, inputBdds);
428   //
429   //	Since the range and domain of our operator is necessarily
430   //	the same kind, the number of bits encoding sorts will be
431   //	the same, and we can just get this from our inputBdds.
432   //
433   int nrBits = inputBdds.size();
434   //
435   //	The challenge here is that our exponent may be huge so
436   //	we cannot iterate our symbolic BDD sort function.
437   //
438   //	Instead we rely on the assumption that the number of input
439   //	sorts is relatively small.
440   //	We go through evey possible sort and apply the iterated (ground)
441   //	sort function; we then build the symbolic result as a case
442   //	split on input sort.
443   //
444   const mpz_class& number = s->getNumber();
445   //
446   //	The negation of each input BDD will be used at least once
447   //	(otherwise the bit would always be 1 and hence unneeded) and
448   //	thus we calculate them in advance.
449   //
450   Vector<Bdd> negatedInputBdds(nrBits);
451   for (int i = 0; i < nrBits; ++i)
452     negatedInputBdds[i] = bdd_not(inputBdds[i]);
453   //
454   //	We build the symbolic output sort in a stand alone vector.
455   //
456   Vector<Bdd> resultBdds(nrBits);
457   //
458   //	For each possible index that the input sort could have we compute
459   //	the index of the sort produced by our iterated functon symbol.
460   //
461   int nrSorts = sortPathTable.size();
462   for (int i = 0; i < nrSorts; ++i)
463     {
464       //
465       //	equalBdd will hold the BDD that returns true when our argument
466       //	(whose sort depends on variables whose sort is defined by our BDD
467       //	variables) has sort index i.
468       //
469       Bdd equalBdd = bdd_true();
470       int inIndex = i;
471       for (int j = 0; j < nrBits; ++j, inIndex >>= 1)
472 	equalBdd = bdd_and(equalBdd, (inIndex & 1) ? inputBdds[j] : negatedInputBdds[j]);
473       //
474       //	We now do a ground sort computation to determined what our
475       //	output sort index is with input sort index i.
476       //
477       int outIndex = sortPathTable[i].computeSortIndex(number);
478       //
479       //	Finally we want to add the case:
480       //	  If input sort has index i, then output sort has index outIndex.
481       //	For each bit position, we say it the corresponding bit of outIndex if
482       //	equal is true and 0 otherwise.
483       //	And we want to OR these results for each bit position at we iterate
484       //	over all input sorts in our outer loop.
485       //
486       //	Thus for each bitPosition j, we need to OR in 1 iff
487       //	outIndex has a 1 and equal is true. Of course equal is symbolic
488       //	in terms of our BDD variables, so we OR in equal whenever
489       //	outIndex has a 1 bit.
490       //
491       for (int j = 0; j < nrBits; ++j, outIndex >>= 1)
492 	{
493 	  if (outIndex & 1)
494 	    resultBdds[j] = bdd_or(resultBdds[j], equalBdd);
495 	}
496     }
497   //
498   //	Finally we append our result BDDs to the output BDDs.
499   //
500   FOR_EACH_CONST(i, Vector<Bdd>, resultBdds)
501     outputBdds.append(*i);
502 }
503 
504 bool
isStable() const505 S_Symbol::isStable() const
506 {
507   return true;
508 }
509 
510 //
511 //	Hash cons code.
512 //
513 
514 DagNode*
makeCanonical(DagNode * original,HashConsSet * hcs)515 S_Symbol::makeCanonical(DagNode* original, HashConsSet* hcs)
516 {
517   S_DagNode* s = safeCast(S_DagNode*, original);
518   DagNode* d = s->getArgument();
519   DagNode* c = hcs->getCanonical(hcs->insert(d));
520   if (c == d)
521     return original;  // can use the original dag node as the canonical version
522   //
523   //	Need to make new node.
524   //
525   S_DagNode* n = new S_DagNode(this, s->getNumber(), c);
526   n->copySetRewritingFlags(original);
527   n->setSortIndex(original->getSortIndex());
528   return n;
529 }
530 
531 DagNode*
makeCanonicalCopy(DagNode * original,HashConsSet * hcs)532 S_Symbol::makeCanonicalCopy(DagNode* original, HashConsSet* hcs)
533 {
534   //
535   //	We have a unreduced node - copy forced.
536   //
537   S_DagNode* s = safeCast(S_DagNode*, original);
538   DagNode* c = hcs->getCanonical(hcs->insert(s->getArgument()));
539   S_DagNode* n = new S_DagNode(this, s->getNumber(), c);
540   n->copySetRewritingFlags(original);
541   n->setSortIndex(original->getSortIndex());
542   return n;
543 }
544