1 /*
2  * Copyright © 2003-2019 Dynare Team
3  *
4  * This file is part of Dynare.
5  *
6  * Dynare is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Dynare is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with Dynare.  If not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include <cstdlib>
21 #include <cassert>
22 #include <iostream>
23 #include <regex>
24 
25 #include <filesystem>
26 
27 #include "DataTree.hh"
28 
29 void
initConstants()30 DataTree::initConstants()
31 {
32   Zero = AddNonNegativeConstant("0");
33   One = AddNonNegativeConstant("1");
34   Two = AddNonNegativeConstant("2");
35   Three = AddNonNegativeConstant("3");
36 
37   MinusOne = AddUMinus(One);
38 
39   NaN = AddNonNegativeConstant("NaN");
40   Infinity = AddNonNegativeConstant("Inf");
41   MinusInfinity = AddUMinus(Infinity);
42 
43   Pi = AddNonNegativeConstant("3.141592653589793");
44 }
45 
DataTree(SymbolTable & symbol_table_arg,NumericalConstants & num_constants_arg,ExternalFunctionsTable & external_functions_table_arg,bool is_dynamic_arg)46 DataTree::DataTree(SymbolTable &symbol_table_arg,
47                    NumericalConstants &num_constants_arg,
48                    ExternalFunctionsTable &external_functions_table_arg,
49                    bool is_dynamic_arg) :
50   symbol_table{symbol_table_arg},
51   num_constants{num_constants_arg},
52   external_functions_table{external_functions_table_arg},
53   is_dynamic{is_dynamic_arg}
54 {
55   initConstants();
56 }
57 
DataTree(const DataTree & d)58 DataTree::DataTree(const DataTree &d) :
59   symbol_table{d.symbol_table},
60   num_constants{d.num_constants},
61   external_functions_table{d.external_functions_table},
62   is_dynamic{d.is_dynamic},
63   local_variables_vector{d.local_variables_vector}
64 {
65   // Constants must be initialized first because they are used in some Add* methods
66   initConstants();
67 
68   for (const auto &it : d.node_list)
69     it->clone(*this);
70 
71   assert(node_list.size() == d.node_list.size());
72 
73   for (const auto &it : d.local_variables_table)
74     local_variables_table[it.first] = it.second->clone(*this);
75 }
76 
77 DataTree &
operator =(const DataTree & d)78 DataTree::operator=(const DataTree &d)
79 {
80   assert(&symbol_table == &d.symbol_table);
81   assert(&num_constants == &d.num_constants);
82   assert(&external_functions_table == &d.external_functions_table);
83   assert(is_dynamic == d.is_dynamic);
84 
85   num_const_node_map.clear();
86   variable_node_map.clear();
87   unary_op_node_map.clear();
88   binary_op_node_map.clear();
89   trinary_op_node_map.clear();
90   external_function_node_map.clear();
91   var_expectation_node_map.clear();
92   pac_expectation_node_map.clear();
93   first_deriv_external_function_node_map.clear();
94   second_deriv_external_function_node_map.clear();
95 
96   node_list.clear();
97 
98   // Constants must be initialized first because they are used in some Add* methods
99   initConstants();
100 
101   /* Model local variables must be next, because they can be evaluated in Add*
102      methods when the model equations are added */
103   for (const auto &it : d.local_variables_table)
104     local_variables_table[it.first] = it.second->clone(*this);
105 
106   for (const auto &it : d.node_list)
107     it->clone(*this);
108 
109   assert(node_list.size() == d.node_list.size());
110 
111   local_variables_vector = d.local_variables_vector;
112 
113   return *this;
114 }
115 
116 expr_t
AddNonNegativeConstant(const string & value)117 DataTree::AddNonNegativeConstant(const string &value)
118 {
119   int id = num_constants.AddNonNegativeConstant(value);
120 
121   if (auto it = num_const_node_map.find(id);
122       it != num_const_node_map.end())
123     return it->second;
124 
125   auto sp = make_unique<NumConstNode>(*this, node_list.size(), id);
126   auto p = sp.get();
127   node_list.push_back(move(sp));
128   num_const_node_map[id] = p;
129   return p;
130 }
131 
132 VariableNode *
AddVariable(int symb_id,int lag)133 DataTree::AddVariable(int symb_id, int lag)
134 {
135   if (lag != 0 && !is_dynamic)
136     {
137       cerr << "Leads/lags not authorized in this DataTree" << endl;
138       exit(EXIT_FAILURE);
139     }
140 
141   if (auto it = variable_node_map.find({ symb_id, lag });
142       it != variable_node_map.end())
143     return it->second;
144 
145   auto sp = make_unique<VariableNode>(*this, node_list.size(), symb_id, lag);
146   auto p = sp.get();
147   node_list.push_back(move(sp));
148   variable_node_map[{ symb_id, lag }] = p;
149   return p;
150 }
151 
152 VariableNode *
getVariable(int symb_id,int lag) const153 DataTree::getVariable(int symb_id, int lag) const
154 {
155   auto it = variable_node_map.find({ symb_id, lag });
156   if (it == variable_node_map.end())
157     {
158       cerr << "DataTree::getVariable: unknown variable node for symb_id=" << symb_id << " and lag=" << lag << endl;
159       exit(EXIT_FAILURE);
160     }
161   return it->second;
162 }
163 
164 bool
ParamUsedWithLeadLagInternal() const165 DataTree::ParamUsedWithLeadLagInternal() const
166 {
167   for (const auto &it : variable_node_map)
168     if (symbol_table.getType(it.first.first) == SymbolType::parameter && it.first.second != 0)
169       return true;
170   return false;
171 }
172 
173 expr_t
AddPlus(expr_t iArg1,expr_t iArg2)174 DataTree::AddPlus(expr_t iArg1, expr_t iArg2)
175 {
176   if (iArg2 == Zero)
177     return iArg1;
178 
179   if (iArg1 == Zero)
180     return iArg2;
181 
182   // Simplify x+(-y) in x-y
183   if (auto uarg2 = dynamic_cast<UnaryOpNode *>(iArg2);
184       uarg2 && uarg2->op_code == UnaryOpcode::uminus)
185     return AddMinus(iArg1, uarg2->arg);
186 
187   // Simplify (-x)+y in y-x
188   if (auto uarg1 = dynamic_cast<UnaryOpNode *>(iArg1);
189       uarg1 && uarg1->op_code == UnaryOpcode::uminus)
190     return AddMinus(iArg2, uarg1->arg);
191 
192   // Simplify (x-y)+y in x
193   if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1);
194       barg1 && barg1->op_code == BinaryOpcode::minus && barg1->arg2 == iArg2)
195     return barg1->arg1;
196 
197   // Simplify y+(x-y) in x
198   if (auto barg2 = dynamic_cast<BinaryOpNode *>(iArg2);
199       barg2 && barg2->op_code == BinaryOpcode::minus && barg2->arg2 == iArg1)
200     return barg2->arg1;
201 
202   // To treat commutativity of "+"
203   // Nodes iArg1 and iArg2 are sorted by index
204   if (iArg1->idx > iArg2->idx)
205     swap(iArg1, iArg2);
206   return AddBinaryOp(iArg1, BinaryOpcode::plus, iArg2);
207 }
208 
209 expr_t
AddMinus(expr_t iArg1,expr_t iArg2)210 DataTree::AddMinus(expr_t iArg1, expr_t iArg2)
211 {
212   if (iArg2 == Zero)
213     return iArg1;
214 
215   if (iArg1 == Zero)
216     return AddUMinus(iArg2);
217 
218   if (iArg1 == iArg2)
219     return Zero;
220 
221   // Simplify x-(-y) in x+y
222   if (auto uarg2 = dynamic_cast<UnaryOpNode *>(iArg2);
223       uarg2 && uarg2->op_code == UnaryOpcode::uminus)
224     return AddPlus(iArg1, uarg2->arg);
225 
226   // Simplify (x+y)-y and (y+x)-y in x
227   if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1);
228       barg1 && barg1->op_code == BinaryOpcode::plus)
229     {
230       if (barg1->arg2 == iArg2)
231         return barg1->arg1;
232       if (barg1->arg1 == iArg2)
233         return barg1->arg2;
234     }
235 
236   return AddBinaryOp(iArg1, BinaryOpcode::minus, iArg2);
237 }
238 
239 expr_t
AddUMinus(expr_t iArg1)240 DataTree::AddUMinus(expr_t iArg1)
241 {
242   if (iArg1 == Zero)
243     return Zero;
244 
245   // Simplify -(-x) in x
246   if (auto uarg = dynamic_cast<UnaryOpNode *>(iArg1);
247       uarg && uarg->op_code == UnaryOpcode::uminus)
248     return uarg->arg;
249 
250   return AddUnaryOp(UnaryOpcode::uminus, iArg1);
251 }
252 
253 expr_t
AddTimes(expr_t iArg1,expr_t iArg2)254 DataTree::AddTimes(expr_t iArg1, expr_t iArg2)
255 {
256   if (iArg1 == Zero || iArg2 == Zero)
257     return Zero;
258 
259   if (iArg1 == One)
260     return iArg2;
261 
262   if (iArg2 == One)
263     return iArg1;
264 
265   if (iArg1 == MinusOne)
266     return AddUMinus(iArg2);
267 
268   if (iArg2 == MinusOne)
269     return AddUMinus(iArg1);
270 
271   // Simplify (x/y)*y in x
272   if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1);
273       barg1 && barg1->op_code == BinaryOpcode::divide && barg1->arg2 == iArg2)
274     return barg1->arg1;
275 
276   // Simplify y*(x/y) in x
277   if (auto barg2 = dynamic_cast<BinaryOpNode *>(iArg2);
278       barg2 && barg2->op_code == BinaryOpcode::divide && barg2->arg2 == iArg1)
279     return barg2->arg1;
280 
281   // To treat commutativity of "*"
282   // Nodes iArg1 and iArg2 are sorted by index
283   if (iArg1->idx > iArg2->idx)
284     swap(iArg1, iArg2);
285   return AddBinaryOp(iArg1, BinaryOpcode::times, iArg2);
286 }
287 
288 expr_t
AddDivide(expr_t iArg1,expr_t iArg2)289 DataTree::AddDivide(expr_t iArg1, expr_t iArg2) noexcept(false)
290 {
291   if (iArg2 == One)
292     return iArg1;
293 
294   // This test should be before the next two, otherwise 0/0 won't be rejected
295   if (iArg2 == Zero)
296     {
297       cerr << "ERROR: Division by zero!" << endl;
298       throw DivisionByZeroException();
299     }
300 
301   if (iArg1 == Zero)
302     return Zero;
303 
304   if (iArg1 == iArg2)
305     return One;
306 
307   // Simplify x/(1/y) in x*y
308   if (auto barg2 = dynamic_cast<BinaryOpNode *>(iArg2);
309       barg2 && barg2->op_code == BinaryOpcode::divide && barg2->arg1 == One)
310     return AddTimes(iArg1, barg2->arg2);
311 
312   // Simplify (x*y)/y and (y*x)/y in x
313   if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1);
314       barg1 && barg1->op_code == BinaryOpcode::times)
315     {
316       if (barg1->arg2 == iArg2)
317         return barg1->arg1;
318       if (barg1->arg1 == iArg2)
319         return barg1->arg2;
320     }
321 
322   return AddBinaryOp(iArg1, BinaryOpcode::divide, iArg2);
323 }
324 
325 expr_t
AddLess(expr_t iArg1,expr_t iArg2)326 DataTree::AddLess(expr_t iArg1, expr_t iArg2)
327 {
328   return AddBinaryOp(iArg1, BinaryOpcode::less, iArg2);
329 }
330 
331 expr_t
AddGreater(expr_t iArg1,expr_t iArg2)332 DataTree::AddGreater(expr_t iArg1, expr_t iArg2)
333 {
334   return AddBinaryOp(iArg1, BinaryOpcode::greater, iArg2);
335 }
336 
337 expr_t
AddLessEqual(expr_t iArg1,expr_t iArg2)338 DataTree::AddLessEqual(expr_t iArg1, expr_t iArg2)
339 {
340   return AddBinaryOp(iArg1, BinaryOpcode::lessEqual, iArg2);
341 }
342 
343 expr_t
AddGreaterEqual(expr_t iArg1,expr_t iArg2)344 DataTree::AddGreaterEqual(expr_t iArg1, expr_t iArg2)
345 {
346   return AddBinaryOp(iArg1, BinaryOpcode::greaterEqual, iArg2);
347 }
348 
349 expr_t
AddEqualEqual(expr_t iArg1,expr_t iArg2)350 DataTree::AddEqualEqual(expr_t iArg1, expr_t iArg2)
351 {
352   return AddBinaryOp(iArg1, BinaryOpcode::equalEqual, iArg2);
353 }
354 
355 expr_t
AddDifferent(expr_t iArg1,expr_t iArg2)356 DataTree::AddDifferent(expr_t iArg1, expr_t iArg2)
357 {
358   return AddBinaryOp(iArg1, BinaryOpcode::different, iArg2);
359 }
360 
361 expr_t
AddPower(expr_t iArg1,expr_t iArg2)362 DataTree::AddPower(expr_t iArg1, expr_t iArg2)
363 {
364   // This one comes first, because 0⁰=1
365   if (iArg2 == Zero)
366     return One;
367 
368   if (iArg1 == Zero)
369     return Zero;
370 
371   if (iArg1 == One)
372     return One;
373 
374   if (iArg2 == One)
375     return iArg1;
376 
377   return AddBinaryOp(iArg1, BinaryOpcode::power, iArg2);
378 }
379 
380 expr_t
AddPowerDeriv(expr_t iArg1,expr_t iArg2,int powerDerivOrder)381 DataTree::AddPowerDeriv(expr_t iArg1, expr_t iArg2, int powerDerivOrder)
382 {
383   assert(powerDerivOrder > 0);
384   return AddBinaryOp(iArg1, BinaryOpcode::powerDeriv, iArg2, powerDerivOrder);
385 }
386 
387 expr_t
AddDiff(expr_t iArg1)388 DataTree::AddDiff(expr_t iArg1)
389 {
390   if (iArg1->maxLead() > 0)
391     // Issue preprocessor#21: always expand diffs with leads
392     return AddMinus(iArg1, iArg1->decreaseLeadsLags(1));
393   return AddUnaryOp(UnaryOpcode::diff, iArg1);
394 }
395 
396 expr_t
AddAdl(expr_t iArg1,const string & name,const vector<int> & lags)397 DataTree::AddAdl(expr_t iArg1, const string &name, const vector<int> &lags)
398 {
399   return AddUnaryOp(UnaryOpcode::adl, iArg1, 0, 0, 0, string(name), lags);
400 }
401 
402 expr_t
AddExp(expr_t iArg1)403 DataTree::AddExp(expr_t iArg1)
404 {
405   if (iArg1 == Zero)
406     return One;
407 
408   return AddUnaryOp(UnaryOpcode::exp, iArg1);
409 }
410 
411 expr_t
AddLog(expr_t iArg1)412 DataTree::AddLog(expr_t iArg1)
413 {
414   if (iArg1 == One)
415     return Zero;
416 
417   if (iArg1 == Zero)
418     {
419       cerr << "ERROR: log(0) not defined!" << endl;
420       exit(EXIT_FAILURE);
421     }
422 
423   // Simplify log(1/x) in −log(x)
424   if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1);
425       barg1 && barg1->op_code == BinaryOpcode::divide && barg1->arg1 == One)
426     return AddUMinus(AddLog(barg1->arg2));
427 
428   return AddUnaryOp(UnaryOpcode::log, iArg1);
429 }
430 
431 expr_t
AddLog10(expr_t iArg1)432 DataTree::AddLog10(expr_t iArg1)
433 {
434   if (iArg1 == One)
435     return Zero;
436 
437   if (iArg1 == Zero)
438     {
439       cerr << "ERROR: log10(0) not defined!" << endl;
440       exit(EXIT_FAILURE);
441     }
442 
443   // Simplify log₁₀(1/x) in −log₁₀(x)
444   if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1);
445       barg1 && barg1->op_code == BinaryOpcode::divide && barg1->arg1 == One)
446     return AddUMinus(AddLog10(barg1->arg2));
447 
448   return AddUnaryOp(UnaryOpcode::log10, iArg1);
449 }
450 
451 expr_t
AddCos(expr_t iArg1)452 DataTree::AddCos(expr_t iArg1)
453 {
454   if (iArg1 == Zero)
455     return One;
456 
457   return AddUnaryOp(UnaryOpcode::cos, iArg1);
458 }
459 
460 expr_t
AddSin(expr_t iArg1)461 DataTree::AddSin(expr_t iArg1)
462 {
463   if (iArg1 == Zero)
464     return Zero;
465 
466   return AddUnaryOp(UnaryOpcode::sin, iArg1);
467 }
468 
469 expr_t
AddTan(expr_t iArg1)470 DataTree::AddTan(expr_t iArg1)
471 {
472   if (iArg1 == Zero)
473     return Zero;
474 
475   return AddUnaryOp(UnaryOpcode::tan, iArg1);
476 }
477 
478 expr_t
AddAcos(expr_t iArg1)479 DataTree::AddAcos(expr_t iArg1)
480 {
481   if (iArg1 == One)
482     return Zero;
483 
484   return AddUnaryOp(UnaryOpcode::acos, iArg1);
485 }
486 
487 expr_t
AddAsin(expr_t iArg1)488 DataTree::AddAsin(expr_t iArg1)
489 {
490   if (iArg1 == Zero)
491     return Zero;
492 
493   return AddUnaryOp(UnaryOpcode::asin, iArg1);
494 }
495 
496 expr_t
AddAtan(expr_t iArg1)497 DataTree::AddAtan(expr_t iArg1)
498 {
499   if (iArg1 == Zero)
500     return Zero;
501 
502   return AddUnaryOp(UnaryOpcode::atan, iArg1);
503 }
504 
505 expr_t
AddCosh(expr_t iArg1)506 DataTree::AddCosh(expr_t iArg1)
507 {
508   if (iArg1 == Zero)
509     return One;
510 
511   return AddUnaryOp(UnaryOpcode::cosh, iArg1);
512 }
513 
514 expr_t
AddSinh(expr_t iArg1)515 DataTree::AddSinh(expr_t iArg1)
516 {
517   if (iArg1 == Zero)
518     return Zero;
519 
520   return AddUnaryOp(UnaryOpcode::sinh, iArg1);
521 }
522 
523 expr_t
AddTanh(expr_t iArg1)524 DataTree::AddTanh(expr_t iArg1)
525 {
526   if (iArg1 == Zero)
527     return Zero;
528 
529   return AddUnaryOp(UnaryOpcode::tanh, iArg1);
530 }
531 
532 expr_t
AddAcosh(expr_t iArg1)533 DataTree::AddAcosh(expr_t iArg1)
534 {
535   if (iArg1 == One)
536     return Zero;
537 
538   return AddUnaryOp(UnaryOpcode::acosh, iArg1);
539 }
540 
541 expr_t
AddAsinh(expr_t iArg1)542 DataTree::AddAsinh(expr_t iArg1)
543 {
544   if (iArg1 == Zero)
545     return Zero;
546 
547   return AddUnaryOp(UnaryOpcode::asinh, iArg1);
548 }
549 
550 expr_t
AddAtanh(expr_t iArg1)551 DataTree::AddAtanh(expr_t iArg1)
552 {
553   if (iArg1 == Zero)
554     return Zero;
555 
556   return AddUnaryOp(UnaryOpcode::atanh, iArg1);
557 }
558 
559 expr_t
AddSqrt(expr_t iArg1)560 DataTree::AddSqrt(expr_t iArg1)
561 {
562   if (iArg1 == Zero)
563     return Zero;
564 
565   if (iArg1 == One)
566     return One;
567 
568   return AddUnaryOp(UnaryOpcode::sqrt, iArg1);
569 }
570 
571 expr_t
AddCbrt(expr_t iArg1)572 DataTree::AddCbrt(expr_t iArg1)
573 {
574   if (iArg1 == Zero)
575     return Zero;
576 
577   if (iArg1 == One)
578     return One;
579 
580   return AddUnaryOp(UnaryOpcode::cbrt, iArg1);
581 }
582 
583 expr_t
AddAbs(expr_t iArg1)584 DataTree::AddAbs(expr_t iArg1)
585 {
586   if (iArg1 == Zero)
587     return Zero;
588 
589   if (iArg1 == One)
590     return One;
591 
592   return AddUnaryOp(UnaryOpcode::abs, iArg1);
593 }
594 
595 expr_t
AddSign(expr_t iArg1)596 DataTree::AddSign(expr_t iArg1)
597 {
598   if (iArg1 == Zero)
599     return Zero;
600 
601   if (iArg1 == One)
602     return One;
603 
604   return AddUnaryOp(UnaryOpcode::sign, iArg1);
605 }
606 
607 expr_t
AddErf(expr_t iArg1)608 DataTree::AddErf(expr_t iArg1)
609 {
610   if (iArg1 == Zero)
611     return Zero;
612 
613   return AddUnaryOp(UnaryOpcode::erf, iArg1);
614 }
615 
616 expr_t
AddMax(expr_t iArg1,expr_t iArg2)617 DataTree::AddMax(expr_t iArg1, expr_t iArg2)
618 {
619   return AddBinaryOp(iArg1, BinaryOpcode::max, iArg2);
620 }
621 
622 expr_t
AddMin(expr_t iArg1,expr_t iArg2)623 DataTree::AddMin(expr_t iArg1, expr_t iArg2)
624 {
625   return AddBinaryOp(iArg1, BinaryOpcode::min, iArg2);
626 }
627 
628 expr_t
AddNormcdf(expr_t iArg1,expr_t iArg2,expr_t iArg3)629 DataTree::AddNormcdf(expr_t iArg1, expr_t iArg2, expr_t iArg3)
630 {
631   return AddTrinaryOp(iArg1, TrinaryOpcode::normcdf, iArg2, iArg3);
632 }
633 
634 expr_t
AddNormpdf(expr_t iArg1,expr_t iArg2,expr_t iArg3)635 DataTree::AddNormpdf(expr_t iArg1, expr_t iArg2, expr_t iArg3)
636 {
637   return AddTrinaryOp(iArg1, TrinaryOpcode::normpdf, iArg2, iArg3);
638 }
639 
640 expr_t
AddSteadyState(expr_t iArg1)641 DataTree::AddSteadyState(expr_t iArg1)
642 {
643   return AddUnaryOp(UnaryOpcode::steadyState, iArg1);
644 }
645 
646 expr_t
AddSteadyStateParamDeriv(expr_t iArg1,int param_symb_id)647 DataTree::AddSteadyStateParamDeriv(expr_t iArg1, int param_symb_id)
648 {
649   return AddUnaryOp(UnaryOpcode::steadyStateParamDeriv, iArg1, 0, param_symb_id);
650 }
651 
652 expr_t
AddSteadyStateParam2ndDeriv(expr_t iArg1,int param1_symb_id,int param2_symb_id)653 DataTree::AddSteadyStateParam2ndDeriv(expr_t iArg1, int param1_symb_id, int param2_symb_id)
654 {
655   return AddUnaryOp(UnaryOpcode::steadyStateParam2ndDeriv, iArg1, 0, param1_symb_id, param2_symb_id);
656 }
657 
658 expr_t
AddExpectation(int iArg1,expr_t iArg2)659 DataTree::AddExpectation(int iArg1, expr_t iArg2)
660 {
661   return AddUnaryOp(UnaryOpcode::expectation, iArg2, iArg1);
662 }
663 
664 expr_t
AddVarExpectation(const string & model_name)665 DataTree::AddVarExpectation(const string &model_name)
666 {
667   if (auto it = var_expectation_node_map.find(model_name);
668       it != var_expectation_node_map.end())
669     return it->second;
670 
671   auto sp = make_unique<VarExpectationNode>(*this, node_list.size(), model_name);
672   auto p = sp.get();
673   node_list.push_back(move(sp));
674   var_expectation_node_map[model_name] = p;
675   return p;
676 }
677 
678 expr_t
AddPacExpectation(const string & model_name)679 DataTree::AddPacExpectation(const string &model_name)
680 {
681   if (auto it = pac_expectation_node_map.find(model_name);
682       it != pac_expectation_node_map.end())
683     return it->second;
684 
685   auto sp = make_unique<PacExpectationNode>(*this, node_list.size(), model_name);
686   auto p = sp.get();
687   node_list.push_back(move(sp));
688   pac_expectation_node_map[model_name] = p;
689   return p;
690 }
691 
692 expr_t
AddEqual(expr_t iArg1,expr_t iArg2)693 DataTree::AddEqual(expr_t iArg1, expr_t iArg2)
694 {
695   return AddBinaryOp(iArg1, BinaryOpcode::equal, iArg2);
696 }
697 
698 void
AddLocalVariable(int symb_id,expr_t value)699 DataTree::AddLocalVariable(int symb_id, expr_t value) noexcept(false)
700 {
701   assert(symbol_table.getType(symb_id) == SymbolType::modelLocalVariable);
702 
703   // Throw an exception if symbol already declared
704   if (auto it = local_variables_table.find(symb_id);
705       it != local_variables_table.end())
706     throw LocalVariableException(symbol_table.getName(symb_id));
707 
708   local_variables_table[symb_id] = value;
709   local_variables_vector.push_back(symb_id);
710 }
711 
712 expr_t
AddExternalFunction(int symb_id,const vector<expr_t> & arguments)713 DataTree::AddExternalFunction(int symb_id, const vector<expr_t> &arguments)
714 {
715   assert(symbol_table.getType(symb_id) == SymbolType::externalFunction);
716 
717   if (auto it = external_function_node_map.find({ arguments, symb_id });
718       it != external_function_node_map.end())
719     return it->second;
720 
721   auto sp = make_unique<ExternalFunctionNode>(*this, node_list.size(), symb_id, arguments);
722   auto p = sp.get();
723   node_list.push_back(move(sp));
724   external_function_node_map[{ arguments, symb_id }] = p;
725   return p;
726 }
727 
728 expr_t
AddFirstDerivExternalFunction(int top_level_symb_id,const vector<expr_t> & arguments,int input_index)729 DataTree::AddFirstDerivExternalFunction(int top_level_symb_id, const vector<expr_t> &arguments, int input_index)
730 {
731   assert(symbol_table.getType(top_level_symb_id) == SymbolType::externalFunction);
732 
733   if (auto it = first_deriv_external_function_node_map.find({ arguments, input_index, top_level_symb_id });
734       it != first_deriv_external_function_node_map.end())
735     return it->second;
736 
737   auto sp = make_unique<FirstDerivExternalFunctionNode>(*this, node_list.size(), top_level_symb_id, arguments, input_index);
738   auto p = sp.get();
739   node_list.push_back(move(sp));
740   first_deriv_external_function_node_map[{ arguments, input_index, top_level_symb_id }] = p;
741   return p;
742 }
743 
744 expr_t
AddSecondDerivExternalFunction(int top_level_symb_id,const vector<expr_t> & arguments,int input_index1,int input_index2)745 DataTree::AddSecondDerivExternalFunction(int top_level_symb_id, const vector<expr_t> &arguments, int input_index1, int input_index2)
746 {
747   assert(symbol_table.getType(top_level_symb_id) == SymbolType::externalFunction);
748 
749   if (auto it = second_deriv_external_function_node_map.find({ arguments, input_index1, input_index2,
750                                                                top_level_symb_id });
751     it != second_deriv_external_function_node_map.end())
752     return it->second;
753 
754   auto sp = make_unique<SecondDerivExternalFunctionNode>(*this, node_list.size(), top_level_symb_id, arguments, input_index1, input_index2);
755   auto p = sp.get();
756   node_list.push_back(move(sp));
757   second_deriv_external_function_node_map[{ arguments, input_index1, input_index2, top_level_symb_id }] = p;
758   return p;
759 }
760 
761 bool
isSymbolUsed(int symb_id) const762 DataTree::isSymbolUsed(int symb_id) const
763 {
764   for (const auto &it : variable_node_map)
765     if (it.first.first == symb_id)
766       return true;
767 
768   if (local_variables_table.find(symb_id) != local_variables_table.end())
769     return true;
770 
771   return false;
772 }
773 
774 int
getDerivID(int symb_id,int lag) const775 DataTree::getDerivID(int symb_id, int lag) const noexcept(false)
776 {
777   throw UnknownDerivIDException();
778 }
779 
780 SymbolType
getTypeByDerivID(int deriv_id) const781 DataTree::getTypeByDerivID(int deriv_id) const noexcept(false)
782 {
783   throw UnknownDerivIDException();
784 }
785 
786 int
getLagByDerivID(int deriv_id) const787 DataTree::getLagByDerivID(int deriv_id) const noexcept(false)
788 {
789   throw UnknownDerivIDException();
790 }
791 
792 int
getSymbIDByDerivID(int deriv_id) const793 DataTree::getSymbIDByDerivID(int deriv_id) const noexcept(false)
794 {
795   throw UnknownDerivIDException();
796 }
797 
798 void
addAllParamDerivId(set<int> & deriv_id_set)799 DataTree::addAllParamDerivId(set<int> &deriv_id_set)
800 {
801 }
802 
803 int
getDynJacobianCol(int deriv_id) const804 DataTree::getDynJacobianCol(int deriv_id) const noexcept(false)
805 {
806   throw UnknownDerivIDException();
807 }
808 
809 bool
isUnaryOpUsed(UnaryOpcode opcode) const810 DataTree::isUnaryOpUsed(UnaryOpcode opcode) const
811 {
812   for (const auto &it : unary_op_node_map)
813     if (get<1>(it.first) == opcode)
814       return true;
815 
816   return false;
817 }
818 
819 bool
isUnaryOpUsedOnType(SymbolType type,UnaryOpcode opcode) const820 DataTree::isUnaryOpUsedOnType(SymbolType type, UnaryOpcode opcode) const
821 {
822   set<int> var;
823   for (const auto &it : unary_op_node_map)
824     if (get<1>(it.first) == opcode)
825       {
826         it.second->collectVariables(type, var);
827         if (!var.empty())
828           return true;
829       }
830   return false;
831 }
832 
833 bool
isBinaryOpUsed(BinaryOpcode opcode) const834 DataTree::isBinaryOpUsed(BinaryOpcode opcode) const
835 {
836   for (const auto &it : binary_op_node_map)
837     if (get<2>(it.first) == opcode)
838       return true;
839 
840   return false;
841 }
842 
843 bool
isBinaryOpUsedOnType(SymbolType type,BinaryOpcode opcode) const844 DataTree::isBinaryOpUsedOnType(SymbolType type, BinaryOpcode opcode) const
845 {
846   set<int> var;
847   for (const auto &it : binary_op_node_map)
848     if (get<2>(it.first) == opcode)
849       {
850         it.second->collectVariables(type, var);
851         if (!var.empty())
852           return true;
853       }
854   return false;
855 }
856 
857 bool
isTrinaryOpUsed(TrinaryOpcode opcode) const858 DataTree::isTrinaryOpUsed(TrinaryOpcode opcode) const
859 {
860   for (const auto &it : trinary_op_node_map)
861     if (get<3>(it.first) == opcode)
862       return true;
863 
864   return false;
865 }
866 
867 bool
isExternalFunctionUsed(int symb_id) const868 DataTree::isExternalFunctionUsed(int symb_id) const
869 {
870   for (const auto &it : external_function_node_map)
871     if (it.first.second == symb_id)
872       return true;
873 
874   return false;
875 }
876 
877 bool
isFirstDerivExternalFunctionUsed(int symb_id) const878 DataTree::isFirstDerivExternalFunctionUsed(int symb_id) const
879 {
880   for (const auto &it : first_deriv_external_function_node_map)
881     if (get<2>(it.first) == symb_id)
882       return true;
883 
884   return false;
885 }
886 
887 bool
isSecondDerivExternalFunctionUsed(int symb_id) const888 DataTree::isSecondDerivExternalFunctionUsed(int symb_id) const
889 {
890   for (const auto &it : second_deriv_external_function_node_map)
891     if (get<3>(it.first) == symb_id)
892       return true;
893 
894   return false;
895 }
896 
897 int
minLagForSymbol(int symb_id) const898 DataTree::minLagForSymbol(int symb_id) const
899 {
900   int r = 0;
901   for (const auto &it : variable_node_map)
902     if (it.first.first == symb_id && it.first.second < r)
903       r = it.first.second;
904   return r;
905 }
906 
907 void
writePowerDerivCHeader(ostream & output) const908 DataTree::writePowerDerivCHeader(ostream &output) const
909 {
910   if (isBinaryOpUsed(BinaryOpcode::powerDeriv))
911     output << "double getPowerDeriv(double, double, int);" << endl;
912 }
913 
914 void
writePowerDeriv(ostream & output) const915 DataTree::writePowerDeriv(ostream &output) const
916 {
917   if (isBinaryOpUsed(BinaryOpcode::powerDeriv))
918     output << "/*" << endl
919            << " * The k-th derivative of x^p" << endl
920            << " */" << endl
921            << "double getPowerDeriv(double x, double p, int k)" << endl
922            << "{" << endl
923            << "#ifdef _MSC_VER" << endl
924            << "# define nearbyint(x) (fabs((x)-floor(x)) < fabs((x)-ceil(x)) ? floor(x) : ceil(x))" << endl
925            << "#endif" << endl
926            << "  if ( fabs(x) < " << near_zero << " && p > 0 && k > p && fabs(p-nearbyint(p)) < " << near_zero << " )" << endl
927            << "    return 0.0;" << endl
928            << "  else" << endl
929            << "    {" << endl
930            << "      int i = 0;" << endl
931            << "      double dxp = pow(x, p-k);" << endl
932            << "      for (; i<k; i++)" << endl
933            << "        dxp *= p--;" << endl
934            << "      return dxp;" << endl
935            << "    }" << endl
936            << "}" << endl;
937 }
938 
939 string
packageDir(const string & package)940 DataTree::packageDir(const string &package)
941 {
942   regex pat{R"(\.)"};
943   string dirname = "+" + regex_replace(package, pat, "/+");
944   filesystem::create_directories(dirname);
945   return dirname;
946 }
947