1 /*
2
3 This file is part of the Maude 2 interpreter.
4
5 Copyright 2004 SRI International, Menlo Park, CA 94025, USA.
6
7 This program is free software; you can redistribute it and/or modify
8 it under the terms of the GNU General Public License as published by
9 the Free Software Foundation; either version 2 of the License, or
10 (at your option) any later version.
11
12 This program is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with this program; if not, write to the Free Software
19 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA.
20
21 */
22
23 //
24 // Implementation for class NumberOpSymbol.
25 //
26 #include <iterator>
27
28 // utility stuff
29 #include "macros.hh"
30 #include "vector.hh"
31 #include "mpzSystem.hh"
32
33 // forward declarations
34 #include "interface.hh"
35 #include "core.hh"
36 #include "freeTheory.hh"
37 #include "ACU_Persistent.hh"
38 #include "ACU_Theory.hh"
39 #include "NA_Theory.hh"
40 #include "builtIn.hh"
41
42 // interface class definitions
43 #include "term.hh"
44
45 // core class definitions
46 #include "rewritingContext.hh"
47 #include "dagArgumentIterator.hh"
48 #include "symbolMap.hh"
49
50 // free theory class definitions
51 #include "freeSymbol.hh"
52 #include "freeDagNode.hh"
53
54 // ACU theory class definitions
55 #include "ACU_Symbol.hh"
56 #include "ACU_DagNode.hh"
57
58 // built in stuff
59 #include "bindingMacros.hh"
60 #include "succSymbol.hh"
61 #include "minusSymbol.hh"
62 #include "stringSymbol.hh"
63 #include "stringDagNode.hh"
64 #include "matrixOpSymbol.hh"
65
MatrixOpSymbol(int id,int arity)66 MatrixOpSymbol::MatrixOpSymbol(int id, int arity)
67 : NumberOpSymbol(id, arity)
68 {
69 #define MACRO(SymbolName, SymbolClass, RequiredFlags, NrArgs) \
70 SymbolName = 0;
71 #include "matrixOpSignature.cc"
72 #undef MACRO
73 }
74
75 bool
attachData(const Vector<Sort * > & opDeclaration,const char * purpose,const Vector<const char * > & data)76 MatrixOpSymbol::attachData(const Vector<Sort*>& opDeclaration,
77 const char* purpose,
78 const Vector<const char*>& data)
79 {
80 if (strcmp(purpose, "MatrixOpSymbol") == 0)
81 return true;
82 return NumberOpSymbol::attachData(opDeclaration, purpose, data);
83 }
84
85 bool
attachSymbol(const char * purpose,Symbol * symbol)86 MatrixOpSymbol::attachSymbol(const char* purpose, Symbol* symbol)
87 {
88 #define MACRO(SymbolName, SymbolClass, RequiredFlags, NrArgs) \
89 BIND_SYMBOL(purpose, symbol, SymbolName, SymbolClass*);
90 #include "matrixOpSignature.cc"
91 #undef MACRO
92 return NumberOpSymbol::attachSymbol(purpose, symbol);
93 }
94
95 void
copyAttachments(Symbol * original,SymbolMap * map)96 MatrixOpSymbol::copyAttachments(Symbol* original, SymbolMap* map)
97 {
98 MatrixOpSymbol* orig = safeCast(MatrixOpSymbol*, original);
99 #define MACRO(SymbolName, SymbolClass, RequiredFlags, NrArgs) \
100 COPY_SYMBOL(orig, SymbolName, map, SymbolClass*);
101 #include "matrixOpSignature.cc"
102 #undef MACRO
103 NumberOpSymbol::copyAttachments(original, map);
104 }
105
106 void
getDataAttachments(const Vector<Sort * > & opDeclaration,Vector<const char * > & purposes,Vector<Vector<const char * >> & data)107 MatrixOpSymbol::getDataAttachments(const Vector<Sort*>& opDeclaration,
108 Vector<const char*>& purposes,
109 Vector<Vector<const char*> >& data)
110 {
111 int nrDataAttachments = purposes.length();
112 purposes.resize(nrDataAttachments + 1);
113 purposes[nrDataAttachments] = "MatrixOpSymbol";
114 data.resize(nrDataAttachments + 1);
115 (data[nrDataAttachments]).resize(1);
116 data[nrDataAttachments][0] = "natSystemSolve";
117 NumberOpSymbol::getDataAttachments(opDeclaration, purposes, data);
118 }
119
120 void
getSymbolAttachments(Vector<const char * > & purposes,Vector<Symbol * > & symbols)121 MatrixOpSymbol::getSymbolAttachments(Vector<const char*>& purposes,
122 Vector<Symbol*>& symbols)
123 {
124 #define MACRO(SymbolName, SymbolClass, RequiredFlags, NrArgs) \
125 APPEND_SYMBOL(purposes, symbols, SymbolName);
126 #include "matrixOpSignature.cc"
127 #undef MACRO
128 NumberOpSymbol::getSymbolAttachments(purposes, symbols);
129 }
130
131 bool
downMatrixEntry(DagNode * dagNode,SparseMatrix & matrix,int & maxRowNr,int & maxColNr)132 MatrixOpSymbol::downMatrixEntry(DagNode* dagNode, SparseMatrix& matrix, int& maxRowNr, int& maxColNr)
133 {
134 if (dagNode->symbol() == matrixEntrySymbol)
135 {
136 FreeDagNode* d = safeCast(FreeDagNode*, dagNode);
137 DagNode* arg = d->getArgument(0);
138 if (arg->symbol() == indexPairSymbol)
139 {
140 FreeDagNode* a = safeCast(FreeDagNode*, arg);
141 int rowNr;
142 int colNr;
143 if (getSuccSymbol()->getSignedInt(a->getArgument(0), rowNr) &&
144 getSuccSymbol()->getSignedInt(a->getArgument(1), colNr) &&
145 getNumber(d->getArgument(1), matrix[rowNr][colNr]))
146 {
147 if (rowNr > maxRowNr)
148 maxRowNr = rowNr;
149 if (colNr > maxColNr)
150 maxColNr = colNr;
151 return true;
152 }
153 }
154 }
155 return false;
156 }
157
158 bool
downMatrix(DagNode * dagNode,SparseMatrix & matrix,int & maxRowNr,int & maxColNr)159 MatrixOpSymbol::downMatrix(DagNode* dagNode, SparseMatrix& matrix, int& maxRowNr, int& maxColNr)
160 {
161 Symbol* s = dagNode->symbol();
162 if (s == matrixSymbol)
163 {
164 for (DagArgumentIterator i(dagNode); i.valid(); i.next())
165 {
166 if (!downMatrixEntry(i.argument(), matrix, maxRowNr, maxColNr))
167 return false;
168 }
169 }
170 else if (s != emptyMatrixSymbol)
171 return downMatrixEntry(dagNode, matrix, maxRowNr, maxColNr);
172 return true;
173 }
174
175 bool
downVectorEntry(DagNode * dagNode,IntVec & vec,int & maxRowNr)176 MatrixOpSymbol::downVectorEntry(DagNode* dagNode, IntVec& vec, int& maxRowNr)
177 {
178 if (dagNode->symbol() == vectorEntrySymbol)
179 {
180 FreeDagNode* d = safeCast(FreeDagNode*, dagNode);
181 int index;
182 if (getSuccSymbol()->getSignedInt(d->getArgument(0), index))
183 {
184 if (index > maxRowNr)
185 {
186 vec.resize(index + 1);
187 for (int i = maxRowNr + 1; i < index; ++i)
188 vec[i] = 0;
189 maxRowNr = index;
190 }
191 if (getNumber(d->getArgument(1), vec[index]))
192 return true;
193 }
194 }
195 return false;
196 }
197
198 bool
downVector(DagNode * dagNode,IntVec & vec,int & maxRowNr)199 MatrixOpSymbol::downVector(DagNode* dagNode, IntVec& vec, int& maxRowNr)
200 {
201 vec.resize(maxRowNr + 1);
202 for (int i = 0; i <= maxRowNr; ++i)
203 vec[i] = 0;
204 Symbol* s = dagNode->symbol();
205 if (s == vectorSymbol)
206 {
207 for (DagArgumentIterator i(dagNode); i.valid(); i.next())
208 {
209 if (!downVectorEntry(i.argument(), vec, maxRowNr))
210 return false;
211 }
212 }
213 else if (s != emptyVectorSymbol)
214 return downVectorEntry(dagNode, vec, maxRowNr);
215 return true;
216 }
217
218 bool
downAlgorithm(DagNode * dagNode,Algorithm & algorithm)219 MatrixOpSymbol::downAlgorithm(DagNode* dagNode, Algorithm& algorithm)
220 {
221 if (dagNode->symbol() == stringSymbol)
222 {
223 const Rope& alg = safeCast(StringDagNode*, dagNode)->getValue();
224 if (alg.empty())
225 algorithm = SYSTEMS_CHOICE;
226 else
227 {
228 char *algStr = alg.makeZeroTerminatedString();
229 if (strcmp(algStr, "cd") == 0)
230 algorithm = CD;
231 else if (strcmp(algStr, "gcd") == 0)
232 algorithm = GCD;
233 else
234 {
235 delete [] algStr;
236 return false;
237 }
238 delete [] algStr;
239 }
240 return true;
241 }
242 return false;
243 }
244
245 DagNode*
upSet(const Vector<DagNode * > & elts)246 MatrixOpSymbol::upSet(const Vector<DagNode*>& elts)
247 {
248 int n = elts.size();
249 if (n == 0)
250 return emptyVectorSetSymbol->makeDagNode();
251 return (n == 1) ? elts[0] : vectorSetSymbol->makeDagNode(elts);
252 }
253
254 DagNode*
upVector(const IntVec & row)255 MatrixOpSymbol::upVector(const IntVec& row)
256 {
257 Vector<DagNode*> elts;
258 Vector<DagNode*> pair(2);
259 int nrRows = row.size();
260 for (int i = 1; i < nrRows; i++)
261 {
262 const mpz_class& v = row[i];
263 Assert(v >= 0, "-ve solution");
264 if (v > 0)
265 {
266 pair[0] = getSuccSymbol()->makeNatDag(i - 1);
267 pair[1] = getSuccSymbol()->makeNatDag(v);
268 elts.append(vectorEntrySymbol->makeDagNode(pair));
269 }
270 }
271 int n = elts.size();
272 if (n == 0)
273 return emptyVectorSymbol->makeDagNode();
274 return (n == 1) ? elts[0] : vectorSymbol->makeDagNode(elts);
275 }
276
277 bool
eqRewrite(DagNode * subject,RewritingContext & context)278 MatrixOpSymbol::eqRewrite(DagNode* subject, RewritingContext& context)
279 {
280 FreeDagNode* d = safeCast(FreeDagNode*, subject);
281 DagNode* m = d->getArgument(0);
282 m->reduce(context);
283 DagNode* v = d->getArgument(1);
284 v->reduce(context);
285 DagNode* a = d->getArgument(2);
286 a->reduce(context);
287
288 Algorithm algorithm;
289 SparseMatrix matrix;
290 IntVec vec;
291 int maxRowNr = -1;
292 int maxColNr = -1;
293 if (downAlgorithm(a, algorithm) &&
294 downMatrix(m, matrix, maxRowNr, maxColNr) &&
295 maxRowNr >= 0 &&
296 downVector(v, vec, maxRowNr))
297 {
298 Vector<DagNode*> homogenous;
299 Vector<DagNode*> inhomogenous;
300 //
301 // Build Diophantine system.
302 //
303 MpzSystem ds;
304 int rowSize = maxColNr + 2;
305 IntVec row(rowSize);
306 for (int i = 0; i <= maxRowNr; i++)
307 {
308 for (int j = 1; j < rowSize; j++)
309 row[j] = 0;
310
311 const mpz_class& v = vec[i];
312 const SparseVector& r = matrix[i];
313 //
314 // If we have an equation with all zero coefficients and nonzero
315 // constant term we can trivially fail.
316 //
317 if (r.empty() && v != 0)
318 goto fail;
319
320 row[0] = -v;
321 FOR_EACH_CONST(j, SparseVector, r)
322 row[j->first + 1] = j->second;
323 ds.insertEqn(row);
324 }
325 for (int j = 1; j < rowSize; j++)
326 row[j] = NONE;
327 row[0] = 1;
328 ds.setUpperBounds(row);
329 //
330 // Extract solutions.
331 //
332 if (algorithm == GCD ||
333 (algorithm == SYSTEMS_CHOICE && maxColNr <= maxRowNr + 1))
334 {
335 while (ds.findNextMinimalSolutionGcd(row))
336 {
337 if (row[0] == 0)
338 homogenous.append(upVector(row));
339 else
340 inhomogenous.append(upVector(row));
341 }
342 }
343 else
344 {
345 while (ds.findNextMinimalSolution(row))
346 {
347 if (row[0] == 0)
348 homogenous.append(upVector(row));
349 else
350 inhomogenous.append(upVector(row));
351 }
352 }
353 //
354 // Build result dag.
355 //
356 fail:
357 Vector<DagNode*> args(2);
358 args[0] = upSet(inhomogenous);
359 args[1] = inhomogenous.empty() ? args[0] : upSet(homogenous);
360 return context.builtInReplace(subject, vectorSetPairSymbol->makeDagNode(args));
361 }
362 //
363 // NumberOpSymbol doesn't know how to deal with this.
364 //
365 return FreeSymbol::eqRewrite(subject, context);
366 }
367