1 /*
2  * Copyright © 2007-2021 Dynare Team
3  *
4  * This file is part of Dynare.
5  *
6  * Dynare is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Dynare is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with Dynare.  If not, see <http://www.gnu.org/licenses/>.
18  */
19 
20 #include <iostream>
21 #include <algorithm>
22 #include <cassert>
23 #include <cmath>
24 #include <utility>
25 #include <limits>
26 
27 #include "ExprNode.hh"
28 #include "DataTree.hh"
29 #include "ModFile.hh"
30 
ExprNode(DataTree & datatree_arg,int idx_arg)31 ExprNode::ExprNode(DataTree &datatree_arg, int idx_arg) : datatree{datatree_arg}, idx{idx_arg}
32 {
33 }
34 
35 expr_t
getDerivative(int deriv_id)36 ExprNode::getDerivative(int deriv_id)
37 {
38   if (!preparedForDerivation)
39     prepareForDerivation();
40 
41   // Return zero if derivative is necessarily null (using symbolic a priori)
42   if (auto it = non_null_derivatives.find(deriv_id); it == non_null_derivatives.end())
43     return datatree.Zero;
44 
45   // If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
46   if (auto it2 = derivatives.find(deriv_id); it2 != derivatives.end())
47     return it2->second;
48   else
49     {
50       expr_t d = computeDerivative(deriv_id);
51       derivatives[deriv_id] = d;
52       return d;
53     }
54 }
55 
56 int
precedence(ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms) const57 ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
58 {
59   // For a constant, a variable, or a unary op, the precedence is maximal
60   return 100;
61 }
62 
63 int
precedenceJson(const temporary_terms_t & temporary_terms) const64 ExprNode::precedenceJson(const temporary_terms_t &temporary_terms) const
65 {
66   // For a constant, a variable, or a unary op, the precedence is maximal
67   return 100;
68 }
69 
70 int
cost(int cost,bool is_matlab) const71 ExprNode::cost(int cost, bool is_matlab) const
72 {
73   // For a terminal node, the cost is null
74   return 0;
75 }
76 
77 int
cost(const temporary_terms_t & temp_terms_map,bool is_matlab) const78 ExprNode::cost(const temporary_terms_t &temp_terms_map, bool is_matlab) const
79 {
80   // For a terminal node, the cost is null
81   return 0;
82 }
83 
84 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const85 ExprNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
86 {
87   // For a terminal node, the cost is null
88   return 0;
89 }
90 
91 bool
checkIfTemporaryTermThenWrite(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs) const92 ExprNode::checkIfTemporaryTermThenWrite(ostream &output, ExprNodeOutputType output_type,
93                                         const temporary_terms_t &temporary_terms,
94                                         const temporary_terms_idxs_t &temporary_terms_idxs) const
95 {
96   if (auto it = temporary_terms.find(const_cast<ExprNode *>(this)); it == temporary_terms.end())
97     return false;
98 
99   if (output_type == ExprNodeOutputType::matlabDynamicModelSparse)
100     output << "T" << idx << "(it_)";
101   else
102     if (output_type == ExprNodeOutputType::matlabStaticModelSparse)
103       output << "T" << idx;
104     else
105       {
106         auto it2 = temporary_terms_idxs.find(const_cast<ExprNode *>(this));
107         // It is the responsibility of the caller to ensure that all temporary terms have their index
108         assert(it2 != temporary_terms_idxs.end());
109         output << "T" << LEFT_ARRAY_SUBSCRIPT(output_type)
110                << it2->second + ARRAY_SUBSCRIPT_OFFSET(output_type)
111                << RIGHT_ARRAY_SUBSCRIPT(output_type);
112       }
113   return true;
114 }
115 
116 pair<expr_t, int>
getLagEquivalenceClass() const117 ExprNode::getLagEquivalenceClass() const
118 {
119   int index = maxLead();
120 
121   if (index == numeric_limits<int>::min())
122     index = 0; // If no variable in the expression, the equivalence class has size 1
123 
124   return { decreaseLeadsLags(index), index };
125 }
126 
127 void
collectVariables(SymbolType type,set<int> & result) const128 ExprNode::collectVariables(SymbolType type, set<int> &result) const
129 {
130   set<pair<int, int>> symbs_lags;
131   collectDynamicVariables(type, symbs_lags);
132   transform(symbs_lags.begin(), symbs_lags.end(), inserter(result, result.begin()),
133             [](auto x) { return x.first; });
134 }
135 
136 void
collectEndogenous(set<pair<int,int>> & result) const137 ExprNode::collectEndogenous(set<pair<int, int>> &result) const
138 {
139   set<pair<int, int>> symb_ids;
140   collectDynamicVariables(SymbolType::endogenous, symb_ids);
141   for (const auto &symb_id : symb_ids)
142     result.emplace(datatree.symbol_table.getTypeSpecificID(symb_id.first), symb_id.second);
143 }
144 
145 void
collectExogenous(set<pair<int,int>> & result) const146 ExprNode::collectExogenous(set<pair<int, int>> &result) const
147 {
148   set<pair<int, int>> symb_ids;
149   collectDynamicVariables(SymbolType::exogenous, symb_ids);
150   for (const auto &symb_id : symb_ids)
151     result.emplace(datatree.symbol_table.getTypeSpecificID(symb_id.first), symb_id.second);
152 }
153 
154 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const155 ExprNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
156                                 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
157                                 map<expr_t, pair<int, pair<int, int>>> &reference_count,
158                                 bool is_matlab) const
159 {
160   // Nothing to do for a terminal node
161 }
162 
163 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const164 ExprNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
165                                 temporary_terms_t &temporary_terms,
166                                 map<expr_t, pair<int, int>> &first_occurence,
167                                 int Curr_block,
168                                 vector<vector<temporary_terms_t>> &v_temporary_terms,
169                                 int equation) const
170 {
171   // Nothing to do for a terminal node
172 }
173 
174 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const175 ExprNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
176 {
177   /* nothing to do */
178   return { 0, nullptr };
179 }
180 
181 void
writeOutput(ostream & output) const182 ExprNode::writeOutput(ostream &output) const
183 {
184   writeOutput(output, ExprNodeOutputType::matlabOutsideModel, {}, {});
185 }
186 
187 void
writeOutput(ostream & output,ExprNodeOutputType output_type) const188 ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type) const
189 {
190   writeOutput(output, output_type, {}, {});
191 }
192 
193 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs) const194 ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const
195 {
196   writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, {});
197 }
198 
199 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic) const200 ExprNode::compile(ostream &CompileCode, unsigned int &instruction_number,
201                   bool lhs_rhs, const temporary_terms_t &temporary_terms,
202                   const map_idx_t &map_idx, bool dynamic, bool steady_dynamic) const
203 {
204   compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, {});
205 }
206 
207 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const208 ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
209                                       const temporary_terms_t &temporary_terms,
210                                       const temporary_terms_idxs_t &temporary_terms_idxs,
211                                       deriv_node_temp_terms_t &tef_terms) const
212 {
213   // Nothing to do
214 }
215 
216 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const217 ExprNode::writeJsonExternalFunctionOutput(vector<string> &efout,
218                                           const temporary_terms_t &temporary_terms,
219                                           deriv_node_temp_terms_t &tef_terms,
220                                           bool isdynamic) const
221 {
222   // Nothing to do
223 }
224 
225 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const226 ExprNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
227                                         bool lhs_rhs, const temporary_terms_t &temporary_terms,
228                                         const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
229                                         deriv_node_temp_terms_t &tef_terms) const
230 {
231   // Nothing to do
232 }
233 
234 VariableNode *
createEndoLeadAuxiliaryVarForMyself(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const235 ExprNode::createEndoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
236 {
237   int n = maxEndoLead();
238   assert(n >= 2);
239 
240   if (auto it = subst_table.find(this);
241       it != subst_table.end())
242     return const_cast<VariableNode *>(it->second);
243 
244   expr_t substexpr = decreaseLeadsLags(n-1);
245   int lag = n-2;
246 
247   // Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag)
248   // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag))
249   while (lag >= 0)
250     {
251       expr_t orig_expr = decreaseLeadsLags(lag);
252       if (auto it = subst_table.find(orig_expr); it == subst_table.end())
253         {
254           int symb_id = datatree.symbol_table.addEndoLeadAuxiliaryVar(orig_expr->idx, substexpr);
255           neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr)));
256           substexpr = datatree.AddVariable(symb_id, +1);
257           assert(dynamic_cast<VariableNode *>(substexpr));
258           subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr);
259         }
260       else
261         substexpr = const_cast<VariableNode *>(it->second);
262 
263       lag--;
264     }
265 
266   return dynamic_cast<VariableNode *>(substexpr);
267 }
268 
269 VariableNode *
createExoLeadAuxiliaryVarForMyself(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const270 ExprNode::createExoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
271 {
272   int n = maxExoLead();
273   assert(n >= 1);
274 
275   if (auto it = subst_table.find(this);
276       it != subst_table.end())
277     return const_cast<VariableNode *>(it->second);
278 
279   expr_t substexpr = decreaseLeadsLags(n);
280   int lag = n-1;
281 
282   // Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag)
283   // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag))
284   while (lag >= 0)
285     {
286       expr_t orig_expr = decreaseLeadsLags(lag);
287       if (auto it = subst_table.find(orig_expr); it == subst_table.end())
288         {
289           int symb_id = datatree.symbol_table.addExoLeadAuxiliaryVar(orig_expr->idx, substexpr);
290           neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr)));
291           substexpr = datatree.AddVariable(symb_id, +1);
292           assert(dynamic_cast<VariableNode *>(substexpr));
293           subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr);
294         }
295       else
296         substexpr = const_cast<VariableNode *>(it->second);
297 
298       lag--;
299     }
300 
301   return dynamic_cast<VariableNode *>(substexpr);
302 }
303 
304 bool
isNumConstNodeEqualTo(double value) const305 ExprNode::isNumConstNodeEqualTo(double value) const
306 {
307   return false;
308 }
309 
310 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const311 ExprNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
312 {
313   return false;
314 }
315 
316 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const317 ExprNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
318 {
319 }
320 
321 void
fillErrorCorrectionRow(int eqn,const vector<int> & nontarget_lhs,const vector<int> & target_lhs,map<tuple<int,int,int>,expr_t> & A0,map<tuple<int,int,int>,expr_t> & A0star) const322 ExprNode::fillErrorCorrectionRow(int eqn,
323                                  const vector<int> &nontarget_lhs,
324                                  const vector<int> &target_lhs,
325                                  map<tuple<int, int, int>, expr_t> &A0,
326                                  map<tuple<int, int, int>, expr_t> &A0star) const
327 {
328   vector<pair<expr_t, int>> terms;
329   decomposeAdditiveTerms(terms, 1);
330 
331   for (const auto &it : terms)
332     {
333       pair<int, vector<tuple<int, int, int, double>>> m;
334       try
335         {
336           m = it.first->matchParamTimesLinearCombinationOfVariables();
337           for (auto &t : m.second)
338             get<3>(t) *= it.second; // Update sign of constants
339         }
340       catch (MatchFailureException &e)
341         {
342           /* FIXME: we should not just skip them, but rather verify that they are
343              autoregressive terms or residuals (probably by merging the two "fill" procedures) */
344           continue;
345         }
346 
347       // Helper function
348       auto one_step_orig = [this](int symb_id) {
349                              return datatree.symbol_table.isAuxiliaryVariable(symb_id) ?
350                                datatree.symbol_table.getOrigSymbIdForDiffAuxVar(symb_id) : symb_id;
351                            };
352 
353       /* Verify that all variables belong to the error-correction term.
354          FIXME: same remark as above about skipping terms. */
355       bool not_ec = false;
356       for (const auto &t : m.second)
357         {
358           int vid = one_step_orig(get<0>(t));
359           not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end()
360                               && find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end());
361         }
362       if (not_ec)
363         continue;
364 
365       // Now fill the matrices
366       for (auto [var_id, lag, param_id, constant] : m.second)
367         {
368           int orig_vid = one_step_orig(var_id);
369           int orig_lag = datatree.symbol_table.isAuxiliaryVariable(var_id) ? -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(var_id) : lag;
370           if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end())
371             {
372               // This an LHS variable, so fill A0
373               if (constant != 1)
374                 {
375                   cerr << "ERROR in trend component model: LHS variable should not appear with a multiplicative constant in error correction term" << endl;
376                   exit(EXIT_FAILURE);
377                 }
378               if (param_id != -1)
379                 {
380                   cerr << "ERROR in trend component model: spurious parameter in error correction term" << endl;
381                   exit(EXIT_FAILURE);
382                 }
383               int colidx = static_cast<int>(distance(nontarget_lhs.begin(), find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_vid)));
384               if (A0.find({eqn, -orig_lag, colidx}) != A0.end())
385                 {
386                   cerr << "ExprNode::fillErrorCorrection: Error filling A0 matrix: "
387                        << "lag/symb_id encountered more than once in equation" << endl;
388                   exit(EXIT_FAILURE);
389                 }
390               A0[{eqn, -orig_lag, colidx}] = datatree.AddVariable(m.first);
391             }
392           else
393             {
394               // This is a target, so fill A0star
395               int colidx = static_cast<int>(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid)));
396               expr_t e = datatree.AddTimes(datatree.AddVariable(m.first), datatree.AddPossiblyNegativeConstant(-constant));
397               if (param_id != -1)
398                 e = datatree.AddTimes(e, datatree.AddVariable(param_id));
399               if (auto coor = tuple(eqn, -orig_lag, colidx); A0star.find(coor) == A0star.end())
400                 A0star[coor] = e;
401               else
402                 A0star[coor] = datatree.AddPlus(e, A0star[coor]);
403             }
404         }
405     }
406 }
407 
NumConstNode(DataTree & datatree_arg,int idx_arg,int id_arg)408 NumConstNode::NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg) :
409   ExprNode{datatree_arg, idx_arg},
410   id{id_arg}
411 {
412 }
413 
414 int
countDiffs() const415 NumConstNode::countDiffs() const
416 {
417   return 0;
418 }
419 
420 void
prepareForDerivation()421 NumConstNode::prepareForDerivation()
422 {
423   preparedForDerivation = true;
424   // All derivatives are null, so non_null_derivatives is left empty
425 }
426 
427 expr_t
computeDerivative(int deriv_id)428 NumConstNode::computeDerivative(int deriv_id)
429 {
430   return datatree.Zero;
431 }
432 
433 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const434 NumConstNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
435 {
436   if (temporary_terms.find(const_cast<NumConstNode *>(this)) != temporary_terms.end())
437     temporary_terms_inuse.insert(idx);
438 }
439 
440 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const441 NumConstNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
442                           const temporary_terms_t &temporary_terms,
443                           const temporary_terms_idxs_t &temporary_terms_idxs,
444                           const deriv_node_temp_terms_t &tef_terms) const
445 {
446   if (!checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
447     output << datatree.num_constants.get(id);
448 }
449 
450 void
writeJsonAST(ostream & output) const451 NumConstNode::writeJsonAST(ostream &output) const
452 {
453   output << R"({"node_type" : "NumConstNode", "value" : )";
454   output << std::stof(datatree.num_constants.get(id)) << "}";
455 }
456 
457 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const458 NumConstNode::writeJsonOutput(ostream &output,
459                               const temporary_terms_t &temporary_terms,
460                               const deriv_node_temp_terms_t &tef_terms,
461                               bool isdynamic) const
462 {
463   output << datatree.num_constants.get(id);
464 }
465 
466 bool
containsExternalFunction() const467 NumConstNode::containsExternalFunction() const
468 {
469   return false;
470 }
471 
472 double
eval(const eval_context_t & eval_context) const473 NumConstNode::eval(const eval_context_t &eval_context) const noexcept(false)
474 {
475   return datatree.num_constants.getDouble(id);
476 }
477 
478 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const479 NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number,
480                       bool lhs_rhs, const temporary_terms_t &temporary_terms,
481                       const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
482                       const deriv_node_temp_terms_t &tef_terms) const
483 {
484   FLDC_ fldc(datatree.num_constants.getDouble(id));
485   fldc.write(CompileCode, instruction_number);
486 }
487 
488 void
collectVARLHSVariable(set<expr_t> & result) const489 NumConstNode::collectVARLHSVariable(set<expr_t> &result) const
490 {
491   cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
492   exit(EXIT_FAILURE);
493 }
494 
495 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const496 NumConstNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
497 {
498 }
499 
500 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const501 NumConstNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
502 {
503   /* return the numercial constant */
504   return { 0, datatree.AddNonNegativeConstant(datatree.num_constants.get(id)) };
505 }
506 
507 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)508 NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
509 {
510   return datatree.Zero;
511 }
512 
513 expr_t
toStatic(DataTree & static_datatree) const514 NumConstNode::toStatic(DataTree &static_datatree) const
515 {
516   return static_datatree.AddNonNegativeConstant(datatree.num_constants.get(id));
517 }
518 
519 void
computeXrefs(EquationInfo & ei) const520 NumConstNode::computeXrefs(EquationInfo &ei) const
521 {
522 }
523 
524 expr_t
clone(DataTree & datatree) const525 NumConstNode::clone(DataTree &datatree) const
526 {
527   return datatree.AddNonNegativeConstant(datatree.num_constants.get(id));
528 }
529 
530 int
maxEndoLead() const531 NumConstNode::maxEndoLead() const
532 {
533   return 0;
534 }
535 
536 int
maxExoLead() const537 NumConstNode::maxExoLead() const
538 {
539   return 0;
540 }
541 
542 int
maxEndoLag() const543 NumConstNode::maxEndoLag() const
544 {
545   return 0;
546 }
547 
548 int
maxExoLag() const549 NumConstNode::maxExoLag() const
550 {
551   return 0;
552 }
553 
554 int
maxLead() const555 NumConstNode::maxLead() const
556 {
557   return numeric_limits<int>::min();
558 }
559 
560 int
maxLag() const561 NumConstNode::maxLag() const
562 {
563   return numeric_limits<int>::min();
564 }
565 
566 int
maxLagWithDiffsExpanded() const567 NumConstNode::maxLagWithDiffsExpanded() const
568 {
569   return numeric_limits<int>::min();
570 }
571 
572 expr_t
undiff() const573 NumConstNode::undiff() const
574 {
575   return const_cast<NumConstNode *>(this);
576 }
577 
578 int
VarMinLag() const579 NumConstNode::VarMinLag() const
580 {
581   return 1;
582 }
583 
584 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const585 NumConstNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
586 {
587   return 0;
588 }
589 
590 int
PacMaxLag(int lhs_symb_id) const591 NumConstNode::PacMaxLag(int lhs_symb_id) const
592 {
593   return 0;
594 }
595 
596 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const597 NumConstNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
598 {
599   return -1;
600 }
601 
602 expr_t
decreaseLeadsLags(int n) const603 NumConstNode::decreaseLeadsLags(int n) const
604 {
605   return const_cast<NumConstNode *>(this);
606 }
607 
608 expr_t
decreaseLeadsLagsPredeterminedVariables() const609 NumConstNode::decreaseLeadsLagsPredeterminedVariables() const
610 {
611   return const_cast<NumConstNode *>(this);
612 }
613 
614 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const615 NumConstNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
616 {
617   return const_cast<NumConstNode *>(this);
618 }
619 
620 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const621 NumConstNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
622 {
623   return const_cast<NumConstNode *>(this);
624 }
625 
626 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const627 NumConstNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
628 {
629   return const_cast<NumConstNode *>(this);
630 }
631 
632 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const633 NumConstNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
634 {
635   return const_cast<NumConstNode *>(this);
636 }
637 
638 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const639 NumConstNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
640 {
641   return const_cast<NumConstNode *>(this);
642 }
643 
644 expr_t
substituteAdl() const645 NumConstNode::substituteAdl() const
646 {
647   return const_cast<NumConstNode *>(this);
648 }
649 
650 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const651 NumConstNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
652 {
653   return const_cast<NumConstNode *>(this);
654 }
655 
656 void
findDiffNodes(lag_equivalence_table_t & nodes) const657 NumConstNode::findDiffNodes(lag_equivalence_table_t &nodes) const
658 {
659 }
660 
661 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const662 NumConstNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
663 {
664 }
665 
666 int
findTargetVariable(int lhs_symb_id) const667 NumConstNode::findTargetVariable(int lhs_symb_id) const
668 {
669   return -1;
670 }
671 
672 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const673 NumConstNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
674 {
675   return const_cast<NumConstNode *>(this);
676 }
677 
678 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const679 NumConstNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
680 {
681   return const_cast<NumConstNode *>(this);
682 }
683 
684 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)685 NumConstNode::substitutePacExpectation(const string &name, expr_t subexpr)
686 {
687   return const_cast<NumConstNode *>(this);
688 }
689 
690 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const691 NumConstNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
692 {
693   return const_cast<NumConstNode *>(this);
694 }
695 
696 bool
isNumConstNodeEqualTo(double value) const697 NumConstNode::isNumConstNodeEqualTo(double value) const
698 {
699   if (datatree.num_constants.getDouble(id) == value)
700     return true;
701   else
702     return false;
703 }
704 
705 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const706 NumConstNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
707 {
708   return false;
709 }
710 
711 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const712 NumConstNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
713 {
714 }
715 
716 bool
containsPacExpectation(const string & pac_model_name) const717 NumConstNode::containsPacExpectation(const string &pac_model_name) const
718 {
719   return false;
720 }
721 
722 bool
containsEndogenous() const723 NumConstNode::containsEndogenous() const
724 {
725   return false;
726 }
727 
728 bool
containsExogenous() const729 NumConstNode::containsExogenous() const
730 {
731   return false;
732 }
733 
734 expr_t
replaceTrendVar() const735 NumConstNode::replaceTrendVar() const
736 {
737   return const_cast<NumConstNode *>(this);
738 }
739 
740 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const741 NumConstNode::detrend(int symb_id, bool log_trend, expr_t trend) const
742 {
743   return const_cast<NumConstNode *>(this);
744 }
745 
746 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const747 NumConstNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
748 {
749   return const_cast<NumConstNode *>(this);
750 }
751 
752 bool
isInStaticForm() const753 NumConstNode::isInStaticForm() const
754 {
755   return true;
756 }
757 
758 bool
isParamTimesEndogExpr() const759 NumConstNode::isParamTimesEndogExpr() const
760 {
761   return false;
762 }
763 
764 bool
isVarModelReferenced(const string & model_info_name) const765 NumConstNode::isVarModelReferenced(const string &model_info_name) const
766 {
767   return false;
768 }
769 
770 expr_t
substituteStaticAuxiliaryVariable() const771 NumConstNode::substituteStaticAuxiliaryVariable() const
772 {
773   return const_cast<NumConstNode *>(this);
774 }
775 
776 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const777 NumConstNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
778 {
779   return;
780 }
781 
782 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const783 NumConstNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
784 {
785   return const_cast<NumConstNode *>(this);
786 }
787 
VariableNode(DataTree & datatree_arg,int idx_arg,int symb_id_arg,int lag_arg)788 VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
789   ExprNode{datatree_arg, idx_arg},
790   symb_id{symb_id_arg},
791   lag{lag_arg}
792 {
793   // It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
794   assert(get_type() != SymbolType::externalFunction
795          && (lag == 0 || (get_type() != SymbolType::modelLocalVariable && get_type() != SymbolType::modFileLocalVariable)));
796 }
797 
798 void
prepareForDerivation()799 VariableNode::prepareForDerivation()
800 {
801   if (preparedForDerivation)
802     return;
803 
804   preparedForDerivation = true;
805 
806   // Fill in non_null_derivatives
807   switch (get_type())
808     {
809     case SymbolType::endogenous:
810     case SymbolType::exogenous:
811     case SymbolType::exogenousDet:
812     case SymbolType::parameter:
813     case SymbolType::trend:
814     case SymbolType::logTrend:
815       // For a variable or a parameter, the only non-null derivative is with respect to itself
816       non_null_derivatives.insert(datatree.getDerivID(symb_id, lag));
817       break;
818     case SymbolType::modelLocalVariable:
819       datatree.getLocalVariable(symb_id)->prepareForDerivation();
820       // Non null derivatives are those of the value of the local parameter
821       non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives;
822       break;
823     case SymbolType::modFileLocalVariable:
824     case SymbolType::statementDeclaredVariable:
825     case SymbolType::unusedEndogenous:
826       // Such a variable is never derived
827       break;
828     case SymbolType::externalFunction:
829     case SymbolType::endogenousVAR:
830     case SymbolType::epilogue:
831       cerr << "VariableNode::prepareForDerivation: impossible case" << endl;
832       exit(EXIT_FAILURE);
833     case SymbolType::excludedVariable:
834       cerr << "VariableNode::prepareForDerivation: impossible case: "
835            << "You are trying to derive a variable that has been excluded via include_eqs/exclude_eqs: "
836            << datatree.symbol_table.getName(symb_id) << endl;
837       exit(EXIT_FAILURE);
838     }
839 }
840 
841 expr_t
computeDerivative(int deriv_id)842 VariableNode::computeDerivative(int deriv_id)
843 {
844   switch (get_type())
845     {
846     case SymbolType::endogenous:
847     case SymbolType::exogenous:
848     case SymbolType::exogenousDet:
849     case SymbolType::parameter:
850     case SymbolType::trend:
851     case SymbolType::logTrend:
852       if (deriv_id == datatree.getDerivID(symb_id, lag))
853         return datatree.One;
854       else
855         return datatree.Zero;
856     case SymbolType::modelLocalVariable:
857       return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id);
858     case SymbolType::modFileLocalVariable:
859       cerr << "modFileLocalVariable is not derivable" << endl;
860       exit(EXIT_FAILURE);
861     case SymbolType::statementDeclaredVariable:
862       cerr << "statementDeclaredVariable is not derivable" << endl;
863       exit(EXIT_FAILURE);
864     case SymbolType::unusedEndogenous:
865       cerr << "unusedEndogenous is not derivable" << endl;
866       exit(EXIT_FAILURE);
867     case SymbolType::externalFunction:
868     case SymbolType::endogenousVAR:
869     case SymbolType::epilogue:
870     case SymbolType::excludedVariable:
871       cerr << "VariableNode::computeDerivative: Impossible case!" << endl;
872       exit(EXIT_FAILURE);
873     }
874   // Suppress GCC warning
875   exit(EXIT_FAILURE);
876 }
877 
878 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const879 VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
880 {
881   if (temporary_terms.find(const_cast<VariableNode *>(this)) != temporary_terms.end())
882     temporary_terms_inuse.insert(idx);
883   if (get_type() == SymbolType::modelLocalVariable)
884     datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
885 }
886 
887 bool
containsExternalFunction() const888 VariableNode::containsExternalFunction() const
889 {
890   if (get_type() == SymbolType::modelLocalVariable)
891     return datatree.getLocalVariable(symb_id)->containsExternalFunction();
892 
893   return false;
894 }
895 
896 void
writeJsonAST(ostream & output) const897 VariableNode::writeJsonAST(ostream &output) const
898 {
899   output << R"({"node_type" : "VariableNode", )"
900          << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "type" : ")";
901   switch (get_type())
902     {
903     case SymbolType::endogenous:
904       output << "endogenous";
905       break;
906     case SymbolType::exogenous:
907       output << "exogenous";
908       break;
909     case SymbolType::exogenousDet:
910       output << "exogenousDet";
911       break;
912     case SymbolType::parameter:
913       output << "parameter";
914       break;
915     case SymbolType::modelLocalVariable:
916       output << "modelLocalVariable";
917       break;
918     case SymbolType::modFileLocalVariable:
919       output << "modFileLocalVariable";
920       break;
921     case SymbolType::externalFunction:
922       output << "externalFunction";
923       break;
924     case SymbolType::trend:
925       output << "trend";
926       break;
927     case SymbolType::statementDeclaredVariable:
928       output << "statementDeclaredVariable";
929       break;
930     case SymbolType::logTrend:
931       output << "logTrend:";
932       break;
933     case SymbolType::unusedEndogenous:
934       output << "unusedEndogenous";
935       break;
936     case SymbolType::endogenousVAR:
937       output << "endogenousVAR";
938       break;
939     case SymbolType::epilogue:
940       output << "epilogue";
941       break;
942     case SymbolType::excludedVariable:
943       cerr << "VariableNode::computeDerivative: Impossible case!" << endl;
944       exit(EXIT_FAILURE);
945     }
946   output << R"(", "lag" : )" << lag << "}";
947 }
948 
949 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const950 VariableNode::writeJsonOutput(ostream &output,
951                               const temporary_terms_t &temporary_terms,
952                               const deriv_node_temp_terms_t &tef_terms,
953                               bool isdynamic) const
954 {
955   if (temporary_terms.find(const_cast<VariableNode *>(this)) != temporary_terms.end())
956     {
957       output << "T" << idx;
958       return;
959     }
960 
961   output << datatree.symbol_table.getName(symb_id);
962   if (isdynamic && lag != 0)
963     output << "(" << lag << ")";
964 }
965 
966 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const967 VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
968                           const temporary_terms_t &temporary_terms,
969                           const temporary_terms_idxs_t &temporary_terms_idxs,
970                           const deriv_node_temp_terms_t &tef_terms) const
971 {
972   auto type = get_type();
973   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
974     return;
975 
976   if (isLatexOutput(output_type))
977     {
978       if (output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator)
979         output << R"(\bar)";
980       output << "{" << datatree.symbol_table.getTeXName(symb_id) << "}";
981       if (output_type == ExprNodeOutputType::latexDynamicModel
982           && (type == SymbolType::endogenous || type == SymbolType::exogenous || type == SymbolType::exogenousDet || type == SymbolType::trend || type == SymbolType::logTrend))
983         {
984           output << "_{t";
985           if (lag != 0)
986             {
987               if (lag > 0)
988                 output << "+";
989               output << lag;
990             }
991           output << "}";
992         }
993       return;
994     }
995 
996   int i;
997   switch (int tsid = datatree.symbol_table.getTypeSpecificID(symb_id); type)
998     {
999     case SymbolType::parameter:
1000       if (output_type == ExprNodeOutputType::matlabOutsideModel)
1001         output << "M_.params" << "(" << tsid + 1 << ")";
1002       else
1003         output << "params" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + ARRAY_SUBSCRIPT_OFFSET(output_type) << RIGHT_ARRAY_SUBSCRIPT(output_type);
1004       break;
1005 
1006     case SymbolType::modelLocalVariable:
1007       if (output_type == ExprNodeOutputType::matlabDynamicModelSparse || output_type == ExprNodeOutputType::matlabStaticModelSparse
1008           || output_type == ExprNodeOutputType::matlabDynamicSteadyStateOperator || output_type == ExprNodeOutputType::matlabDynamicSparseSteadyStateOperator
1009           || output_type == ExprNodeOutputType::CDynamicSteadyStateOperator)
1010         {
1011           output << "(";
1012           datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
1013           output << ")";
1014         }
1015       else
1016         /* We append underscores to avoid name clashes with "g1" or "oo_".
1017            But we probably never arrive here because MLV are temporary terms… */
1018         output << datatree.symbol_table.getName(symb_id) << "__";
1019       break;
1020 
1021     case SymbolType::modFileLocalVariable:
1022       output << datatree.symbol_table.getName(symb_id);
1023       break;
1024 
1025     case SymbolType::endogenous:
1026       switch (output_type)
1027         {
1028         case ExprNodeOutputType::juliaDynamicModel:
1029         case ExprNodeOutputType::matlabDynamicModel:
1030         case ExprNodeOutputType::CDynamicModel:
1031           i = datatree.getDynJacobianCol(datatree.getDerivID(symb_id, lag)) + ARRAY_SUBSCRIPT_OFFSET(output_type);
1032           output <<  "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1033           break;
1034         case ExprNodeOutputType::CStaticModel:
1035         case ExprNodeOutputType::juliaStaticModel:
1036         case ExprNodeOutputType::matlabStaticModel:
1037         case ExprNodeOutputType::matlabStaticModelSparse:
1038           i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
1039           output <<  "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1040           break;
1041         case ExprNodeOutputType::matlabDynamicModelSparse:
1042           i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
1043           if (lag > 0)
1044             output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1045           else if (lag < 0)
1046             output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1047           else
1048             output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1049           break;
1050         case ExprNodeOutputType::matlabOutsideModel:
1051           output << "oo_.steady_state(" << tsid + 1 << ")";
1052           break;
1053         case ExprNodeOutputType::juliaDynamicSteadyStateOperator:
1054         case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
1055         case ExprNodeOutputType::matlabDynamicSparseSteadyStateOperator:
1056           output << "steady_state" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + 1 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1057           break;
1058         case ExprNodeOutputType::CDynamicSteadyStateOperator:
1059           output << "steady_state[" << tsid << "]";
1060           break;
1061         case ExprNodeOutputType::juliaSteadyStateFile:
1062         case ExprNodeOutputType::steadyStateFile:
1063           output << "ys_" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + 1 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1064           break;
1065         case ExprNodeOutputType::matlabDseries:
1066           output << "ds." << datatree.symbol_table.getName(symb_id);
1067           if (lag != 0)
1068             output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
1069           break;
1070         case ExprNodeOutputType::epilogueFile:
1071           output << "ds." << datatree.symbol_table.getName(symb_id);
1072           output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1073           if (lag != 0)
1074             output << lag;
1075           output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1076           break;
1077         default:
1078           cerr << "VariableNode::writeOutput: should not reach this point" << endl;
1079           exit(EXIT_FAILURE);
1080         }
1081       break;
1082 
1083     case SymbolType::exogenous:
1084       i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
1085       switch (output_type)
1086         {
1087         case ExprNodeOutputType::juliaDynamicModel:
1088         case ExprNodeOutputType::matlabDynamicModel:
1089         case ExprNodeOutputType::matlabDynamicModelSparse:
1090           if (lag > 0)
1091             output <<  "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i
1092                    << RIGHT_ARRAY_SUBSCRIPT(output_type);
1093           else if (lag < 0)
1094             output <<  "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i
1095                    << RIGHT_ARRAY_SUBSCRIPT(output_type);
1096           else
1097             output <<  "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i
1098                    << RIGHT_ARRAY_SUBSCRIPT(output_type);
1099           break;
1100         case ExprNodeOutputType::CDynamicModel:
1101           if (lag == 0)
1102             output <<  "x[it_+" << i << "*nb_row_x]";
1103           else if (lag > 0)
1104             output <<  "x[it_+" << lag << "+" << i << "*nb_row_x]";
1105           else
1106             output <<  "x[it_" << lag << "+" << i << "*nb_row_x]";
1107           break;
1108         case ExprNodeOutputType::CStaticModel:
1109         case ExprNodeOutputType::juliaStaticModel:
1110         case ExprNodeOutputType::matlabStaticModel:
1111         case ExprNodeOutputType::matlabStaticModelSparse:
1112           output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1113           break;
1114         case ExprNodeOutputType::matlabOutsideModel:
1115           assert(lag == 0);
1116           output <<  "oo_.exo_steady_state(" << i << ")";
1117           break;
1118         case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
1119           output <<  "oo_.exo_steady_state(" << i << ")";
1120           break;
1121         case ExprNodeOutputType::juliaSteadyStateFile:
1122         case ExprNodeOutputType::steadyStateFile:
1123           output << "exo_" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1124           break;
1125         case ExprNodeOutputType::matlabDseries:
1126           output << "ds." << datatree.symbol_table.getName(symb_id);
1127           if (lag != 0)
1128             output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
1129           break;
1130         case ExprNodeOutputType::epilogueFile:
1131           output << "ds." << datatree.symbol_table.getName(symb_id);
1132           output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1133           if (lag != 0)
1134             output << lag;
1135           output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1136           break;
1137         default:
1138           cerr << "VariableNode::writeOutput: should not reach this point" << endl;
1139           exit(EXIT_FAILURE);
1140         }
1141       break;
1142 
1143     case SymbolType::exogenousDet:
1144       i = tsid + datatree.symbol_table.exo_nbr() + ARRAY_SUBSCRIPT_OFFSET(output_type);
1145       switch (output_type)
1146         {
1147         case ExprNodeOutputType::juliaDynamicModel:
1148         case ExprNodeOutputType::matlabDynamicModel:
1149         case ExprNodeOutputType::matlabDynamicModelSparse:
1150           if (lag > 0)
1151             output <<  "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i
1152                    << RIGHT_ARRAY_SUBSCRIPT(output_type);
1153           else if (lag < 0)
1154             output <<  "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i
1155                    << RIGHT_ARRAY_SUBSCRIPT(output_type);
1156           else
1157             output <<  "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i
1158                    << RIGHT_ARRAY_SUBSCRIPT(output_type);
1159           break;
1160         case ExprNodeOutputType::CDynamicModel:
1161           if (lag == 0)
1162             output <<  "x[it_+" << i << "*nb_row_x]";
1163           else if (lag > 0)
1164             output <<  "x[it_+" << lag << "+" << i << "*nb_row_x]";
1165           else
1166             output <<  "x[it_" << lag << "+" << i << "*nb_row_x]";
1167           break;
1168         case ExprNodeOutputType::CStaticModel:
1169         case ExprNodeOutputType::juliaStaticModel:
1170         case ExprNodeOutputType::matlabStaticModel:
1171         case ExprNodeOutputType::matlabStaticModelSparse:
1172           output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1173           break;
1174         case ExprNodeOutputType::matlabOutsideModel:
1175           assert(lag == 0);
1176           output <<  "oo_.exo_det_steady_state(" << tsid + 1 << ")";
1177           break;
1178         case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
1179           output <<  "oo_.exo_det_steady_state(" << tsid + 1 << ")";
1180           break;
1181         case ExprNodeOutputType::juliaSteadyStateFile:
1182         case ExprNodeOutputType::steadyStateFile:
1183           output << "exo_" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1184           break;
1185         case ExprNodeOutputType::matlabDseries:
1186           output << "ds." << datatree.symbol_table.getName(symb_id);
1187           if (lag != 0)
1188             output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
1189           break;
1190         case ExprNodeOutputType::epilogueFile:
1191           output << "ds." << datatree.symbol_table.getName(symb_id);
1192           output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1193           if (lag != 0)
1194             output << lag;
1195           output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1196           break;
1197         default:
1198           cerr << "VariableNode::writeOutput: should not reach this point" << endl;
1199           exit(EXIT_FAILURE);
1200         }
1201       break;
1202     case SymbolType::epilogue:
1203       if (output_type == ExprNodeOutputType::epilogueFile)
1204         {
1205           output << "ds." << datatree.symbol_table.getName(symb_id);
1206           output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1207           if (lag != 0)
1208             output << lag;
1209           output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1210         }
1211       else if (output_type == ExprNodeOutputType::matlabDseries)
1212         // Only writing dseries for epilogue_static, hence no need to check lag
1213         output << "ds." << datatree.symbol_table.getName(symb_id);
1214       else
1215         {
1216           cerr << "VariableNode::writeOutput: Impossible case" << endl;
1217           exit(EXIT_FAILURE);
1218         }
1219       break;
1220     case SymbolType::unusedEndogenous:
1221       cerr << "ERROR: You cannot use an endogenous variable in an expression if that variable has not been used in the model block." << endl;
1222       exit(EXIT_FAILURE);
1223     case SymbolType::externalFunction:
1224     case SymbolType::trend:
1225     case SymbolType::logTrend:
1226     case SymbolType::statementDeclaredVariable:
1227     case SymbolType::endogenousVAR:
1228     case SymbolType::excludedVariable:
1229       cerr << "VariableNode::writeOutput: Impossible case" << endl;
1230       exit(EXIT_FAILURE);
1231     }
1232 }
1233 
1234 expr_t
substituteStaticAuxiliaryVariable() const1235 VariableNode::substituteStaticAuxiliaryVariable() const
1236 {
1237   if (get_type() == SymbolType::endogenous)
1238     try
1239       {
1240         return datatree.symbol_table.getAuxiliaryVarsExprNode(symb_id)->substituteStaticAuxiliaryVariable();
1241       }
1242     catch (SymbolTable::SearchFailedException &e)
1243       {
1244       }
1245   return const_cast<VariableNode *>(this);
1246 }
1247 
1248 double
eval(const eval_context_t & eval_context) const1249 VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
1250 {
1251   if (get_type() == SymbolType::modelLocalVariable)
1252     return datatree.getLocalVariable(symb_id)->eval(eval_context);
1253 
1254   auto it = eval_context.find(symb_id);
1255   if (it == eval_context.end())
1256     throw EvalException();
1257 
1258   return it->second;
1259 }
1260 
1261 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const1262 VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
1263                       bool lhs_rhs, const temporary_terms_t &temporary_terms,
1264                       const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
1265                       const deriv_node_temp_terms_t &tef_terms) const
1266 {
1267   auto type = get_type();
1268   if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
1269     datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
1270   else
1271     {
1272       int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
1273       if (type == SymbolType::exogenousDet)
1274         tsid += datatree.symbol_table.exo_nbr();
1275       if (!lhs_rhs)
1276         {
1277           if (dynamic)
1278             {
1279               if (steady_dynamic) // steady state values in a dynamic model
1280                 {
1281                   FLDVS_ fldvs{static_cast<uint8_t>(type), static_cast<unsigned int>(tsid)};
1282                   fldvs.write(CompileCode, instruction_number);
1283                 }
1284               else
1285                 {
1286                   if (type == SymbolType::parameter)
1287                     {
1288                       FLDV_ fldv{static_cast<int>(type), static_cast<unsigned int>(tsid)};
1289                       fldv.write(CompileCode, instruction_number);
1290                     }
1291                   else
1292                     {
1293                       FLDV_ fldv{static_cast<int>(type), static_cast<unsigned int>(tsid), lag};
1294                       fldv.write(CompileCode, instruction_number);
1295                     }
1296                 }
1297             }
1298           else
1299             {
1300               FLDSV_ fldsv{static_cast<uint8_t>(type), static_cast<unsigned int>(tsid)};
1301               fldsv.write(CompileCode, instruction_number);
1302             }
1303         }
1304       else
1305         {
1306           if (dynamic)
1307             {
1308               if (steady_dynamic) // steady state values in a dynamic model
1309                 {
1310                   cerr << "Impossible case: steady_state in rhs of equation" << endl;
1311                   exit(EXIT_FAILURE);
1312                 }
1313               else
1314                 {
1315                   if (type == SymbolType::parameter)
1316                     {
1317                       FSTPV_ fstpv{static_cast<int>(type), static_cast<unsigned int>(tsid)};
1318                       fstpv.write(CompileCode, instruction_number);
1319                     }
1320                   else
1321                     {
1322                       FSTPV_ fstpv{static_cast<int>(type), static_cast<unsigned int>(tsid), lag};
1323                       fstpv.write(CompileCode, instruction_number);
1324                     }
1325                 }
1326             }
1327           else
1328             {
1329               FSTPSV_ fstpsv{static_cast<uint8_t>(type), static_cast<unsigned int>(tsid)};
1330               fstpsv.write(CompileCode, instruction_number);
1331             }
1332         }
1333     }
1334 }
1335 
1336 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const1337 VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
1338                                     temporary_terms_t &temporary_terms,
1339                                     map<expr_t, pair<int, int>> &first_occurence,
1340                                     int Curr_block,
1341                                     vector<vector<temporary_terms_t>> &v_temporary_terms,
1342                                     int equation) const
1343 {
1344   if (get_type() == SymbolType::modelLocalVariable)
1345     datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
1346 }
1347 
1348 void
collectVARLHSVariable(set<expr_t> & result) const1349 VariableNode::collectVARLHSVariable(set<expr_t> &result) const
1350 {
1351   if (get_type() == SymbolType::endogenous && lag == 0)
1352     result.insert(const_cast<VariableNode *>(this));
1353   else
1354     {
1355       cerr << "ERROR: you can only have endogenous variables or unary ops on LHS of VAR" << endl;
1356       exit(EXIT_FAILURE);
1357     }
1358 }
1359 
1360 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const1361 VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
1362 {
1363   if (get_type() == type_arg)
1364     result.emplace(symb_id, lag);
1365   if (get_type() == SymbolType::modelLocalVariable)
1366     datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
1367 }
1368 
1369 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const1370 VariableNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
1371 {
1372   /* The equation has to be normalized with respect to the current endogenous variable ascribed to it.
1373      The two input arguments are :
1374      - The ID of the endogenous variable associated to the equation.
1375      - The list of operators and operands needed to normalize the equation*
1376 
1377      The pair returned by NormalizeEquation is composed of
1378      - a flag indicating if the expression returned contains (flag = 1) or not (flag = 0)
1379      the endogenous variable related to the equation.
1380      If the expression contains more than one occurence of the associated endogenous variable,
1381      the flag is equal to 2.
1382      - an expression equal to the RHS if flag = 0 and equal to NULL elsewhere
1383   */
1384   if (get_type() == SymbolType::endogenous)
1385     {
1386       if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0)
1387         /* the endogenous variable */
1388         return { 1, nullptr };
1389       else
1390         return { 0, datatree.AddVariable(symb_id, lag) };
1391     }
1392   else
1393     {
1394       if (get_type() == SymbolType::parameter)
1395         return { 0, datatree.AddVariable(symb_id, 0) };
1396       else
1397         return { 0, datatree.AddVariable(symb_id, lag) };
1398     }
1399 }
1400 
1401 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)1402 VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
1403 {
1404   switch (get_type())
1405     {
1406     case SymbolType::endogenous:
1407     case SymbolType::exogenous:
1408     case SymbolType::exogenousDet:
1409     case SymbolType::parameter:
1410     case SymbolType::trend:
1411     case SymbolType::logTrend:
1412       if (deriv_id == datatree.getDerivID(symb_id, lag))
1413         return datatree.One;
1414       // If there is in the equation a recursive variable we could use a chaine rule derivation
1415       else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
1416                it != recursive_variables.end())
1417         {
1418           map<int, expr_t> recursive_vars2(recursive_variables);
1419           recursive_vars2.erase(it->first);
1420           return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2));
1421         }
1422       else
1423         return datatree.Zero;
1424 
1425     case SymbolType::modelLocalVariable:
1426       return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
1427     case SymbolType::modFileLocalVariable:
1428       cerr << "modFileLocalVariable is not derivable" << endl;
1429       exit(EXIT_FAILURE);
1430     case SymbolType::statementDeclaredVariable:
1431       cerr << "statementDeclaredVariable is not derivable" << endl;
1432       exit(EXIT_FAILURE);
1433     case SymbolType::unusedEndogenous:
1434       cerr << "unusedEndogenous is not derivable" << endl;
1435       exit(EXIT_FAILURE);
1436     case SymbolType::externalFunction:
1437     case SymbolType::endogenousVAR:
1438     case SymbolType::epilogue:
1439     case SymbolType::excludedVariable:
1440       cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl;
1441       exit(EXIT_FAILURE);
1442     }
1443   // Suppress GCC warning
1444   exit(EXIT_FAILURE);
1445 }
1446 
1447 expr_t
toStatic(DataTree & static_datatree) const1448 VariableNode::toStatic(DataTree &static_datatree) const
1449 {
1450   return static_datatree.AddVariable(symb_id);
1451 }
1452 
1453 void
computeXrefs(EquationInfo & ei) const1454 VariableNode::computeXrefs(EquationInfo &ei) const
1455 {
1456   switch (get_type())
1457     {
1458     case SymbolType::endogenous:
1459       ei.endo.emplace(symb_id, lag);
1460       break;
1461     case SymbolType::exogenous:
1462       ei.exo.emplace(symb_id, lag);
1463       break;
1464     case SymbolType::exogenousDet:
1465       ei.exo_det.emplace(symb_id, lag);
1466       break;
1467     case SymbolType::parameter:
1468       ei.param.emplace(symb_id, 0);
1469       break;
1470     case SymbolType::modFileLocalVariable:
1471       datatree.getLocalVariable(symb_id)->computeXrefs(ei);
1472       break;
1473     case SymbolType::trend:
1474     case SymbolType::logTrend:
1475     case SymbolType::modelLocalVariable:
1476     case SymbolType::statementDeclaredVariable:
1477     case SymbolType::unusedEndogenous:
1478     case SymbolType::externalFunction:
1479     case SymbolType::endogenousVAR:
1480     case SymbolType::epilogue:
1481     case SymbolType::excludedVariable:
1482       break;
1483     }
1484 }
1485 
1486 SymbolType
get_type() const1487 VariableNode::get_type() const
1488 {
1489   return datatree.symbol_table.getType(symb_id);
1490 }
1491 
1492 expr_t
clone(DataTree & datatree) const1493 VariableNode::clone(DataTree &datatree) const
1494 {
1495   return datatree.AddVariable(symb_id, lag);
1496 }
1497 
1498 int
maxEndoLead() const1499 VariableNode::maxEndoLead() const
1500 {
1501   switch (get_type())
1502     {
1503     case SymbolType::endogenous:
1504       return max(lag, 0);
1505     case SymbolType::modelLocalVariable:
1506       return datatree.getLocalVariable(symb_id)->maxEndoLead();
1507     default:
1508       return 0;
1509     }
1510 }
1511 
1512 int
maxExoLead() const1513 VariableNode::maxExoLead() const
1514 {
1515   switch (get_type())
1516     {
1517     case SymbolType::exogenous:
1518       return max(lag, 0);
1519     case SymbolType::modelLocalVariable:
1520       return datatree.getLocalVariable(symb_id)->maxExoLead();
1521     default:
1522       return 0;
1523     }
1524 }
1525 
1526 int
maxEndoLag() const1527 VariableNode::maxEndoLag() const
1528 {
1529   switch (get_type())
1530     {
1531     case SymbolType::endogenous:
1532       return max(-lag, 0);
1533     case SymbolType::modelLocalVariable:
1534       return datatree.getLocalVariable(symb_id)->maxEndoLag();
1535     default:
1536       return 0;
1537     }
1538 }
1539 
1540 int
maxExoLag() const1541 VariableNode::maxExoLag() const
1542 {
1543   switch (get_type())
1544     {
1545     case SymbolType::exogenous:
1546       return max(-lag, 0);
1547     case SymbolType::modelLocalVariable:
1548       return datatree.getLocalVariable(symb_id)->maxExoLag();
1549     default:
1550       return 0;
1551     }
1552 }
1553 
1554 int
maxLead() const1555 VariableNode::maxLead() const
1556 {
1557   switch (get_type())
1558     {
1559     case SymbolType::endogenous:
1560     case SymbolType::exogenous:
1561     case SymbolType::exogenousDet:
1562       return lag;
1563     case SymbolType::modelLocalVariable:
1564       return datatree.getLocalVariable(symb_id)->maxLead();
1565     default:
1566       return 0;
1567     }
1568 }
1569 
1570 int
VarMinLag() const1571 VariableNode::VarMinLag() const
1572 {
1573   switch (get_type())
1574     {
1575     case SymbolType::endogenous:
1576       return -lag;
1577     case SymbolType::exogenous:
1578       if (lag > 0)
1579         return -lag;
1580       else
1581         return 1; // Can have contemporaneus exog in VAR
1582     case SymbolType::modelLocalVariable:
1583       return datatree.getLocalVariable(symb_id)->VarMinLag();
1584     default:
1585       return 1;
1586     }
1587 }
1588 
1589 int
maxLag() const1590 VariableNode::maxLag() const
1591 {
1592   switch (get_type())
1593     {
1594     case SymbolType::endogenous:
1595     case SymbolType::exogenous:
1596     case SymbolType::exogenousDet:
1597       return -lag;
1598     case SymbolType::modelLocalVariable:
1599       return datatree.getLocalVariable(symb_id)->maxLag();
1600     default:
1601       return 0;
1602     }
1603 }
1604 
1605 int
maxLagWithDiffsExpanded() const1606 VariableNode::maxLagWithDiffsExpanded() const
1607 {
1608   switch (get_type())
1609     {
1610     case SymbolType::endogenous:
1611     case SymbolType::exogenous:
1612     case SymbolType::exogenousDet:
1613     case SymbolType::epilogue:
1614       return -lag;
1615     case SymbolType::modelLocalVariable:
1616       return datatree.getLocalVariable(symb_id)->maxLagWithDiffsExpanded();
1617     default:
1618       return 0;
1619     }
1620 }
1621 
1622 expr_t
undiff() const1623 VariableNode::undiff() const
1624 {
1625   return const_cast<VariableNode *>(this);
1626 }
1627 
1628 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const1629 VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
1630 {
1631   auto [lag_equiv_repr, index] = getLagEquivalenceClass();
1632   if (lhs_lag_equiv.find(lag_equiv_repr) == lhs_lag_equiv.end())
1633     return 0;
1634   return maxLag();
1635 }
1636 
1637 int
PacMaxLag(int lhs_symb_id) const1638 VariableNode::PacMaxLag(int lhs_symb_id) const
1639 {
1640   if (get_type() == SymbolType::modelLocalVariable)
1641     return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id);
1642 
1643   if (lhs_symb_id == symb_id)
1644     return -lag;
1645   return 0;
1646 }
1647 
1648 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const1649 VariableNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
1650 {
1651   return -1;
1652 }
1653 
1654 expr_t
substituteAdl() const1655 VariableNode::substituteAdl() const
1656 {
1657   /* Do not recurse into model-local variables definition, rather do it at the
1658      DynamicModel method level (see the comment there) */
1659   return const_cast<VariableNode *>(this);
1660 }
1661 
1662 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const1663 VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
1664 {
1665   if (get_type() == SymbolType::modelLocalVariable)
1666     return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table);
1667 
1668   return const_cast<VariableNode *>(this);
1669 }
1670 
1671 void
findDiffNodes(lag_equivalence_table_t & nodes) const1672 VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
1673 {
1674   if (get_type() == SymbolType::modelLocalVariable)
1675     datatree.getLocalVariable(symb_id)->findDiffNodes(nodes);
1676 }
1677 
1678 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const1679 VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
1680 {
1681   if (get_type() == SymbolType::modelLocalVariable)
1682     datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
1683 }
1684 
1685 int
findTargetVariable(int lhs_symb_id) const1686 VariableNode::findTargetVariable(int lhs_symb_id) const
1687 {
1688   if (get_type() == SymbolType::modelLocalVariable)
1689     return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
1690 
1691   return -1;
1692 }
1693 
1694 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1695 VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
1696                              vector<BinaryOpNode *> &neweqs) const
1697 {
1698   if (get_type() == SymbolType::modelLocalVariable)
1699     return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs);
1700 
1701   return const_cast<VariableNode *>(this);
1702 }
1703 
1704 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1705 VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1706 {
1707   if (get_type() == SymbolType::modelLocalVariable)
1708     return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs);
1709 
1710   return const_cast<VariableNode *>(this);
1711 }
1712 
1713 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)1714 VariableNode::substitutePacExpectation(const string &name, expr_t subexpr)
1715 {
1716   if (get_type() == SymbolType::modelLocalVariable)
1717     return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr);
1718 
1719   return const_cast<VariableNode *>(this);
1720 }
1721 
1722 expr_t
decreaseLeadsLags(int n) const1723 VariableNode::decreaseLeadsLags(int n) const
1724 {
1725   switch (get_type())
1726     {
1727     case SymbolType::endogenous:
1728     case SymbolType::exogenous:
1729     case SymbolType::exogenousDet:
1730     case SymbolType::trend:
1731     case SymbolType::logTrend:
1732       return datatree.AddVariable(symb_id, lag-n);
1733     case SymbolType::modelLocalVariable:
1734       return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n);
1735     default:
1736       return const_cast<VariableNode *>(this);
1737     }
1738 }
1739 
1740 expr_t
decreaseLeadsLagsPredeterminedVariables() const1741 VariableNode::decreaseLeadsLagsPredeterminedVariables() const
1742 {
1743   /* Do not recurse into model-local variables definitions, since MLVs are
1744      already handled by DynamicModel::transformPredeterminedVariables().
1745      This is also necessary because of #65. */
1746   if (datatree.symbol_table.isPredetermined(symb_id))
1747     return decreaseLeadsLags(1);
1748   else
1749     return const_cast<VariableNode *>(this);
1750 }
1751 
1752 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const1753 VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
1754 {
1755   switch (get_type())
1756     {
1757     case SymbolType::endogenous:
1758       if (lag <= 1)
1759         return const_cast<VariableNode *>(this);
1760       else
1761         return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
1762     case SymbolType::modelLocalVariable:
1763       if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 1)
1764         return const_cast<VariableNode *>(this);
1765       else
1766         return value->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
1767     default:
1768       return const_cast<VariableNode *>(this);
1769     }
1770 }
1771 
1772 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1773 VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1774 {
1775   VariableNode *substexpr;
1776   int cur_lag;
1777   switch (get_type())
1778     {
1779     case SymbolType::endogenous:
1780       if (lag >= -1)
1781         return const_cast<VariableNode *>(this);
1782 
1783       if (auto it = subst_table.find(this); it != subst_table.end())
1784         return const_cast<VariableNode *>(it->second);
1785 
1786       substexpr = datatree.AddVariable(symb_id, -1);
1787       cur_lag = -2;
1788 
1789       // Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag)
1790       // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag))
1791       while (cur_lag >= lag)
1792         {
1793           VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag);
1794           if (auto it = subst_table.find(orig_expr); it == subst_table.end())
1795             {
1796               int aux_symb_id = datatree.symbol_table.addEndoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
1797               neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr)));
1798               substexpr = datatree.AddVariable(aux_symb_id, -1);
1799               subst_table[orig_expr] = substexpr;
1800             }
1801           else
1802             substexpr = const_cast<VariableNode *>(it->second);
1803 
1804           cur_lag--;
1805         }
1806       return substexpr;
1807 
1808     case SymbolType::modelLocalVariable:
1809       if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLag() <= 1)
1810         return const_cast<VariableNode *>(this);
1811       else
1812         return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
1813     default:
1814       return const_cast<VariableNode *>(this);
1815     }
1816 }
1817 
1818 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const1819 VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
1820 {
1821   switch (get_type())
1822     {
1823     case SymbolType::exogenous:
1824       if (lag <= 0)
1825         return const_cast<VariableNode *>(this);
1826       else
1827         return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
1828     case SymbolType::modelLocalVariable:
1829       if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLead() == 0)
1830         return const_cast<VariableNode *>(this);
1831       else
1832         return value->substituteExoLead(subst_table, neweqs, deterministic_model);
1833     default:
1834       return const_cast<VariableNode *>(this);
1835     }
1836 }
1837 
1838 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1839 VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1840 {
1841   VariableNode *substexpr;
1842   int cur_lag;
1843   switch (get_type())
1844     {
1845     case SymbolType::exogenous:
1846       if (lag >= 0)
1847         return const_cast<VariableNode *>(this);
1848 
1849       if (auto it = subst_table.find(this); it != subst_table.end())
1850         return const_cast<VariableNode *>(it->second);
1851 
1852       substexpr = datatree.AddVariable(symb_id, 0);
1853       cur_lag = -1;
1854 
1855       // Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag)
1856       // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag))
1857       while (cur_lag >= lag)
1858         {
1859           VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag);
1860           if (auto it = subst_table.find(orig_expr); it == subst_table.end())
1861             {
1862               int aux_symb_id = datatree.symbol_table.addExoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
1863               neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr)));
1864               substexpr = datatree.AddVariable(aux_symb_id, -1);
1865               subst_table[orig_expr] = substexpr;
1866             }
1867           else
1868             substexpr = const_cast<VariableNode *>(it->second);
1869 
1870           cur_lag--;
1871         }
1872       return substexpr;
1873 
1874     case SymbolType::modelLocalVariable:
1875       if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLag() == 0)
1876         return const_cast<VariableNode *>(this);
1877       else
1878         return value->substituteExoLag(subst_table, neweqs);
1879     default:
1880       return const_cast<VariableNode *>(this);
1881     }
1882 }
1883 
1884 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const1885 VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
1886 {
1887   if (get_type() == SymbolType::modelLocalVariable)
1888     return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model);
1889 
1890   return const_cast<VariableNode *>(this);
1891 }
1892 
1893 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1894 VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1895 {
1896   switch (get_type())
1897     {
1898     case SymbolType::endogenous:
1899       assert(lag <= 1);
1900       if (lag <= 0
1901           || (subset.size() > 0
1902               && find(subset.begin(), subset.end(), datatree.symbol_table.getName(symb_id)) == subset.end()))
1903         return const_cast<VariableNode *>(this);
1904       else
1905         {
1906           VariableNode *diffvar;
1907           if (auto it = subst_table.find(this); it != subst_table.end())
1908             diffvar = const_cast<VariableNode *>(it->second);
1909           else
1910             {
1911               int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, datatree.AddMinus(datatree.AddVariable(symb_id, 0),
1912                                                                                                             datatree.AddVariable(symb_id, -1)));
1913               neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), datatree.AddMinus(datatree.AddVariable(symb_id, 0),
1914                                                                                                                                       datatree.AddVariable(symb_id, -1)))));
1915               diffvar = datatree.AddVariable(aux_symb_id, 1);
1916               subst_table[this] = diffvar;
1917             }
1918           return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar);
1919         }
1920     case SymbolType::modelLocalVariable:
1921       if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 0)
1922         return const_cast<VariableNode *>(this);
1923       else
1924         return value->differentiateForwardVars(subset, subst_table, neweqs);
1925     default:
1926       return const_cast<VariableNode *>(this);
1927     }
1928 }
1929 
1930 bool
isNumConstNodeEqualTo(double value) const1931 VariableNode::isNumConstNodeEqualTo(double value) const
1932 {
1933   return false;
1934 }
1935 
1936 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const1937 VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
1938 {
1939   if (get_type() == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
1940     return true;
1941   else
1942     return false;
1943 }
1944 
1945 bool
containsPacExpectation(const string & pac_model_name) const1946 VariableNode::containsPacExpectation(const string &pac_model_name) const
1947 {
1948   if (get_type() == SymbolType::modelLocalVariable)
1949     return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name);
1950 
1951   return false;
1952 }
1953 
1954 bool
containsEndogenous() const1955 VariableNode::containsEndogenous() const
1956 {
1957   if (get_type() == SymbolType::modelLocalVariable)
1958     return datatree.getLocalVariable(symb_id)->containsEndogenous();
1959 
1960   if (get_type() == SymbolType::endogenous)
1961     return true;
1962   else
1963     return false;
1964 }
1965 
1966 bool
containsExogenous() const1967 VariableNode::containsExogenous() const
1968 {
1969   if (get_type() == SymbolType::modelLocalVariable)
1970     return datatree.getLocalVariable(symb_id)->containsExogenous();
1971 
1972   return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet;
1973 }
1974 
1975 expr_t
replaceTrendVar() const1976 VariableNode::replaceTrendVar() const
1977 {
1978   if (get_type() == SymbolType::modelLocalVariable)
1979     return datatree.getLocalVariable(symb_id)->replaceTrendVar();
1980 
1981   if (get_type() == SymbolType::trend)
1982     return datatree.One;
1983   else if (get_type() == SymbolType::logTrend)
1984     return datatree.Zero;
1985   else
1986     return const_cast<VariableNode *>(this);
1987 }
1988 
1989 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const1990 VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
1991 {
1992   if (get_type() == SymbolType::modelLocalVariable)
1993     return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend);
1994 
1995   if (this->symb_id != symb_id)
1996     return const_cast<VariableNode *>(this);
1997 
1998   if (log_trend)
1999     {
2000       if (lag == 0)
2001         return datatree.AddPlus(const_cast<VariableNode *>(this), trend);
2002       else
2003         return datatree.AddPlus(const_cast<VariableNode *>(this), trend->decreaseLeadsLags(-lag));
2004     }
2005   else
2006     {
2007       if (lag == 0)
2008         return datatree.AddTimes(const_cast<VariableNode *>(this), trend);
2009       else
2010         return datatree.AddTimes(const_cast<VariableNode *>(this), trend->decreaseLeadsLags(-lag));
2011     }
2012 }
2013 
2014 int
countDiffs() const2015 VariableNode::countDiffs() const
2016 {
2017   if (get_type() == SymbolType::modelLocalVariable)
2018     return datatree.getLocalVariable(symb_id)->countDiffs();
2019 
2020   return 0;
2021 }
2022 
2023 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const2024 VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
2025 {
2026   if (get_type() == SymbolType::modelLocalVariable)
2027     return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map);
2028 
2029   if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0)
2030     return const_cast<VariableNode *>(this);
2031 
2032   auto it = trend_symbols_map.find(symb_id);
2033   expr_t noTrendLeadLagNode = datatree.AddVariable(it->first);
2034   bool log_trend = get_type() == SymbolType::logTrend;
2035   expr_t trend = it->second;
2036 
2037   if (lag > 0)
2038     {
2039       expr_t growthFactorSequence = trend->decreaseLeadsLags(-1);
2040       if (log_trend)
2041         {
2042           for (int i = 1; i < lag; i++)
2043             growthFactorSequence = datatree.AddPlus(growthFactorSequence, trend->decreaseLeadsLags(-1*(i+1)));
2044           return datatree.AddPlus(noTrendLeadLagNode, growthFactorSequence);
2045         }
2046       else
2047         {
2048           for (int i = 1; i < lag; i++)
2049             growthFactorSequence = datatree.AddTimes(growthFactorSequence, trend->decreaseLeadsLags(-1*(i+1)));
2050           return datatree.AddTimes(noTrendLeadLagNode, growthFactorSequence);
2051         }
2052     }
2053   else //get_lag < 0
2054     {
2055       expr_t growthFactorSequence = trend;
2056       if (log_trend)
2057         {
2058           for (int i = 1; i < abs(lag); i++)
2059             growthFactorSequence = datatree.AddPlus(growthFactorSequence, trend->decreaseLeadsLags(i));
2060           return datatree.AddMinus(noTrendLeadLagNode, growthFactorSequence);
2061         }
2062       else
2063         {
2064           for (int i = 1; i < abs(lag); i++)
2065             growthFactorSequence = datatree.AddTimes(growthFactorSequence, trend->decreaseLeadsLags(i));
2066           return datatree.AddDivide(noTrendLeadLagNode, growthFactorSequence);
2067         }
2068     }
2069 }
2070 
2071 bool
isInStaticForm() const2072 VariableNode::isInStaticForm() const
2073 {
2074   if (get_type() == SymbolType::modelLocalVariable)
2075     return datatree.getLocalVariable(symb_id)->isInStaticForm();
2076 
2077   return lag == 0;
2078 }
2079 
2080 bool
isParamTimesEndogExpr() const2081 VariableNode::isParamTimesEndogExpr() const
2082 {
2083   if (get_type() == SymbolType::modelLocalVariable)
2084     return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr();
2085 
2086   return false;
2087 }
2088 
2089 bool
isVarModelReferenced(const string & model_info_name) const2090 VariableNode::isVarModelReferenced(const string &model_info_name) const
2091 {
2092   if (get_type() == SymbolType::modelLocalVariable)
2093     return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name);
2094 
2095   return false;
2096 }
2097 
2098 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const2099 VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
2100 {
2101   if (get_type() == SymbolType::modelLocalVariable)
2102     return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags);
2103 
2104   if (get_type() == SymbolType::endogenous)
2105     if (string varname = datatree.symbol_table.getName(symb_id);
2106         model_endos_and_lags.find(varname) == model_endos_and_lags.end())
2107       model_endos_and_lags[varname] = min(model_endos_and_lags[varname], lag);
2108     else
2109       model_endos_and_lags[varname] = lag;
2110 }
2111 
2112 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const2113 VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
2114 {
2115   return;
2116 }
2117 
2118 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const2119 VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
2120 {
2121   /* Do not recurse into model-local variables definitions, since MLVs are
2122      already handled by ModelTree::simplifyEquations().
2123      This is also necessary because of #65. */
2124   for (auto &it : table)
2125     if (it.first->symb_id == symb_id)
2126       return it.second;
2127   return const_cast<VariableNode *>(this);
2128 }
2129 
UnaryOpNode(DataTree & datatree_arg,int idx_arg,UnaryOpcode op_code_arg,const expr_t arg_arg,int expectation_information_set_arg,int param1_symb_id_arg,int param2_symb_id_arg,string adl_param_name_arg,vector<int> adl_lags_arg)2130 UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg) :
2131   ExprNode{datatree_arg, idx_arg},
2132   arg{arg_arg},
2133   expectation_information_set{expectation_information_set_arg},
2134   param1_symb_id{param1_symb_id_arg},
2135   param2_symb_id{param2_symb_id_arg},
2136   op_code{op_code_arg},
2137   adl_param_name{move(adl_param_name_arg)},
2138   adl_lags{move(adl_lags_arg)}
2139 {
2140 }
2141 
2142 void
prepareForDerivation()2143 UnaryOpNode::prepareForDerivation()
2144 {
2145   if (preparedForDerivation)
2146     return;
2147 
2148   preparedForDerivation = true;
2149 
2150   arg->prepareForDerivation();
2151 
2152   // Non-null derivatives are those of the argument (except for STEADY_STATE)
2153   non_null_derivatives = arg->non_null_derivatives;
2154   if (op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
2155       || op_code == UnaryOpcode::steadyStateParam2ndDeriv)
2156     datatree.addAllParamDerivId(non_null_derivatives);
2157 }
2158 
2159 expr_t
composeDerivatives(expr_t darg,int deriv_id)2160 UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
2161 {
2162   expr_t t11, t12, t13, t14;
2163 
2164   switch (op_code)
2165     {
2166     case UnaryOpcode::uminus:
2167       return datatree.AddUMinus(darg);
2168     case UnaryOpcode::exp:
2169       return datatree.AddTimes(darg, this);
2170     case UnaryOpcode::log:
2171       return datatree.AddDivide(darg, arg);
2172     case UnaryOpcode::log10:
2173       t11 = datatree.AddExp(datatree.One);
2174       t12 = datatree.AddLog10(t11);
2175       t13 = datatree.AddDivide(darg, arg);
2176       return datatree.AddTimes(t12, t13);
2177     case UnaryOpcode::cos:
2178       t11 = datatree.AddSin(arg);
2179       t12 = datatree.AddUMinus(t11);
2180       return datatree.AddTimes(darg, t12);
2181     case UnaryOpcode::sin:
2182       t11 = datatree.AddCos(arg);
2183       return datatree.AddTimes(darg, t11);
2184     case UnaryOpcode::tan:
2185       t11 = datatree.AddTimes(this, this);
2186       t12 = datatree.AddPlus(t11, datatree.One);
2187       return datatree.AddTimes(darg, t12);
2188     case UnaryOpcode::acos:
2189       t11 = datatree.AddSin(this);
2190       t12 = datatree.AddDivide(darg, t11);
2191       return datatree.AddUMinus(t12);
2192     case UnaryOpcode::asin:
2193       t11 = datatree.AddCos(this);
2194       return datatree.AddDivide(darg, t11);
2195     case UnaryOpcode::atan:
2196       t11 = datatree.AddTimes(arg, arg);
2197       t12 = datatree.AddPlus(datatree.One, t11);
2198       return datatree.AddDivide(darg, t12);
2199     case UnaryOpcode::cosh:
2200       t11 = datatree.AddSinh(arg);
2201       return datatree.AddTimes(darg, t11);
2202     case UnaryOpcode::sinh:
2203       t11 = datatree.AddCosh(arg);
2204       return datatree.AddTimes(darg, t11);
2205     case UnaryOpcode::tanh:
2206       t11 = datatree.AddTimes(this, this);
2207       t12 = datatree.AddMinus(datatree.One, t11);
2208       return datatree.AddTimes(darg, t12);
2209     case UnaryOpcode::acosh:
2210       t11 = datatree.AddSinh(this);
2211       return datatree.AddDivide(darg, t11);
2212     case UnaryOpcode::asinh:
2213       t11 = datatree.AddCosh(this);
2214       return datatree.AddDivide(darg, t11);
2215     case UnaryOpcode::atanh:
2216       t11 = datatree.AddTimes(arg, arg);
2217       t12 = datatree.AddMinus(datatree.One, t11);
2218       return datatree.AddTimes(darg, t12);
2219     case UnaryOpcode::sqrt:
2220       t11 = datatree.AddPlus(this, this);
2221       return datatree.AddDivide(darg, t11);
2222     case UnaryOpcode::cbrt:
2223       t11 = datatree.AddPower(arg, datatree.AddDivide(datatree.Two, datatree.Three));
2224       t12 = datatree.AddTimes(datatree.Three, t11);
2225       return datatree.AddDivide(darg, t12);
2226     case UnaryOpcode::abs:
2227       t11 = datatree.AddSign(arg);
2228       return datatree.AddTimes(t11, darg);
2229     case UnaryOpcode::sign:
2230       return datatree.Zero;
2231     case UnaryOpcode::steadyState:
2232       if (datatree.isDynamic())
2233         {
2234           if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
2235             {
2236               auto varg = dynamic_cast<VariableNode *>(arg);
2237               if (!varg)
2238                 {
2239                   cerr << "UnaryOpNode::composeDerivatives: STEADY_STATE() should only be used on "
2240                        << "standalone variables (like STEADY_STATE(y)) to be derivable w.r.t. parameters" << endl;
2241                   exit(EXIT_FAILURE);
2242                 }
2243               if (datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous)
2244                 return datatree.AddSteadyStateParamDeriv(arg, datatree.getSymbIDByDerivID(deriv_id));
2245               else
2246                 return datatree.Zero;
2247             }
2248           else
2249             return datatree.Zero;
2250         }
2251       else
2252         return darg;
2253     case UnaryOpcode::steadyStateParamDeriv:
2254       assert(datatree.isDynamic());
2255       if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
2256         {
2257           auto varg = dynamic_cast<VariableNode *>(arg);
2258           assert(varg);
2259           assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2260           return datatree.AddSteadyStateParam2ndDeriv(arg, param1_symb_id, datatree.getSymbIDByDerivID(deriv_id));
2261         }
2262       else
2263         return datatree.Zero;
2264     case UnaryOpcode::steadyStateParam2ndDeriv:
2265       assert(datatree.isDynamic());
2266       if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
2267         {
2268           cerr << "3rd derivative of STEADY_STATE node w.r.t. three parameters not implemented" << endl;
2269           exit(EXIT_FAILURE);
2270         }
2271       else
2272         return datatree.Zero;
2273     case UnaryOpcode::expectation:
2274       cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::expectation" << endl;
2275       exit(EXIT_FAILURE);
2276     case UnaryOpcode::erf:
2277       // x^2
2278       t11 = datatree.AddPower(arg, datatree.Two);
2279       // exp(x^2)
2280       t12 = datatree.AddExp(t11);
2281       // sqrt(pi)
2282       t11 = datatree.AddSqrt(datatree.Pi);
2283       // sqrt(pi)*exp(x^2)
2284       t13 = datatree.AddTimes(t11, t12);
2285       // 2/(sqrt(pi)*exp(x^2));
2286       t14 = datatree.AddDivide(datatree.Two, t13);
2287       // (2/(sqrt(pi)*exp(x^2)))*dx;
2288       return datatree.AddTimes(t14, darg);
2289     case UnaryOpcode::diff:
2290       cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::diff" << endl;
2291       exit(EXIT_FAILURE);
2292     case UnaryOpcode::adl:
2293       cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::adl" << endl;
2294       exit(EXIT_FAILURE);
2295     }
2296   // Suppress GCC warning
2297   exit(EXIT_FAILURE);
2298 }
2299 
2300 expr_t
computeDerivative(int deriv_id)2301 UnaryOpNode::computeDerivative(int deriv_id)
2302 {
2303   expr_t darg = arg->getDerivative(deriv_id);
2304   return composeDerivatives(darg, deriv_id);
2305 }
2306 
2307 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const2308 UnaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
2309 {
2310   // For a temporary term, the cost is null
2311   for (const auto &it : temp_terms_map)
2312     if (it.second.find(const_cast<UnaryOpNode *>(this)) != it.second.end())
2313       return 0;
2314 
2315   return cost(arg->cost(temp_terms_map, is_matlab), is_matlab);
2316 }
2317 
2318 int
cost(const temporary_terms_t & temporary_terms,bool is_matlab) const2319 UnaryOpNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
2320 {
2321   // For a temporary term, the cost is null
2322   if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
2323     return 0;
2324 
2325   return cost(arg->cost(temporary_terms, is_matlab), is_matlab);
2326 }
2327 
2328 int
cost(int cost,bool is_matlab) const2329 UnaryOpNode::cost(int cost, bool is_matlab) const
2330 {
2331   if (is_matlab)
2332     // Cost for Matlab files
2333     switch (op_code)
2334       {
2335       case UnaryOpcode::uminus:
2336       case UnaryOpcode::sign:
2337         return cost + 70;
2338       case UnaryOpcode::exp:
2339         return cost + 160;
2340       case UnaryOpcode::log:
2341         return cost + 300;
2342       case UnaryOpcode::log10:
2343       case UnaryOpcode::erf:
2344         return cost + 16000;
2345       case UnaryOpcode::cos:
2346       case UnaryOpcode::sin:
2347       case UnaryOpcode::cosh:
2348         return cost + 210;
2349       case UnaryOpcode::tan:
2350         return cost + 230;
2351       case UnaryOpcode::acos:
2352         return cost + 300;
2353       case UnaryOpcode::asin:
2354         return cost + 310;
2355       case UnaryOpcode::atan:
2356         return cost + 140;
2357       case UnaryOpcode::sinh:
2358         return cost + 240;
2359       case UnaryOpcode::tanh:
2360         return cost + 190;
2361       case UnaryOpcode::acosh:
2362         return cost + 770;
2363       case UnaryOpcode::asinh:
2364         return cost + 460;
2365       case UnaryOpcode::atanh:
2366         return cost + 350;
2367       case UnaryOpcode::sqrt:
2368       case UnaryOpcode::cbrt:
2369       case UnaryOpcode::abs:
2370         return cost + 570;
2371       case UnaryOpcode::steadyState:
2372       case UnaryOpcode::steadyStateParamDeriv:
2373       case UnaryOpcode::steadyStateParam2ndDeriv:
2374       case UnaryOpcode::expectation:
2375         return cost;
2376       case UnaryOpcode::diff:
2377         cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::diff" << endl;
2378         exit(EXIT_FAILURE);
2379       case UnaryOpcode::adl:
2380         cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::adl" << endl;
2381         exit(EXIT_FAILURE);
2382       }
2383   else
2384     // Cost for C files
2385     switch (op_code)
2386       {
2387       case UnaryOpcode::uminus:
2388       case UnaryOpcode::sign:
2389         return cost + 3;
2390       case UnaryOpcode::exp:
2391       case UnaryOpcode::acosh:
2392         return cost + 210;
2393       case UnaryOpcode::log:
2394         return cost + 137;
2395       case UnaryOpcode::log10:
2396         return cost + 139;
2397       case UnaryOpcode::cos:
2398       case UnaryOpcode::sin:
2399         return cost + 160;
2400       case UnaryOpcode::tan:
2401         return cost + 170;
2402       case UnaryOpcode::acos:
2403       case UnaryOpcode::atan:
2404         return cost + 190;
2405       case UnaryOpcode::asin:
2406         return cost + 180;
2407       case UnaryOpcode::cosh:
2408       case UnaryOpcode::sinh:
2409       case UnaryOpcode::tanh:
2410       case UnaryOpcode::erf:
2411         return cost + 240;
2412       case UnaryOpcode::asinh:
2413         return cost + 220;
2414       case UnaryOpcode::atanh:
2415         return cost + 150;
2416       case UnaryOpcode::sqrt:
2417       case UnaryOpcode::cbrt:
2418       case UnaryOpcode::abs:
2419         return cost + 90;
2420       case UnaryOpcode::steadyState:
2421       case UnaryOpcode::steadyStateParamDeriv:
2422       case UnaryOpcode::steadyStateParam2ndDeriv:
2423       case UnaryOpcode::expectation:
2424         return cost;
2425       case UnaryOpcode::diff:
2426         cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::diff" << endl;
2427         exit(EXIT_FAILURE);
2428       case UnaryOpcode::adl:
2429         cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::adl" << endl;
2430         exit(EXIT_FAILURE);
2431       }
2432   exit(EXIT_FAILURE);
2433 }
2434 
2435 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const2436 UnaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
2437                                    map<pair<int, int>, temporary_terms_t> &temp_terms_map,
2438                                    map<expr_t, pair<int, pair<int, int>>> &reference_count,
2439                                    bool is_matlab) const
2440 {
2441   expr_t this2 = const_cast<UnaryOpNode *>(this);
2442   if (auto it = reference_count.find(this2);
2443       it == reference_count.end())
2444     {
2445       reference_count[this2] = { 1, derivOrder };
2446       arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
2447     }
2448   else
2449     {
2450       reference_count[this2] = { it->second.first + 1, it->second.second };
2451       if (reference_count[this2].first * cost(temp_terms_map, is_matlab) > min_cost(is_matlab))
2452         temp_terms_map[reference_count[this2].second].insert(this2);
2453     }
2454 }
2455 
2456 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const2457 UnaryOpNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
2458                                    temporary_terms_t &temporary_terms,
2459                                    map<expr_t, pair<int, int>> &first_occurence,
2460                                    int Curr_block,
2461                                    vector< vector<temporary_terms_t>> &v_temporary_terms,
2462                                    int equation) const
2463 {
2464   expr_t this2 = const_cast<UnaryOpNode *>(this);
2465   if (auto it = reference_count.find(this2);
2466       it == reference_count.end())
2467     {
2468       reference_count[this2] = 1;
2469       first_occurence[this2] = { Curr_block, equation };
2470       arg->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
2471     }
2472   else
2473     {
2474       reference_count[this2]++;
2475       if (reference_count[this2] * cost(temporary_terms, false) > min_cost_c)
2476         {
2477           temporary_terms.insert(this2);
2478           v_temporary_terms[first_occurence[this2].first][first_occurence[this2].second].insert(this2);
2479         }
2480     }
2481 }
2482 
2483 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const2484 UnaryOpNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
2485 {
2486   if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
2487     temporary_terms_inuse.insert(idx);
2488   else
2489     arg->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
2490 }
2491 
2492 bool
containsExternalFunction() const2493 UnaryOpNode::containsExternalFunction() const
2494 {
2495   return arg->containsExternalFunction();
2496 }
2497 
2498 void
writeJsonAST(ostream & output) const2499 UnaryOpNode::writeJsonAST(ostream &output) const
2500 {
2501   output << R"({"node_type" : "UnaryOpNode", "op" : ")";
2502   switch (op_code)
2503     {
2504     case UnaryOpcode::uminus:
2505       output << "uminus";
2506       break;
2507     case UnaryOpcode::exp:
2508       output << "exp";
2509       break;
2510     case UnaryOpcode::log:
2511       output << "log";
2512       break;
2513     case UnaryOpcode::log10:
2514       output << "log10";
2515       break;
2516     case UnaryOpcode::cos:
2517       output << "cos";
2518       break;
2519     case UnaryOpcode::sin:
2520       output << "sin";
2521       break;
2522     case UnaryOpcode::tan:
2523       output << "tan";
2524       break;
2525     case UnaryOpcode::acos:
2526       output << "acos";
2527       break;
2528     case UnaryOpcode::asin:
2529       output << "asin";
2530       break;
2531     case UnaryOpcode::atan:
2532       output << "atan";
2533       break;
2534     case UnaryOpcode::cosh:
2535       output << "cosh";
2536       break;
2537     case UnaryOpcode::sinh:
2538       output << "sinh";
2539       break;
2540     case UnaryOpcode::tanh:
2541       output << "tanh";
2542       break;
2543     case UnaryOpcode::acosh:
2544       output << "acosh";
2545       break;
2546     case UnaryOpcode::asinh:
2547       output << "asinh";
2548       break;
2549     case UnaryOpcode::atanh:
2550       output << "atanh";
2551       break;
2552     case UnaryOpcode::sqrt:
2553       output << "sqrt";
2554       break;
2555     case UnaryOpcode::cbrt:
2556       output << "cbrt";
2557       break;
2558     case UnaryOpcode::abs:
2559       output << "abs";
2560       break;
2561     case UnaryOpcode::sign:
2562       output << "sign";
2563       break;
2564     case UnaryOpcode::diff:
2565       output << "diff";
2566       break;
2567     case UnaryOpcode::adl:
2568       output << "adl";
2569       break;
2570     case UnaryOpcode::steadyState:
2571       output << "steady_state";
2572     case UnaryOpcode::steadyStateParamDeriv:
2573       output << "steady_state_param_deriv";
2574       break;
2575     case UnaryOpcode::steadyStateParam2ndDeriv:
2576       output << "steady_state_param_second_deriv";
2577       break;
2578     case UnaryOpcode::expectation:
2579       output << "expectation";
2580       break;
2581     case UnaryOpcode::erf:
2582       output << "erf";
2583       break;
2584     }
2585   output << R"(", "arg" : )";
2586   arg->writeJsonAST(output);
2587   switch (op_code)
2588     {
2589     case UnaryOpcode::adl:
2590       output << R"(, "adl_param_name" : ")" << adl_param_name << R"(")"
2591              << R"(, "lags" : [)";
2592       for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
2593         {
2594           if (it != adl_lags.begin())
2595             output << ", ";
2596           output << *it;
2597         }
2598       output << "]";
2599       break;
2600     default:
2601       break;
2602     }
2603   output << "}";
2604 }
2605 
2606 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const2607 UnaryOpNode::writeJsonOutput(ostream &output,
2608                              const temporary_terms_t &temporary_terms,
2609                              const deriv_node_temp_terms_t &tef_terms,
2610                              bool isdynamic) const
2611 {
2612   if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
2613     {
2614       output << "T" << idx;
2615       return;
2616     }
2617 
2618   // Always put parenthesis around uminus nodes
2619   if (op_code == UnaryOpcode::uminus)
2620     output << "(";
2621 
2622   switch (op_code)
2623     {
2624     case UnaryOpcode::uminus:
2625       output << "-";
2626       break;
2627     case UnaryOpcode::exp:
2628       output << "exp";
2629       break;
2630     case UnaryOpcode::log:
2631       output << "log";
2632       break;
2633     case UnaryOpcode::log10:
2634       output << "log10";
2635       break;
2636     case UnaryOpcode::cos:
2637       output << "cos";
2638       break;
2639     case UnaryOpcode::sin:
2640       output << "sin";
2641       break;
2642     case UnaryOpcode::tan:
2643       output << "tan";
2644       break;
2645     case UnaryOpcode::acos:
2646       output << "acos";
2647       break;
2648     case UnaryOpcode::asin:
2649       output << "asin";
2650       break;
2651     case UnaryOpcode::atan:
2652       output << "atan";
2653       break;
2654     case UnaryOpcode::cosh:
2655       output << "cosh";
2656       break;
2657     case UnaryOpcode::sinh:
2658       output << "sinh";
2659       break;
2660     case UnaryOpcode::tanh:
2661       output << "tanh";
2662       break;
2663     case UnaryOpcode::acosh:
2664       output << "acosh";
2665       break;
2666     case UnaryOpcode::asinh:
2667       output << "asinh";
2668       break;
2669     case UnaryOpcode::atanh:
2670       output << "atanh";
2671       break;
2672     case UnaryOpcode::sqrt:
2673       output << "sqrt";
2674       break;
2675     case UnaryOpcode::cbrt:
2676       output << "cbrt";
2677       break;
2678     case UnaryOpcode::abs:
2679       output << "abs";
2680       break;
2681     case UnaryOpcode::sign:
2682       output << "sign";
2683       break;
2684     case UnaryOpcode::diff:
2685       output << "diff";
2686       break;
2687     case UnaryOpcode::adl:
2688       output << "adl(";
2689       arg->writeJsonOutput(output, temporary_terms, tef_terms);
2690       output << ", '" << adl_param_name << "', [";
2691       for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
2692         {
2693           if (it != adl_lags.begin())
2694             output << ", ";
2695           output << *it;
2696         }
2697       output << "])";
2698       return;
2699     case UnaryOpcode::steadyState:
2700       output << "(";
2701       arg->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
2702       output << ")";
2703       return;
2704     case UnaryOpcode::steadyStateParamDeriv:
2705       {
2706         auto varg = dynamic_cast<VariableNode *>(arg);
2707         assert(varg);
2708         assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2709         assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2710         int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2711         int tsid_param = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2712         output << "ss_param_deriv(" << tsid_endo+1 << "," << tsid_param+1 << ")";
2713       }
2714       return;
2715     case UnaryOpcode::steadyStateParam2ndDeriv:
2716       {
2717         auto varg = dynamic_cast<VariableNode *>(arg);
2718         assert(varg);
2719         assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2720         assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2721         assert(datatree.symbol_table.getType(param2_symb_id) == SymbolType::parameter);
2722         int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2723         int tsid_param1 = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2724         int tsid_param2 = datatree.symbol_table.getTypeSpecificID(param2_symb_id);
2725         output << "ss_param_2nd_deriv(" << tsid_endo+1 << "," << tsid_param1+1
2726                << "," << tsid_param2+1 << ")";
2727       }
2728       return;
2729     case UnaryOpcode::expectation:
2730       output << "EXPECTATION(" << expectation_information_set << ")";
2731       break;
2732     case UnaryOpcode::erf:
2733       output << "erf";
2734       break;
2735     }
2736 
2737   bool close_parenthesis = false;
2738 
2739   /* Enclose argument with parentheses if:
2740      - current opcode is not uminus, or
2741      - current opcode is uminus and argument has lowest precedence
2742   */
2743   if (op_code != UnaryOpcode::uminus
2744       || (op_code == UnaryOpcode::uminus
2745           && arg->precedenceJson(temporary_terms) < precedenceJson(temporary_terms)))
2746     {
2747       output << "(";
2748       close_parenthesis = true;
2749     }
2750 
2751   // Write argument
2752   arg->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
2753 
2754   if (close_parenthesis)
2755     output << ")";
2756 
2757   // Close parenthesis for uminus
2758   if (op_code == UnaryOpcode::uminus)
2759     output << ")";
2760 }
2761 
2762 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const2763 UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
2764                          const temporary_terms_t &temporary_terms,
2765                          const temporary_terms_idxs_t &temporary_terms_idxs,
2766                          const deriv_node_temp_terms_t &tef_terms) const
2767 {
2768   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
2769     return;
2770 
2771   // Always put parenthesis around uminus nodes
2772   if (op_code == UnaryOpcode::uminus)
2773     output << LEFT_PAR(output_type);
2774 
2775   switch (op_code)
2776     {
2777     case UnaryOpcode::uminus:
2778       output << "-";
2779       break;
2780     case UnaryOpcode::exp:
2781       if (isLatexOutput(output_type))
2782         output << R"(\exp)";
2783       else
2784         output << "exp";
2785       break;
2786     case UnaryOpcode::log:
2787       if (isLatexOutput(output_type))
2788         output << R"(\log)";
2789       else
2790         output << "log";
2791       break;
2792     case UnaryOpcode::log10:
2793       if (isLatexOutput(output_type))
2794         output << R"(\log_{10})";
2795       else
2796         output << "log10";
2797       break;
2798     case UnaryOpcode::cos:
2799       if (isLatexOutput(output_type))
2800         output << R"(\cos)";
2801       else
2802         output << "cos";
2803       break;
2804     case UnaryOpcode::sin:
2805       if (isLatexOutput(output_type))
2806         output << R"(\sin)";
2807       else
2808         output << "sin";
2809       break;
2810     case UnaryOpcode::tan:
2811       if (isLatexOutput(output_type))
2812         output << R"(\tan)";
2813       else
2814         output << "tan";
2815       break;
2816     case UnaryOpcode::acos:
2817       output << "acos";
2818       break;
2819     case UnaryOpcode::asin:
2820       output << "asin";
2821       break;
2822     case UnaryOpcode::atan:
2823       output << "atan";
2824       break;
2825     case UnaryOpcode::cosh:
2826       output << "cosh";
2827       break;
2828     case UnaryOpcode::sinh:
2829       output << "sinh";
2830       break;
2831     case UnaryOpcode::tanh:
2832       output << "tanh";
2833       break;
2834     case UnaryOpcode::acosh:
2835       output << "acosh";
2836       break;
2837     case UnaryOpcode::asinh:
2838       output << "asinh";
2839       break;
2840     case UnaryOpcode::atanh:
2841       output << "atanh";
2842       break;
2843     case UnaryOpcode::sqrt:
2844       if (isLatexOutput(output_type))
2845         {
2846           output << R"(\sqrt{)";
2847           arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2848           output << "}";
2849           return;
2850         }
2851       output << "sqrt";
2852       break;
2853     case UnaryOpcode::cbrt:
2854       if (isMatlabOutput(output_type))
2855         {
2856           output << "nthroot(";
2857           arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2858           output << ", 3)";
2859           return;
2860         }
2861       else if (isLatexOutput(output_type))
2862         {
2863           output << R"(\sqrt[3]{)";
2864           arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2865           output << "}";
2866           return;
2867         }
2868       else
2869         output << "cbrt";
2870       break;
2871     case UnaryOpcode::abs:
2872       output << "abs";
2873       break;
2874     case UnaryOpcode::sign:
2875       if (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel)
2876         output << "copysign";
2877       else
2878         output << "sign";
2879       break;
2880     case UnaryOpcode::steadyState:
2881       ExprNodeOutputType new_output_type;
2882       switch (output_type)
2883         {
2884         case ExprNodeOutputType::matlabDynamicModel:
2885           new_output_type = ExprNodeOutputType::matlabDynamicSteadyStateOperator;
2886           break;
2887         case ExprNodeOutputType::latexDynamicModel:
2888           new_output_type = ExprNodeOutputType::latexDynamicSteadyStateOperator;
2889           break;
2890         case ExprNodeOutputType::CDynamicModel:
2891           new_output_type = ExprNodeOutputType::CDynamicSteadyStateOperator;
2892           break;
2893         case ExprNodeOutputType::juliaDynamicModel:
2894           new_output_type = ExprNodeOutputType::juliaDynamicSteadyStateOperator;
2895           break;
2896         case ExprNodeOutputType::matlabDynamicModelSparse:
2897           new_output_type = ExprNodeOutputType::matlabDynamicSparseSteadyStateOperator;
2898           break;
2899         default:
2900           new_output_type = output_type;
2901           break;
2902         }
2903       output << "(";
2904       arg->writeOutput(output, new_output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2905       output << ")";
2906       return;
2907     case UnaryOpcode::steadyStateParamDeriv:
2908       {
2909         auto varg = dynamic_cast<VariableNode *>(arg);
2910         assert(varg);
2911         assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2912         assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2913         int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2914         int tsid_param = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2915         assert(isMatlabOutput(output_type));
2916         output << "ss_param_deriv(" << tsid_endo+1 << "," << tsid_param+1 << ")";
2917       }
2918       return;
2919     case UnaryOpcode::steadyStateParam2ndDeriv:
2920       {
2921         auto varg = dynamic_cast<VariableNode *>(arg);
2922         assert(varg);
2923         assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2924         assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2925         assert(datatree.symbol_table.getType(param2_symb_id) == SymbolType::parameter);
2926         int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2927         int tsid_param1 = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2928         int tsid_param2 = datatree.symbol_table.getTypeSpecificID(param2_symb_id);
2929         assert(isMatlabOutput(output_type));
2930         output << "ss_param_2nd_deriv(" << tsid_endo+1 << "," << tsid_param1+1
2931                << "," << tsid_param2+1 << ")";
2932       }
2933       return;
2934     case UnaryOpcode::expectation:
2935       if (!isLatexOutput(output_type))
2936         {
2937           cerr << "UnaryOpNode::writeOutput: not implemented on UnaryOpcode::expectation" << endl;
2938           exit(EXIT_FAILURE);
2939         }
2940       output << R"(\mathbb{E}_{t)";
2941       if (expectation_information_set != 0)
2942         {
2943           if (expectation_information_set > 0)
2944             output << "+";
2945           output << expectation_information_set;
2946         }
2947       output << "}";
2948       break;
2949     case UnaryOpcode::erf:
2950       output << "erf";
2951       break;
2952     case UnaryOpcode::diff:
2953       output << "diff";
2954       break;
2955     case UnaryOpcode::adl:
2956       output << "adl";
2957       break;
2958     }
2959 
2960   bool close_parenthesis = false;
2961 
2962   /* Enclose argument with parentheses if:
2963      - current opcode is not uminus, or
2964      - current opcode is uminus and argument has lowest precedence
2965   */
2966   if (op_code != UnaryOpcode::uminus
2967       || (op_code == UnaryOpcode::uminus
2968           && arg->precedence(output_type, temporary_terms) < precedence(output_type, temporary_terms)))
2969     {
2970       output << LEFT_PAR(output_type);
2971       if (op_code == UnaryOpcode::sign && (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel))
2972         output << "1.0,";
2973       close_parenthesis = true;
2974     }
2975 
2976   // Write argument
2977   arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2978 
2979   if (close_parenthesis)
2980     output << RIGHT_PAR(output_type);
2981 
2982   // Close parenthesis for uminus
2983   if (op_code == UnaryOpcode::uminus)
2984     output << RIGHT_PAR(output_type);
2985 }
2986 
2987 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const2988 UnaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
2989                                          const temporary_terms_t &temporary_terms,
2990                                          const temporary_terms_idxs_t &temporary_terms_idxs,
2991                                          deriv_node_temp_terms_t &tef_terms) const
2992 {
2993   arg->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2994 }
2995 
2996 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const2997 UnaryOpNode::writeJsonExternalFunctionOutput(vector<string> &efout,
2998                                              const temporary_terms_t &temporary_terms,
2999                                              deriv_node_temp_terms_t &tef_terms,
3000                                              bool isdynamic) const
3001 {
3002   arg->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
3003 }
3004 
3005 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const3006 UnaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
3007                                            bool lhs_rhs, const temporary_terms_t &temporary_terms,
3008                                            const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
3009                                            deriv_node_temp_terms_t &tef_terms) const
3010 {
3011   arg->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
3012                                      dynamic, steady_dynamic, tef_terms);
3013 }
3014 
3015 double
eval_opcode(UnaryOpcode op_code,double v)3016 UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) noexcept(false)
3017 {
3018   switch (op_code)
3019     {
3020     case UnaryOpcode::uminus:
3021       return -v;
3022     case UnaryOpcode::exp:
3023       return exp(v);
3024     case UnaryOpcode::log:
3025       return log(v);
3026     case UnaryOpcode::log10:
3027       return log10(v);
3028     case UnaryOpcode::cos:
3029       return cos(v);
3030     case UnaryOpcode::sin:
3031       return sin(v);
3032     case UnaryOpcode::tan:
3033       return tan(v);
3034     case UnaryOpcode::acos:
3035       return acos(v);
3036     case UnaryOpcode::asin:
3037       return asin(v);
3038     case UnaryOpcode::atan:
3039       return atan(v);
3040     case UnaryOpcode::cosh:
3041       return cosh(v);
3042     case UnaryOpcode::sinh:
3043       return sinh(v);
3044     case UnaryOpcode::tanh:
3045       return tanh(v);
3046     case UnaryOpcode::acosh:
3047       return acosh(v);
3048     case UnaryOpcode::asinh:
3049       return asinh(v);
3050     case UnaryOpcode::atanh:
3051       return atanh(v);
3052     case UnaryOpcode::sqrt:
3053       return sqrt(v);
3054     case UnaryOpcode::cbrt:
3055       return cbrt(v);
3056     case UnaryOpcode::abs:
3057       return abs(v);
3058     case UnaryOpcode::sign:
3059       return (v > 0) ? 1 : ((v < 0) ? -1 : 0);
3060     case UnaryOpcode::steadyState:
3061       return v;
3062     case UnaryOpcode::steadyStateParamDeriv:
3063       cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::steadyStateParamDeriv" << endl;
3064       exit(EXIT_FAILURE);
3065     case UnaryOpcode::steadyStateParam2ndDeriv:
3066       cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::steadyStateParam2ndDeriv" << endl;
3067       exit(EXIT_FAILURE);
3068     case UnaryOpcode::expectation:
3069       cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::expectation" << endl;
3070       exit(EXIT_FAILURE);
3071     case UnaryOpcode::erf:
3072       return erf(v);
3073     case UnaryOpcode::diff:
3074       cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::diff" << endl;
3075       exit(EXIT_FAILURE);
3076     case UnaryOpcode::adl:
3077       cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::adl" << endl;
3078       exit(EXIT_FAILURE);
3079     }
3080   // Suppress GCC warning
3081   exit(EXIT_FAILURE);
3082 }
3083 
3084 double
eval(const eval_context_t & eval_context) const3085 UnaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
3086 {
3087   double v = arg->eval(eval_context);
3088   return eval_opcode(op_code, v);
3089 }
3090 
3091 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const3092 UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
3093                      bool lhs_rhs, const temporary_terms_t &temporary_terms,
3094                      const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
3095                      const deriv_node_temp_terms_t &tef_terms) const
3096 {
3097   if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
3098     {
3099       if (dynamic)
3100         {
3101           auto ii = map_idx.find(idx);
3102           FLDT_ fldt(ii->second);
3103           fldt.write(CompileCode, instruction_number);
3104         }
3105       else
3106         {
3107           auto ii = map_idx.find(idx);
3108           FLDST_ fldst(ii->second);
3109           fldst.write(CompileCode, instruction_number);
3110         }
3111       return;
3112     }
3113   if (op_code == UnaryOpcode::steadyState)
3114     arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, true, tef_terms);
3115   else
3116     {
3117       arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
3118       FUNARY_ funary{static_cast<uint8_t>(op_code)};
3119       funary.write(CompileCode, instruction_number);
3120     }
3121 }
3122 
3123 void
collectVARLHSVariable(set<expr_t> & result) const3124 UnaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
3125 {
3126   if (op_code == UnaryOpcode::diff)
3127     result.insert(const_cast<UnaryOpNode *>(this));
3128   else
3129     arg->collectVARLHSVariable(result);
3130 }
3131 
3132 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const3133 UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
3134 {
3135   arg->collectDynamicVariables(type_arg, result);
3136 }
3137 
3138 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const3139 UnaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
3140 {
3141   pair<int, expr_t> res = arg->normalizeEquation(var_endo, List_of_Op_RHS);
3142   int is_endogenous_present = res.first;
3143   expr_t New_expr_t = res.second;
3144 
3145   if (is_endogenous_present == 2) /* The equation could not be normalized and the process is given-up*/
3146     return { 2, nullptr };
3147   else if (is_endogenous_present) /* The argument of the function contains the current values of
3148                                      the endogenous variable associated to the equation.
3149                                      In order to normalized, we have to apply the invert function to the RHS.*/
3150     {
3151       switch (op_code)
3152         {
3153         case UnaryOpcode::uminus:
3154           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::uminus), nullptr, nullptr);
3155           return { 1, nullptr };
3156         case UnaryOpcode::exp:
3157           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::log), nullptr, nullptr);
3158           return { 1, nullptr };
3159         case UnaryOpcode::log:
3160           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::exp), nullptr, nullptr);
3161           return { 1, nullptr };
3162         case UnaryOpcode::log10:
3163           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.AddNonNegativeConstant("10"));
3164           return { 1, nullptr };
3165         case UnaryOpcode::cos:
3166           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::acos), nullptr, nullptr);
3167           return { 1, nullptr };
3168         case UnaryOpcode::sin:
3169           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::asin), nullptr, nullptr);
3170           return { 1, nullptr };
3171         case UnaryOpcode::tan:
3172           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::atan), nullptr, nullptr);
3173           return { 1, nullptr };
3174         case UnaryOpcode::acos:
3175           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::cos), nullptr, nullptr);
3176           return { 1, nullptr };
3177         case UnaryOpcode::asin:
3178           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::sin), nullptr, nullptr);
3179           return { 1, nullptr };
3180         case UnaryOpcode::atan:
3181           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::tan), nullptr, nullptr);
3182           return { 1, nullptr };
3183         case UnaryOpcode::cosh:
3184           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::acosh), nullptr, nullptr);
3185           return { 1, nullptr };
3186         case UnaryOpcode::sinh:
3187           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::asinh), nullptr, nullptr);
3188           return { 1, nullptr };
3189         case UnaryOpcode::tanh:
3190           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::atanh), nullptr, nullptr);
3191           return { 1, nullptr };
3192         case UnaryOpcode::acosh:
3193           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::cosh), nullptr, nullptr);
3194           return { 1, nullptr };
3195         case UnaryOpcode::asinh:
3196           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::sinh), nullptr, nullptr);
3197           return { 1, nullptr };
3198         case UnaryOpcode::atanh:
3199           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::tanh), nullptr, nullptr);
3200           return { 1, nullptr };
3201         case UnaryOpcode::sqrt:
3202           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Two);
3203           return { 1, nullptr };
3204         case UnaryOpcode::cbrt:
3205           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Three);
3206           return { 1, nullptr };
3207         case UnaryOpcode::abs:
3208           return { 2, nullptr };
3209         case UnaryOpcode::sign:
3210           return { 2, nullptr };
3211         case UnaryOpcode::steadyState:
3212           return { 2, nullptr };
3213         case UnaryOpcode::erf:
3214           return { 2, nullptr };
3215         default:
3216           cerr << "Unary operator not handled during the normalization process" << endl;
3217           return { 2, nullptr }; // Could not be normalized
3218         }
3219     }
3220   else
3221     { /* If the argument of the function do not contain the current values of the endogenous variable
3222          related to the equation, the function with its argument is stored in the RHS*/
3223       switch (op_code)
3224         {
3225         case UnaryOpcode::uminus:
3226           return { 0, datatree.AddUMinus(New_expr_t) };
3227         case UnaryOpcode::exp:
3228           return { 0, datatree.AddExp(New_expr_t) };
3229         case UnaryOpcode::log:
3230           return { 0, datatree.AddLog(New_expr_t) };
3231         case UnaryOpcode::log10:
3232           return { 0, datatree.AddLog10(New_expr_t) };
3233         case UnaryOpcode::cos:
3234           return { 0, datatree.AddCos(New_expr_t) };
3235         case UnaryOpcode::sin:
3236           return { 0, datatree.AddSin(New_expr_t) };
3237         case UnaryOpcode::tan:
3238           return { 0, datatree.AddTan(New_expr_t) };
3239         case UnaryOpcode::acos:
3240           return { 0, datatree.AddAcos(New_expr_t) };
3241         case UnaryOpcode::asin:
3242           return { 0, datatree.AddAsin(New_expr_t) };
3243         case UnaryOpcode::atan:
3244           return { 0, datatree.AddAtan(New_expr_t) };
3245         case UnaryOpcode::cosh:
3246           return { 0, datatree.AddCosh(New_expr_t) };
3247         case UnaryOpcode::sinh:
3248           return { 0, datatree.AddSinh(New_expr_t) };
3249         case UnaryOpcode::tanh:
3250           return { 0, datatree.AddTanh(New_expr_t) };
3251         case UnaryOpcode::acosh:
3252           return { 0, datatree.AddAcosh(New_expr_t) };
3253         case UnaryOpcode::asinh:
3254           return { 0, datatree.AddAsinh(New_expr_t) };
3255         case UnaryOpcode::atanh:
3256           return { 0, datatree.AddAtanh(New_expr_t) };
3257         case UnaryOpcode::sqrt:
3258           return { 0, datatree.AddSqrt(New_expr_t) };
3259         case UnaryOpcode::cbrt:
3260           return { 0, datatree.AddCbrt(New_expr_t) };
3261         case UnaryOpcode::abs:
3262           return { 0, datatree.AddAbs(New_expr_t) };
3263         case UnaryOpcode::sign:
3264           return { 0, datatree.AddSign(New_expr_t) };
3265         case UnaryOpcode::steadyState:
3266           return { 0, datatree.AddSteadyState(New_expr_t) };
3267         case UnaryOpcode::erf:
3268           return { 0, datatree.AddErf(New_expr_t) };
3269         default:
3270           cerr << "Unary operator not handled during the normalization process" << endl;
3271           return { 2, nullptr }; // Could not be normalized
3272         }
3273     }
3274   cerr << "UnaryOpNode::normalizeEquation: impossible case" << endl;
3275   exit(EXIT_FAILURE);
3276 }
3277 
3278 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)3279 UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
3280 {
3281   expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
3282   return composeDerivatives(darg, deriv_id);
3283 }
3284 
3285 expr_t
buildSimilarUnaryOpNode(expr_t alt_arg,DataTree & alt_datatree) const3286 UnaryOpNode::buildSimilarUnaryOpNode(expr_t alt_arg, DataTree &alt_datatree) const
3287 {
3288   switch (op_code)
3289     {
3290     case UnaryOpcode::uminus:
3291       return alt_datatree.AddUMinus(alt_arg);
3292     case UnaryOpcode::exp:
3293       return alt_datatree.AddExp(alt_arg);
3294     case UnaryOpcode::log:
3295       return alt_datatree.AddLog(alt_arg);
3296     case UnaryOpcode::log10:
3297       return alt_datatree.AddLog10(alt_arg);
3298     case UnaryOpcode::cos:
3299       return alt_datatree.AddCos(alt_arg);
3300     case UnaryOpcode::sin:
3301       return alt_datatree.AddSin(alt_arg);
3302     case UnaryOpcode::tan:
3303       return alt_datatree.AddTan(alt_arg);
3304     case UnaryOpcode::acos:
3305       return alt_datatree.AddAcos(alt_arg);
3306     case UnaryOpcode::asin:
3307       return alt_datatree.AddAsin(alt_arg);
3308     case UnaryOpcode::atan:
3309       return alt_datatree.AddAtan(alt_arg);
3310     case UnaryOpcode::cosh:
3311       return alt_datatree.AddCosh(alt_arg);
3312     case UnaryOpcode::sinh:
3313       return alt_datatree.AddSinh(alt_arg);
3314     case UnaryOpcode::tanh:
3315       return alt_datatree.AddTanh(alt_arg);
3316     case UnaryOpcode::acosh:
3317       return alt_datatree.AddAcosh(alt_arg);
3318     case UnaryOpcode::asinh:
3319       return alt_datatree.AddAsinh(alt_arg);
3320     case UnaryOpcode::atanh:
3321       return alt_datatree.AddAtanh(alt_arg);
3322     case UnaryOpcode::sqrt:
3323       return alt_datatree.AddSqrt(alt_arg);
3324     case UnaryOpcode::cbrt:
3325       return alt_datatree.AddCbrt(alt_arg);
3326     case UnaryOpcode::abs:
3327       return alt_datatree.AddAbs(alt_arg);
3328     case UnaryOpcode::sign:
3329       return alt_datatree.AddSign(alt_arg);
3330     case UnaryOpcode::steadyState:
3331       return alt_datatree.AddSteadyState(alt_arg);
3332     case UnaryOpcode::steadyStateParamDeriv:
3333       cerr << "UnaryOpNode::buildSimilarUnaryOpNode: UnaryOpcode::steadyStateParamDeriv can't be translated" << endl;
3334       exit(EXIT_FAILURE);
3335     case UnaryOpcode::steadyStateParam2ndDeriv:
3336       cerr << "UnaryOpNode::buildSimilarUnaryOpNode: UnaryOpcode::steadyStateParam2ndDeriv can't be translated" << endl;
3337       exit(EXIT_FAILURE);
3338     case UnaryOpcode::expectation:
3339       return alt_datatree.AddExpectation(expectation_information_set, alt_arg);
3340     case UnaryOpcode::erf:
3341       return alt_datatree.AddErf(alt_arg);
3342     case UnaryOpcode::diff:
3343       return alt_datatree.AddDiff(alt_arg);
3344     case UnaryOpcode::adl:
3345       return alt_datatree.AddAdl(alt_arg, adl_param_name, adl_lags);
3346     }
3347   // Suppress GCC warning
3348   exit(EXIT_FAILURE);
3349 }
3350 
3351 expr_t
toStatic(DataTree & static_datatree) const3352 UnaryOpNode::toStatic(DataTree &static_datatree) const
3353 {
3354   expr_t sarg = arg->toStatic(static_datatree);
3355   return buildSimilarUnaryOpNode(sarg, static_datatree);
3356 }
3357 
3358 void
computeXrefs(EquationInfo & ei) const3359 UnaryOpNode::computeXrefs(EquationInfo &ei) const
3360 {
3361   arg->computeXrefs(ei);
3362 }
3363 
3364 expr_t
clone(DataTree & datatree) const3365 UnaryOpNode::clone(DataTree &datatree) const
3366 {
3367   expr_t substarg = arg->clone(datatree);
3368   return buildSimilarUnaryOpNode(substarg, datatree);
3369 }
3370 
3371 int
maxEndoLead() const3372 UnaryOpNode::maxEndoLead() const
3373 {
3374   return arg->maxEndoLead();
3375 }
3376 
3377 int
maxExoLead() const3378 UnaryOpNode::maxExoLead() const
3379 {
3380   return arg->maxExoLead();
3381 }
3382 
3383 int
maxEndoLag() const3384 UnaryOpNode::maxEndoLag() const
3385 {
3386   return arg->maxEndoLag();
3387 }
3388 
3389 int
maxExoLag() const3390 UnaryOpNode::maxExoLag() const
3391 {
3392   return arg->maxExoLag();
3393 }
3394 
3395 int
maxLead() const3396 UnaryOpNode::maxLead() const
3397 {
3398   return arg->maxLead();
3399 }
3400 
3401 int
maxLag() const3402 UnaryOpNode::maxLag() const
3403 {
3404   return arg->maxLag();
3405 }
3406 
3407 int
maxLagWithDiffsExpanded() const3408 UnaryOpNode::maxLagWithDiffsExpanded() const
3409 {
3410   if (op_code == UnaryOpcode::diff)
3411     return arg->maxLagWithDiffsExpanded() + 1;
3412   return arg->maxLagWithDiffsExpanded();
3413 }
3414 
3415 expr_t
undiff() const3416 UnaryOpNode::undiff() const
3417 {
3418   if (op_code == UnaryOpcode::diff)
3419     return arg;
3420   return arg->undiff();
3421 }
3422 
3423 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const3424 UnaryOpNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
3425 {
3426   auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3427   if (lhs_lag_equiv.find(lag_equiv_repr) == lhs_lag_equiv.end())
3428     return 0;
3429   return arg->maxLag();
3430 }
3431 
3432 int
VarMinLag() const3433 UnaryOpNode::VarMinLag() const
3434 {
3435   return arg->VarMinLag();
3436 }
3437 
3438 int
PacMaxLag(int lhs_symb_id) const3439 UnaryOpNode::PacMaxLag(int lhs_symb_id) const
3440 {
3441   //This will never be an UnaryOpcode::diff node
3442   return arg->PacMaxLag(lhs_symb_id);
3443 }
3444 
3445 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const3446 UnaryOpNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
3447 {
3448   return arg->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
3449 }
3450 
3451 expr_t
substituteAdl() const3452 UnaryOpNode::substituteAdl() const
3453 {
3454   if (op_code != UnaryOpcode::adl)
3455     {
3456       expr_t argsubst = arg->substituteAdl();
3457       return buildSimilarUnaryOpNode(argsubst, datatree);
3458     }
3459 
3460   expr_t arg1subst = arg->substituteAdl();
3461   expr_t retval = nullptr;
3462   ostringstream inttostr;
3463 
3464   for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
3465     if (it == adl_lags.begin())
3466       {
3467         inttostr << *it;
3468         retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
3469                                    arg1subst->decreaseLeadsLags(*it));
3470       }
3471     else
3472       {
3473         inttostr.clear();
3474         inttostr.str("");
3475         inttostr << *it;
3476         retval = datatree.AddPlus(retval,
3477                                   datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_"
3478                                                                                                      + inttostr.str()), 0),
3479                                                     arg1subst->decreaseLeadsLags(*it)));
3480       }
3481   return retval;
3482 }
3483 
3484 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const3485 UnaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
3486 {
3487   expr_t argsubst = arg->substituteVarExpectation(subst_table);
3488   return buildSimilarUnaryOpNode(argsubst, datatree);
3489 }
3490 
3491 int
countDiffs() const3492 UnaryOpNode::countDiffs() const
3493 {
3494   if (op_code == UnaryOpcode::diff)
3495     return arg->countDiffs() + 1;
3496   return arg->countDiffs();
3497 }
3498 
3499 bool
createAuxVarForUnaryOpNode() const3500 UnaryOpNode::createAuxVarForUnaryOpNode() const
3501 {
3502   switch (op_code)
3503     {
3504     case UnaryOpcode::exp:
3505     case UnaryOpcode::log:
3506     case UnaryOpcode::log10:
3507     case UnaryOpcode::cos:
3508     case UnaryOpcode::sin:
3509     case UnaryOpcode::tan:
3510     case UnaryOpcode::acos:
3511     case UnaryOpcode::asin:
3512     case UnaryOpcode::atan:
3513     case UnaryOpcode::cosh:
3514     case UnaryOpcode::sinh:
3515     case UnaryOpcode::tanh:
3516     case UnaryOpcode::acosh:
3517     case UnaryOpcode::asinh:
3518     case UnaryOpcode::atanh:
3519     case UnaryOpcode::sqrt:
3520     case UnaryOpcode::cbrt:
3521     case UnaryOpcode::abs:
3522     case UnaryOpcode::sign:
3523     case UnaryOpcode::erf:
3524       return true;
3525     default:
3526       return false;
3527     }
3528 }
3529 
3530 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const3531 UnaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
3532 {
3533   arg->findUnaryOpNodesForAuxVarCreation(nodes);
3534 
3535   if (!this->createAuxVarForUnaryOpNode())
3536     return;
3537 
3538   auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3539   nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
3540 }
3541 
3542 void
findDiffNodes(lag_equivalence_table_t & nodes) const3543 UnaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
3544 {
3545   arg->findDiffNodes(nodes);
3546 
3547   if (op_code != UnaryOpcode::diff)
3548     return;
3549 
3550   auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3551   nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
3552 }
3553 
3554 int
findTargetVariable(int lhs_symb_id) const3555 UnaryOpNode::findTargetVariable(int lhs_symb_id) const
3556 {
3557   return arg->findTargetVariable(lhs_symb_id);
3558 }
3559 
3560 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3561 UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
3562                             vector<BinaryOpNode *> &neweqs) const
3563 {
3564   // If this is not a diff node, then substitute recursively and return
3565   expr_t argsubst = arg->substituteDiff(nodes, subst_table, neweqs);
3566   if (op_code != UnaryOpcode::diff)
3567     return buildSimilarUnaryOpNode(argsubst, datatree);
3568 
3569   if (auto sit = subst_table.find(this);
3570       sit != subst_table.end())
3571     return const_cast<VariableNode *>(sit->second);
3572 
3573   auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3574   auto it = nodes.find(lag_equiv_repr);
3575   if (it == nodes.end() || it->second.find(index) == it->second.end()
3576       || it->second.at(index) != this)
3577     {
3578       /* diff does not appear in VAR equations, so simply create aux var and return.
3579          Once the comparison of expression nodes works, come back and remove
3580          this part, folding into the next loop. */
3581       int symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, const_cast<UnaryOpNode *>(this));
3582       VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
3583       neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
3584                                                                       datatree.AddMinus(argsubst,
3585                                                                                         argsubst->decreaseLeadsLags(1)))));
3586       subst_table[this] = dynamic_cast<VariableNode *>(aux_var);
3587       return const_cast<VariableNode *>(subst_table[this]);
3588     }
3589 
3590   /* At this point, we know that this node (and its lagged/leaded brothers)
3591      must be substituted. We create the auxiliary variable and fill the
3592      substitution table for all those similar nodes, in an iteration going from
3593      leads to lags. */
3594   int last_index = 0;
3595   VariableNode *last_aux_var = nullptr;
3596   for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
3597     {
3598       expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
3599         arg->substituteDiff(nodes, subst_table, neweqs);
3600       auto vn = dynamic_cast<VariableNode *>(argsubst);
3601       int symb_id;
3602       if (rit == it->second.rbegin())
3603         {
3604           if (vn)
3605             symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, rit->second, vn->symb_id, vn->lag);
3606           else
3607             symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, rit->second);
3608 
3609           // make originating aux var & equation
3610           last_index = rit->first;
3611           last_aux_var = datatree.AddVariable(symb_id, 0);
3612           //ORIG_AUX_DIFF = argsubst - argsubst(-1)
3613           neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(last_aux_var,
3614                                                                           datatree.AddMinus(argsubst,
3615                                                                                             argsubst->decreaseLeadsLags(1)))));
3616           subst_table[rit->second] = dynamic_cast<VariableNode *>(last_aux_var);
3617         }
3618       else
3619         {
3620           // just add equation of form: AUX_DIFF = LAST_AUX_VAR(-1)
3621           VariableNode *new_aux_var = nullptr;
3622           for (int i = last_index; i > rit->first; i--)
3623             {
3624               if (i == last_index)
3625                 symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(argsubst->idx, rit->second,
3626                                                                        last_aux_var->symb_id, last_aux_var->lag);
3627               else
3628                 symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(new_aux_var->idx, rit->second,
3629                                                                        last_aux_var->symb_id, last_aux_var->lag);
3630 
3631               new_aux_var = datatree.AddVariable(symb_id, 0);
3632               neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(new_aux_var,
3633                                                                               last_aux_var->decreaseLeadsLags(1))));
3634               last_aux_var = new_aux_var;
3635             }
3636           subst_table[rit->second] = dynamic_cast<VariableNode *>(new_aux_var);
3637           last_index = rit->first;
3638         }
3639     }
3640   return const_cast<VariableNode *>(subst_table[this]);
3641 }
3642 
3643 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3644 UnaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3645 {
3646   if (auto sit = subst_table.find(this);
3647       sit != subst_table.end())
3648     return const_cast<VariableNode *>(sit->second);
3649 
3650   /* If the equivalence class of this node is not marked for substitution,
3651      then substitute recursively and return. */
3652   auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3653   auto it = nodes.find(lag_equiv_repr);
3654   expr_t argsubst = arg->substituteUnaryOpNodes(nodes, subst_table, neweqs);
3655   if (it == nodes.end())
3656     return buildSimilarUnaryOpNode(argsubst, datatree);
3657 
3658   string unary_op;
3659   switch (op_code)
3660     {
3661     case UnaryOpcode::exp:
3662       unary_op = "exp";
3663       break;
3664     case UnaryOpcode::log:
3665       unary_op = "log";
3666       break;
3667     case UnaryOpcode::log10:
3668       unary_op = "log10";
3669       break;
3670     case UnaryOpcode::cos:
3671       unary_op = "cos";
3672       break;
3673     case UnaryOpcode::sin:
3674       unary_op = "sin";
3675       break;
3676     case UnaryOpcode::tan:
3677       unary_op = "tan";
3678       break;
3679     case UnaryOpcode::acos:
3680       unary_op = "acos";
3681       break;
3682     case UnaryOpcode::asin:
3683       unary_op = "asin";
3684       break;
3685     case UnaryOpcode::atan:
3686       unary_op = "atan";
3687       break;
3688     case UnaryOpcode::cosh:
3689       unary_op = "cosh";
3690       break;
3691     case UnaryOpcode::sinh:
3692       unary_op = "sinh";
3693       break;
3694     case UnaryOpcode::tanh:
3695       unary_op = "tanh";
3696       break;
3697     case UnaryOpcode::acosh:
3698       unary_op = "acosh";
3699       break;
3700     case UnaryOpcode::asinh:
3701       unary_op = "asinh";
3702       break;
3703     case UnaryOpcode::atanh:
3704       unary_op = "atanh";
3705       break;
3706     case UnaryOpcode::sqrt:
3707       unary_op = "sqrt";
3708       break;
3709     case UnaryOpcode::cbrt:
3710       unary_op = "cbrt";
3711       break;
3712     case UnaryOpcode::abs:
3713       unary_op = "abs";
3714       break;
3715     case UnaryOpcode::sign:
3716       unary_op = "sign";
3717       break;
3718     case UnaryOpcode::erf:
3719       unary_op = "erf";
3720       break;
3721     default:
3722       cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl;
3723       exit(EXIT_FAILURE);
3724     }
3725 
3726   /* At this point, we know that this node (and its lagged/leaded brothers)
3727      must be substituted. We create the auxiliary variable and fill the
3728      substitution table for all those similar nodes, in an iteration going from
3729      leads to lags. */
3730   int base_index = 0;
3731   VariableNode *aux_var = nullptr;
3732   for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
3733     if (rit == it->second.rbegin())
3734       {
3735         /* Verify that we’re not operating on a node with leads, since the
3736            transformation does take into account the expectation operator. We only
3737            need to do this for the first iteration of the loop, because we’re
3738            going from leads to lags. */
3739         if (rit->second->maxLead() > 0)
3740           {
3741             cerr << "Cannot substitute unary operations that contain leads" << endl;
3742             exit(EXIT_FAILURE);
3743           }
3744 
3745         int symb_id;
3746         auto vn = dynamic_cast<VariableNode *>(argsubst);
3747         if (!vn)
3748           symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op);
3749         else
3750           symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op,
3751                                                                  vn->symb_id, vn->lag);
3752         aux_var = datatree.AddVariable(symb_id, 0);
3753         neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
3754                                                                         dynamic_cast<UnaryOpNode *>(rit->second))));
3755         subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
3756         base_index = rit->first;
3757       }
3758     else
3759       subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_index - rit->first));
3760 
3761   return const_cast<VariableNode *>(subst_table.find(this)->second);
3762 }
3763 
3764 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)3765 UnaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
3766 {
3767   expr_t argsubst = arg->substitutePacExpectation(name, subexpr);
3768   return buildSimilarUnaryOpNode(argsubst, datatree);
3769 }
3770 
3771 expr_t
decreaseLeadsLags(int n) const3772 UnaryOpNode::decreaseLeadsLags(int n) const
3773 {
3774   expr_t argsubst = arg->decreaseLeadsLags(n);
3775   return buildSimilarUnaryOpNode(argsubst, datatree);
3776 }
3777 
3778 expr_t
decreaseLeadsLagsPredeterminedVariables() const3779 UnaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
3780 {
3781   expr_t argsubst = arg->decreaseLeadsLagsPredeterminedVariables();
3782   return buildSimilarUnaryOpNode(argsubst, datatree);
3783 }
3784 
3785 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const3786 UnaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
3787 {
3788   if (op_code == UnaryOpcode::uminus || deterministic_model)
3789     {
3790       expr_t argsubst = arg->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
3791       return buildSimilarUnaryOpNode(argsubst, datatree);
3792     }
3793   else
3794     {
3795       if (maxEndoLead() >= 2)
3796         return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
3797       else
3798         return const_cast<UnaryOpNode *>(this);
3799     }
3800 }
3801 
3802 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3803 UnaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3804 {
3805   expr_t argsubst = arg->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
3806   return buildSimilarUnaryOpNode(argsubst, datatree);
3807 }
3808 
3809 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const3810 UnaryOpNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
3811 {
3812   if (op_code == UnaryOpcode::uminus || deterministic_model)
3813     {
3814       expr_t argsubst = arg->substituteExoLead(subst_table, neweqs, deterministic_model);
3815       return buildSimilarUnaryOpNode(argsubst, datatree);
3816     }
3817   else
3818     {
3819       if (maxExoLead() >= 1)
3820         return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
3821       else
3822         return const_cast<UnaryOpNode *>(this);
3823     }
3824 }
3825 
3826 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3827 UnaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3828 {
3829   expr_t argsubst = arg->substituteExoLag(subst_table, neweqs);
3830   return buildSimilarUnaryOpNode(argsubst, datatree);
3831 }
3832 
3833 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const3834 UnaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
3835 {
3836   if (op_code == UnaryOpcode::expectation)
3837     {
3838       if (auto it = subst_table.find(const_cast<UnaryOpNode *>(this)); it != subst_table.end())
3839         return const_cast<VariableNode *>(it->second);
3840 
3841       //Arriving here, we need to create an auxiliary variable for this Expectation Operator:
3842       //AUX_EXPECT_(LEAD/LAG)_(period)_(arg.idx) OR
3843       //AUX_EXPECT_(info_set_name)_(arg.idx)
3844       int symb_id = datatree.symbol_table.addExpectationAuxiliaryVar(expectation_information_set, arg->idx, const_cast<UnaryOpNode *>(this));
3845       expr_t newAuxE = datatree.AddVariable(symb_id, 0);
3846 
3847       if (partial_information_model && expectation_information_set == 0)
3848         if (!dynamic_cast<VariableNode *>(arg))
3849           {
3850             cerr << "ERROR: In Partial Information models, EXPECTATION(0)(X) "
3851                  << "can only be used when X is a single variable." << endl;
3852             exit(EXIT_FAILURE);
3853           }
3854 
3855       //take care of any nested expectation operators by calling arg->substituteExpectation(.), then decreaseLeadsLags for this UnaryOpcode::expectation operator
3856       //arg(lag-period) (holds entire subtree of arg(lag-period)
3857       expr_t substexpr = (arg->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(expectation_information_set);
3858       assert(substexpr);
3859       neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(newAuxE, substexpr))); //AUXE_period_arg.idx = arg(lag-period)
3860       newAuxE = datatree.AddVariable(symb_id, expectation_information_set);
3861 
3862       assert(dynamic_cast<VariableNode *>(newAuxE));
3863       subst_table[this] = dynamic_cast<VariableNode *>(newAuxE);
3864       return newAuxE;
3865     }
3866   else
3867     {
3868       expr_t argsubst = arg->substituteExpectation(subst_table, neweqs, partial_information_model);
3869       return buildSimilarUnaryOpNode(argsubst, datatree);
3870     }
3871 }
3872 
3873 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3874 UnaryOpNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3875 {
3876   expr_t argsubst = arg->differentiateForwardVars(subset, subst_table, neweqs);
3877   return buildSimilarUnaryOpNode(argsubst, datatree);
3878 }
3879 
3880 bool
isNumConstNodeEqualTo(double value) const3881 UnaryOpNode::isNumConstNodeEqualTo(double value) const
3882 {
3883   return false;
3884 }
3885 
3886 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const3887 UnaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
3888 {
3889   return false;
3890 }
3891 
3892 bool
containsPacExpectation(const string & pac_model_name) const3893 UnaryOpNode::containsPacExpectation(const string &pac_model_name) const
3894 {
3895   return arg->containsPacExpectation(pac_model_name);
3896 }
3897 
3898 bool
containsEndogenous() const3899 UnaryOpNode::containsEndogenous() const
3900 {
3901   return arg->containsEndogenous();
3902 }
3903 
3904 bool
containsExogenous() const3905 UnaryOpNode::containsExogenous() const
3906 {
3907   return arg->containsExogenous();
3908 }
3909 
3910 expr_t
replaceTrendVar() const3911 UnaryOpNode::replaceTrendVar() const
3912 {
3913   expr_t argsubst = arg->replaceTrendVar();
3914   return buildSimilarUnaryOpNode(argsubst, datatree);
3915 }
3916 
3917 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const3918 UnaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
3919 {
3920   expr_t argsubst = arg->detrend(symb_id, log_trend, trend);
3921   return buildSimilarUnaryOpNode(argsubst, datatree);
3922 }
3923 
3924 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const3925 UnaryOpNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
3926 {
3927   expr_t argsubst = arg->removeTrendLeadLag(trend_symbols_map);
3928   return buildSimilarUnaryOpNode(argsubst, datatree);
3929 }
3930 
3931 bool
isInStaticForm() const3932 UnaryOpNode::isInStaticForm() const
3933 {
3934   if (op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
3935       || op_code == UnaryOpcode::steadyStateParam2ndDeriv
3936       || op_code == UnaryOpcode::expectation)
3937     return false;
3938   else
3939     return arg->isInStaticForm();
3940 }
3941 
3942 bool
isParamTimesEndogExpr() const3943 UnaryOpNode::isParamTimesEndogExpr() const
3944 {
3945   return arg->isParamTimesEndogExpr();
3946 }
3947 
3948 bool
isVarModelReferenced(const string & model_info_name) const3949 UnaryOpNode::isVarModelReferenced(const string &model_info_name) const
3950 {
3951   return arg->isVarModelReferenced(model_info_name);
3952 }
3953 
3954 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const3955 UnaryOpNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
3956 {
3957   arg->getEndosAndMaxLags(model_endos_and_lags);
3958 }
3959 
3960 expr_t
substituteStaticAuxiliaryVariable() const3961 UnaryOpNode::substituteStaticAuxiliaryVariable() const
3962 {
3963   if (op_code == UnaryOpcode::diff)
3964     return datatree.Zero;
3965 
3966   expr_t argsubst = arg->substituteStaticAuxiliaryVariable();
3967   if (op_code == UnaryOpcode::expectation)
3968     return argsubst;
3969   else
3970     return buildSimilarUnaryOpNode(argsubst, datatree);
3971 }
3972 
3973 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const3974 UnaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
3975 {
3976   arg->findConstantEquations(table);
3977 }
3978 
3979 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const3980 UnaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
3981 {
3982   expr_t argsubst = arg->replaceVarsInEquation(table);
3983   return buildSimilarUnaryOpNode(argsubst, datatree);
3984 }
3985 
BinaryOpNode(DataTree & datatree_arg,int idx_arg,const expr_t arg1_arg,BinaryOpcode op_code_arg,const expr_t arg2_arg,int powerDerivOrder_arg)3986 BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
3987                            BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg) :
3988   ExprNode{datatree_arg, idx_arg},
3989   arg1{arg1_arg},
3990   arg2{arg2_arg},
3991   op_code{op_code_arg},
3992   powerDerivOrder{powerDerivOrder_arg}
3993 {
3994   assert(powerDerivOrder >= 0);
3995 }
3996 
3997 void
prepareForDerivation()3998 BinaryOpNode::prepareForDerivation()
3999 {
4000   if (preparedForDerivation)
4001     return;
4002 
4003   preparedForDerivation = true;
4004 
4005   arg1->prepareForDerivation();
4006   arg2->prepareForDerivation();
4007 
4008   // Non-null derivatives are the union of those of the arguments
4009   // Compute set union of arg1->non_null_derivatives and arg2->non_null_derivatives
4010   set_union(arg1->non_null_derivatives.begin(),
4011             arg1->non_null_derivatives.end(),
4012             arg2->non_null_derivatives.begin(),
4013             arg2->non_null_derivatives.end(),
4014             inserter(non_null_derivatives, non_null_derivatives.begin()));
4015 }
4016 
4017 expr_t
getNonZeroPartofEquation() const4018 BinaryOpNode::getNonZeroPartofEquation() const
4019 {
4020   assert(arg1 == datatree.Zero || arg2 == datatree.Zero);
4021   if (arg1 == datatree.Zero)
4022     return arg2;
4023   return arg1;
4024 }
4025 
4026 expr_t
composeDerivatives(expr_t darg1,expr_t darg2)4027 BinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2)
4028 {
4029   expr_t t11, t12, t13, t14, t15;
4030 
4031   switch (op_code)
4032     {
4033     case BinaryOpcode::plus:
4034       return datatree.AddPlus(darg1, darg2);
4035     case BinaryOpcode::minus:
4036       return datatree.AddMinus(darg1, darg2);
4037     case BinaryOpcode::times:
4038       t11 = datatree.AddTimes(darg1, arg2);
4039       t12 = datatree.AddTimes(darg2, arg1);
4040       return datatree.AddPlus(t11, t12);
4041     case BinaryOpcode::divide:
4042       if (darg2 != datatree.Zero)
4043         {
4044           t11 = datatree.AddTimes(darg1, arg2);
4045           t12 = datatree.AddTimes(darg2, arg1);
4046           t13 = datatree.AddMinus(t11, t12);
4047           t14 = datatree.AddTimes(arg2, arg2);
4048           return datatree.AddDivide(t13, t14);
4049         }
4050       else
4051         return datatree.AddDivide(darg1, arg2);
4052     case BinaryOpcode::less:
4053     case BinaryOpcode::greater:
4054     case BinaryOpcode::lessEqual:
4055     case BinaryOpcode::greaterEqual:
4056     case BinaryOpcode::equalEqual:
4057     case BinaryOpcode::different:
4058       return datatree.Zero;
4059     case BinaryOpcode::power:
4060       if (darg2 == datatree.Zero)
4061         if (darg1 == datatree.Zero)
4062           return datatree.Zero;
4063         else
4064           if (dynamic_cast<NumConstNode *>(arg2))
4065             {
4066               t11 = datatree.AddMinus(arg2, datatree.One);
4067               t12 = datatree.AddPower(arg1, t11);
4068               t13 = datatree.AddTimes(arg2, t12);
4069               return datatree.AddTimes(darg1, t13);
4070             }
4071           else
4072             return datatree.AddTimes(darg1, datatree.AddPowerDeriv(arg1, arg2, powerDerivOrder + 1));
4073       else
4074         {
4075           t11 = datatree.AddLog(arg1);
4076           t12 = datatree.AddTimes(darg2, t11);
4077           t13 = datatree.AddTimes(darg1, arg2);
4078           t14 = datatree.AddDivide(t13, arg1);
4079           t15 = datatree.AddPlus(t12, t14);
4080           return datatree.AddTimes(t15, this);
4081         }
4082     case BinaryOpcode::powerDeriv:
4083       if (darg2 == datatree.Zero)
4084         return datatree.AddTimes(darg1, datatree.AddPowerDeriv(arg1, arg2, powerDerivOrder + 1));
4085       else
4086         {
4087           t11 = datatree.AddTimes(darg2, datatree.AddLog(arg1));
4088           t12 = datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(powerDerivOrder));
4089           t13 = datatree.AddTimes(darg1, t12);
4090           t14 = datatree.AddDivide(t13, arg1);
4091           t15 = datatree.AddPlus(t11, t14);
4092           expr_t f = datatree.AddPower(arg1, t12);
4093           expr_t first_part = datatree.AddTimes(f, t15);
4094 
4095           for (int i = 0; i < powerDerivOrder; i++)
4096             first_part = datatree.AddTimes(first_part, datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(i)));
4097 
4098           t13 = datatree.Zero;
4099           for (int i = 0; i < powerDerivOrder; i++)
4100             {
4101               t11 = datatree.One;
4102               for (int j = 0; j < powerDerivOrder; j++)
4103                 if (i != j)
4104                   {
4105                     t12 = datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(j));
4106                     t11 = datatree.AddTimes(t11, t12);
4107                   }
4108               t13 = datatree.AddPlus(t13, t11);
4109             }
4110           t13 = datatree.AddTimes(darg2, t13);
4111           t14 = datatree.AddTimes(f, t13);
4112           return datatree.AddPlus(first_part, t14);
4113         }
4114     case BinaryOpcode::max:
4115       t11 = datatree.AddGreater(arg1, arg2);
4116       t12 = datatree.AddTimes(t11, darg1);
4117       t13 = datatree.AddMinus(datatree.One, t11);
4118       t14 = datatree.AddTimes(t13, darg2);
4119       return datatree.AddPlus(t14, t12);
4120     case BinaryOpcode::min:
4121       t11 = datatree.AddGreater(arg2, arg1);
4122       t12 = datatree.AddTimes(t11, darg1);
4123       t13 = datatree.AddMinus(datatree.One, t11);
4124       t14 = datatree.AddTimes(t13, darg2);
4125       return datatree.AddPlus(t14, t12);
4126     case BinaryOpcode::equal:
4127       return datatree.AddMinus(darg1, darg2);
4128     }
4129   // Suppress GCC warning
4130   exit(EXIT_FAILURE);
4131 }
4132 
4133 expr_t
unpackPowerDeriv() const4134 BinaryOpNode::unpackPowerDeriv() const
4135 {
4136   if (op_code != BinaryOpcode::powerDeriv)
4137     return const_cast<BinaryOpNode *>(this);
4138 
4139   expr_t front = datatree.One;
4140   for (int i = 0; i < powerDerivOrder; i++)
4141     front = datatree.AddTimes(front,
4142                               datatree.AddMinus(arg2,
4143                                                 datatree.AddPossiblyNegativeConstant(i)));
4144   expr_t tmp = datatree.AddPower(arg1,
4145                                  datatree.AddMinus(arg2,
4146                                                    datatree.AddPossiblyNegativeConstant(powerDerivOrder)));
4147   return datatree.AddTimes(front, tmp);
4148 }
4149 
4150 expr_t
computeDerivative(int deriv_id)4151 BinaryOpNode::computeDerivative(int deriv_id)
4152 {
4153   expr_t darg1 = arg1->getDerivative(deriv_id);
4154   expr_t darg2 = arg2->getDerivative(deriv_id);
4155   return composeDerivatives(darg1, darg2);
4156 }
4157 
4158 int
precedence(ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms) const4159 BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
4160 {
4161   // A temporary term behaves as a variable
4162   if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4163     return 100;
4164 
4165   switch (op_code)
4166     {
4167     case BinaryOpcode::equal:
4168       return 0;
4169     case BinaryOpcode::equalEqual:
4170     case BinaryOpcode::different:
4171       return 1;
4172     case BinaryOpcode::lessEqual:
4173     case BinaryOpcode::greaterEqual:
4174     case BinaryOpcode::less:
4175     case BinaryOpcode::greater:
4176       return 2;
4177     case BinaryOpcode::plus:
4178     case BinaryOpcode::minus:
4179       return 3;
4180     case BinaryOpcode::times:
4181     case BinaryOpcode::divide:
4182       return 4;
4183     case BinaryOpcode::power:
4184     case BinaryOpcode::powerDeriv:
4185       if (isCOutput(output_type))
4186         // In C, power operator is of the form pow(a, b)
4187         return 100;
4188       else
4189         return 5;
4190     case BinaryOpcode::min:
4191     case BinaryOpcode::max:
4192       return 100;
4193     }
4194   // Suppress GCC warning
4195   exit(EXIT_FAILURE);
4196 }
4197 
4198 int
precedenceJson(const temporary_terms_t & temporary_terms) const4199 BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const
4200 {
4201   // A temporary term behaves as a variable
4202   if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4203     return 100;
4204 
4205   switch (op_code)
4206     {
4207     case BinaryOpcode::equal:
4208       return 0;
4209     case BinaryOpcode::equalEqual:
4210     case BinaryOpcode::different:
4211       return 1;
4212     case BinaryOpcode::lessEqual:
4213     case BinaryOpcode::greaterEqual:
4214     case BinaryOpcode::less:
4215     case BinaryOpcode::greater:
4216       return 2;
4217     case BinaryOpcode::plus:
4218     case BinaryOpcode::minus:
4219       return 3;
4220     case BinaryOpcode::times:
4221     case BinaryOpcode::divide:
4222       return 4;
4223     case BinaryOpcode::power:
4224     case BinaryOpcode::powerDeriv:
4225       return 5;
4226     case BinaryOpcode::min:
4227     case BinaryOpcode::max:
4228       return 100;
4229     }
4230   // Suppress GCC warning
4231   exit(EXIT_FAILURE);
4232 }
4233 
4234 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const4235 BinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
4236 {
4237   // For a temporary term, the cost is null
4238   for (const auto &it : temp_terms_map)
4239     if (it.second.find(const_cast<BinaryOpNode *>(this)) != it.second.end())
4240       return 0;
4241 
4242   int arg_cost = arg1->cost(temp_terms_map, is_matlab) + arg2->cost(temp_terms_map, is_matlab);
4243 
4244   return cost(arg_cost, is_matlab);
4245 }
4246 
4247 int
cost(const temporary_terms_t & temporary_terms,bool is_matlab) const4248 BinaryOpNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
4249 {
4250   // For a temporary term, the cost is null
4251   if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4252     return 0;
4253 
4254   int arg_cost = arg1->cost(temporary_terms, is_matlab) + arg2->cost(temporary_terms, is_matlab);
4255 
4256   return cost(arg_cost, is_matlab);
4257 }
4258 
4259 int
cost(int cost,bool is_matlab) const4260 BinaryOpNode::cost(int cost, bool is_matlab) const
4261 {
4262   if (is_matlab)
4263     // Cost for Matlab files
4264     switch (op_code)
4265       {
4266       case BinaryOpcode::less:
4267       case BinaryOpcode::greater:
4268       case BinaryOpcode::lessEqual:
4269       case BinaryOpcode::greaterEqual:
4270       case BinaryOpcode::equalEqual:
4271       case BinaryOpcode::different:
4272         return cost + 60;
4273       case BinaryOpcode::plus:
4274       case BinaryOpcode::minus:
4275       case BinaryOpcode::times:
4276         return cost + 90;
4277       case BinaryOpcode::max:
4278       case BinaryOpcode::min:
4279         return cost + 110;
4280       case BinaryOpcode::divide:
4281         return cost + 990;
4282       case BinaryOpcode::power:
4283       case BinaryOpcode::powerDeriv:
4284         return cost + (min_cost_matlab/2+1);
4285       case BinaryOpcode::equal:
4286         return cost;
4287       }
4288   else
4289     // Cost for C files
4290     switch (op_code)
4291       {
4292       case BinaryOpcode::less:
4293       case BinaryOpcode::greater:
4294       case BinaryOpcode::lessEqual:
4295       case BinaryOpcode::greaterEqual:
4296       case BinaryOpcode::equalEqual:
4297       case BinaryOpcode::different:
4298         return cost + 2;
4299       case BinaryOpcode::plus:
4300       case BinaryOpcode::minus:
4301       case BinaryOpcode::times:
4302         return cost + 4;
4303       case BinaryOpcode::max:
4304       case BinaryOpcode::min:
4305         return cost + 5;
4306       case BinaryOpcode::divide:
4307         return cost + 15;
4308       case BinaryOpcode::power:
4309         return cost + 520;
4310       case BinaryOpcode::powerDeriv:
4311         return cost + (min_cost_c/2+1);;
4312       case BinaryOpcode::equal:
4313         return cost;
4314       }
4315   // Suppress GCC warning
4316   exit(EXIT_FAILURE);
4317 }
4318 
4319 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const4320 BinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
4321                                     map<pair<int, int>, temporary_terms_t> &temp_terms_map,
4322                                     map<expr_t, pair<int, pair<int, int>>> &reference_count,
4323                                     bool is_matlab) const
4324 {
4325   expr_t this2 = const_cast<BinaryOpNode *>(this);
4326   if (auto it = reference_count.find(this2);
4327       it == reference_count.end())
4328     {
4329       // If this node has never been encountered, set its ref count to one,
4330       //  and travel through its children
4331       reference_count[this2] = { 1, derivOrder };
4332       arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
4333       arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
4334     }
4335   else
4336     {
4337       /* If the node has already been encountered, increment its ref count
4338          and declare it as a temporary term if it is too costly (except if it is
4339          an equal node: we don't want them as temporary terms) */
4340       reference_count[this2] = { it->second.first + 1, it->second.second };;
4341       if (reference_count[this2].first * cost(temp_terms_map, is_matlab) > min_cost(is_matlab)
4342           && op_code != BinaryOpcode::equal)
4343         temp_terms_map[reference_count[this2].second].insert(this2);
4344     }
4345 }
4346 
4347 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const4348 BinaryOpNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
4349                                     temporary_terms_t &temporary_terms,
4350                                     map<expr_t, pair<int, int>> &first_occurence,
4351                                     int Curr_block,
4352                                     vector<vector<temporary_terms_t>> &v_temporary_terms,
4353                                     int equation) const
4354 {
4355   expr_t this2 = const_cast<BinaryOpNode *>(this);
4356   if (auto it = reference_count.find(this2);
4357       it == reference_count.end())
4358     {
4359       reference_count[this2] = 1;
4360       first_occurence[this2] = { Curr_block, equation };
4361       arg1->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
4362       arg2->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
4363     }
4364   else
4365     {
4366       reference_count[this2]++;
4367       if (reference_count[this2] * cost(temporary_terms, false) > min_cost_c
4368           && op_code != BinaryOpcode::equal)
4369         {
4370           temporary_terms.insert(this2);
4371           v_temporary_terms[first_occurence[this2].first][first_occurence[this2].second].insert(this2);
4372         }
4373     }
4374 }
4375 
4376 double
eval_opcode(double v1,BinaryOpcode op_code,double v2,int derivOrder)4377 BinaryOpNode::eval_opcode(double v1, BinaryOpcode op_code, double v2, int derivOrder) noexcept(false)
4378 {
4379   switch (op_code)
4380     {
4381     case BinaryOpcode::plus:
4382       return v1 + v2;
4383     case BinaryOpcode::minus:
4384       return v1 - v2;
4385     case BinaryOpcode::times:
4386       return v1 * v2;
4387     case BinaryOpcode::divide:
4388       return v1 / v2;
4389     case BinaryOpcode::power:
4390       return pow(v1, v2);
4391     case BinaryOpcode::powerDeriv:
4392       if (fabs(v1) < near_zero && v2 > 0
4393           && derivOrder > v2
4394           && fabs(v2-nearbyint(v2)) < near_zero)
4395         return 0.0;
4396       else
4397         {
4398           double dxp = pow(v1, v2-derivOrder);
4399           for (int i = 0; i < derivOrder; i++)
4400             dxp *= v2--;
4401           return dxp;
4402         }
4403     case BinaryOpcode::max:
4404       if (v1 < v2)
4405         return v2;
4406       else
4407         return v1;
4408     case BinaryOpcode::min:
4409       if (v1 > v2)
4410         return v2;
4411       else
4412         return v1;
4413     case BinaryOpcode::less:
4414       return v1 < v2;
4415     case BinaryOpcode::greater:
4416       return v1 > v2;
4417     case BinaryOpcode::lessEqual:
4418       return v1 <= v2;
4419     case BinaryOpcode::greaterEqual:
4420       return v1 >= v2;
4421     case BinaryOpcode::equalEqual:
4422       return v1 == v2;
4423     case BinaryOpcode::different:
4424       return v1 != v2;
4425     case BinaryOpcode::equal:
4426       throw EvalException();
4427     }
4428   // Suppress GCC warning
4429   exit(EXIT_FAILURE);
4430 }
4431 
4432 double
eval(const eval_context_t & eval_context) const4433 BinaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
4434 {
4435   double v1 = arg1->eval(eval_context);
4436   double v2 = arg2->eval(eval_context);
4437   return eval_opcode(v1, op_code, v2, powerDerivOrder);
4438 }
4439 
4440 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const4441 BinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
4442                       bool lhs_rhs, const temporary_terms_t &temporary_terms,
4443                       const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
4444                       const deriv_node_temp_terms_t &tef_terms) const
4445 {
4446   // If current node is a temporary term
4447   if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4448     {
4449       if (dynamic)
4450         {
4451           auto ii = map_idx.find(idx);
4452           FLDT_ fldt(ii->second);
4453           fldt.write(CompileCode, instruction_number);
4454         }
4455       else
4456         {
4457           auto ii = map_idx.find(idx);
4458           FLDST_ fldst(ii->second);
4459           fldst.write(CompileCode, instruction_number);
4460         }
4461       return;
4462     }
4463   if (op_code == BinaryOpcode::powerDeriv)
4464     {
4465       FLDC_ fldc(powerDerivOrder);
4466       fldc.write(CompileCode, instruction_number);
4467     }
4468   arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
4469   arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
4470   FBINARY_ fbinary{static_cast<int>(op_code)};
4471   fbinary.write(CompileCode, instruction_number);
4472 }
4473 
4474 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const4475 BinaryOpNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
4476 {
4477   if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4478     temporary_terms_inuse.insert(idx);
4479   else
4480     {
4481       arg1->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
4482       arg2->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
4483     }
4484 }
4485 
4486 bool
containsExternalFunction() const4487 BinaryOpNode::containsExternalFunction() const
4488 {
4489   return arg1->containsExternalFunction()
4490     || arg2->containsExternalFunction();
4491 }
4492 
4493 void
writeJsonAST(ostream & output) const4494 BinaryOpNode::writeJsonAST(ostream &output) const
4495 {
4496   output << R"({"node_type" : "BinaryOpNode",)"
4497          << R"( "op" : ")";
4498   switch (op_code)
4499     {
4500     case BinaryOpcode::plus:
4501       output << "+";
4502       break;
4503     case BinaryOpcode::minus:
4504       output << "-";
4505       break;
4506     case BinaryOpcode::times:
4507       output << "*";
4508       break;
4509     case BinaryOpcode::divide:
4510       output << "/";
4511       break;
4512     case BinaryOpcode::power:
4513       output << "^";
4514       break;
4515     case BinaryOpcode::less:
4516       output << "<";
4517       break;
4518     case BinaryOpcode::greater:
4519       output << ">";
4520       break;
4521     case BinaryOpcode::lessEqual:
4522       output << "<=";
4523       break;
4524     case BinaryOpcode::greaterEqual:
4525       output << ">=";
4526       break;
4527     case BinaryOpcode::equalEqual:
4528       output << "==";
4529       break;
4530     case BinaryOpcode::different:
4531       output << "!=";
4532       break;
4533     case BinaryOpcode::equal:
4534       output << "=";
4535       break;
4536     case BinaryOpcode::max:
4537       output << "max";
4538       break;
4539     case BinaryOpcode::min:
4540       output << "min";
4541       break;
4542     case BinaryOpcode::powerDeriv:
4543       output << "power_deriv";
4544       break;
4545     }
4546   output << R"(", "arg1" : )";
4547   arg1->writeJsonAST(output);
4548   output << R"(, "arg2" : )";
4549   arg2->writeJsonAST(output);
4550   output << "}";
4551 }
4552 
4553 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const4554 BinaryOpNode::writeJsonOutput(ostream &output,
4555                               const temporary_terms_t &temporary_terms,
4556                               const deriv_node_temp_terms_t &tef_terms,
4557                               bool isdynamic) const
4558 {
4559   // If current node is a temporary term
4560   if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4561     {
4562       output << "T" << idx;
4563       return;
4564     }
4565 
4566   if (op_code == BinaryOpcode::max || op_code == BinaryOpcode::min)
4567     {
4568       switch (op_code)
4569         {
4570         case BinaryOpcode::max:
4571           output << "max(";
4572           break;
4573         case BinaryOpcode::min:
4574           output << "min(";
4575           break;
4576         default:
4577           ;
4578         }
4579       arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4580       output << ",";
4581       arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4582       output << ")";
4583       return;
4584     }
4585 
4586   if (op_code == BinaryOpcode::powerDeriv)
4587     {
4588       output << "get_power_deriv(";
4589       arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4590       output << ",";
4591       arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4592       output << "," << powerDerivOrder << ")";
4593       return;
4594     }
4595 
4596   int prec = precedenceJson(temporary_terms);
4597 
4598   bool close_parenthesis = false;
4599 
4600   // If left argument has a lower precedence, or if current and left argument are both power operators,
4601   // add parenthesis around left argument
4602   if (auto barg1 = dynamic_cast<BinaryOpNode *>(arg1);
4603       arg1->precedenceJson(temporary_terms) < prec
4604       || (op_code == BinaryOpcode::power && barg1 && barg1->op_code == BinaryOpcode::power))
4605     {
4606       output << "(";
4607       close_parenthesis = true;
4608     }
4609 
4610   // Write left argument
4611   arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4612 
4613   if (close_parenthesis)
4614     output << ")";
4615 
4616   // Write current operator symbol
4617   switch (op_code)
4618     {
4619     case BinaryOpcode::plus:
4620       output << "+";
4621       break;
4622     case BinaryOpcode::minus:
4623       output << "-";
4624       break;
4625     case BinaryOpcode::times:
4626       output << "*";
4627       break;
4628     case BinaryOpcode::divide:
4629       output << "/";
4630       break;
4631     case BinaryOpcode::power:
4632       output << "^";
4633       break;
4634     case BinaryOpcode::less:
4635       output << "<";
4636       break;
4637     case BinaryOpcode::greater:
4638       output << ">";
4639       break;
4640     case BinaryOpcode::lessEqual:
4641       output << "<=";
4642       break;
4643     case BinaryOpcode::greaterEqual:
4644       output << ">=";
4645       break;
4646     case BinaryOpcode::equalEqual:
4647       output << "==";
4648       break;
4649     case BinaryOpcode::different:
4650       output << "!=";
4651       break;
4652     case BinaryOpcode::equal:
4653       output << "=";
4654       break;
4655     default:
4656       ;
4657     }
4658 
4659   close_parenthesis = false;
4660 
4661   /* Add parenthesis around right argument if:
4662      - its precedence is lower than those of the current node
4663      - it is a power operator and current operator is also a power operator
4664      - it is a minus operator with same precedence than current operator
4665      - it is a divide operator with same precedence than current operator */
4666   auto barg2 = dynamic_cast<BinaryOpNode *>(arg2);
4667   if (int arg2_prec = arg2->precedenceJson(temporary_terms); arg2_prec < prec
4668       || (op_code == BinaryOpcode::power && barg2 && barg2->op_code == BinaryOpcode::power)
4669       || (op_code == BinaryOpcode::minus && arg2_prec == prec)
4670       || (op_code == BinaryOpcode::divide && arg2_prec == prec))
4671     {
4672       output << "(";
4673       close_parenthesis = true;
4674     }
4675 
4676   // Write right argument
4677   arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4678 
4679   if (close_parenthesis)
4680     output << ")";
4681 }
4682 
4683 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const4684 BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
4685                           const temporary_terms_t &temporary_terms,
4686                           const temporary_terms_idxs_t &temporary_terms_idxs,
4687                           const deriv_node_temp_terms_t &tef_terms) const
4688 {
4689   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
4690     return;
4691 
4692   // Treat derivative of Power
4693   if (op_code == BinaryOpcode::powerDeriv)
4694     {
4695       if (isLatexOutput(output_type))
4696         unpackPowerDeriv()->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4697       else
4698         {
4699           if (output_type == ExprNodeOutputType::juliaStaticModel || output_type == ExprNodeOutputType::juliaDynamicModel)
4700             output << "get_power_deriv(";
4701           else
4702             output << "getPowerDeriv(";
4703           arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4704           output << ",";
4705           arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4706           output << "," << powerDerivOrder << ")";
4707         }
4708       return;
4709     }
4710 
4711   // Treat special case of power operator in C, and case of max and min operators
4712   if ((op_code == BinaryOpcode::power && isCOutput(output_type)) || op_code == BinaryOpcode::max || op_code == BinaryOpcode::min)
4713     {
4714       switch (op_code)
4715         {
4716         case BinaryOpcode::power:
4717           output << "pow(";
4718           break;
4719         case BinaryOpcode::max:
4720           output << "max(";
4721           break;
4722         case BinaryOpcode::min:
4723           output << "min(";
4724           break;
4725         default:
4726           ;
4727         }
4728       arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4729       output << ",";
4730       arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4731       output << ")";
4732       return;
4733     }
4734 
4735   int prec = precedence(output_type, temporary_terms);
4736 
4737   bool close_parenthesis = false;
4738 
4739   if (isLatexOutput(output_type) && op_code == BinaryOpcode::divide)
4740     output << R"(\frac{)";
4741   else
4742     {
4743       // If left argument has a lower precedence, or if current and left argument are both power operators, add parenthesis around left argument
4744       auto barg1 = dynamic_cast<BinaryOpNode *>(arg1);
4745       if (arg1->precedence(output_type, temporary_terms) < prec
4746           || (op_code == BinaryOpcode::power && barg1 != nullptr && barg1->op_code == BinaryOpcode::power))
4747         {
4748           output << LEFT_PAR(output_type);
4749           close_parenthesis = true;
4750         }
4751     }
4752 
4753   // Write left argument
4754   arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4755 
4756   if (close_parenthesis)
4757     output << RIGHT_PAR(output_type);
4758 
4759   if (isLatexOutput(output_type) && op_code == BinaryOpcode::divide)
4760     output << "}";
4761 
4762   // Write current operator symbol
4763   switch (op_code)
4764     {
4765     case BinaryOpcode::plus:
4766       output << "+";
4767       break;
4768     case BinaryOpcode::minus:
4769       output << "-";
4770       break;
4771     case BinaryOpcode::times:
4772       if (isLatexOutput(output_type))
4773         output << R"(\, )";
4774       else
4775         output << "*";
4776       break;
4777     case BinaryOpcode::divide:
4778       if (!isLatexOutput(output_type))
4779         output << "/";
4780       break;
4781     case BinaryOpcode::power:
4782       output << "^";
4783       break;
4784     case BinaryOpcode::less:
4785       output << "<";
4786       break;
4787     case BinaryOpcode::greater:
4788       output << ">";
4789       break;
4790     case BinaryOpcode::lessEqual:
4791       if (isLatexOutput(output_type))
4792         output << R"(\leq )";
4793       else
4794         output << "<=";
4795       break;
4796     case BinaryOpcode::greaterEqual:
4797       if (isLatexOutput(output_type))
4798         output << R"(\geq )";
4799       else
4800         output << ">=";
4801       break;
4802     case BinaryOpcode::equalEqual:
4803       output << "==";
4804       break;
4805     case BinaryOpcode::different:
4806       if (isMatlabOutput(output_type))
4807         output << "~=";
4808       else
4809         {
4810           if (isCOutput(output_type) || isJuliaOutput(output_type))
4811             output << "!=";
4812           else
4813             output << R"(\neq )";
4814         }
4815       break;
4816     case BinaryOpcode::equal:
4817       output << "=";
4818       break;
4819     default:
4820       ;
4821     }
4822 
4823   close_parenthesis = false;
4824 
4825   if (isLatexOutput(output_type) && (op_code == BinaryOpcode::power || op_code == BinaryOpcode::divide))
4826     output << "{";
4827   else
4828     {
4829       /* Add parenthesis around right argument if:
4830          - its precedence is lower than those of the current node
4831          - it is a power operator and current operator is also a power operator
4832          - it is a minus operator with same precedence than current operator
4833          - it is a divide operator with same precedence than current operator */
4834       auto barg2 = dynamic_cast<BinaryOpNode *>(arg2);
4835       if (int arg2_prec = arg2->precedence(output_type, temporary_terms); arg2_prec < prec
4836           || (op_code == BinaryOpcode::power && barg2 && barg2->op_code == BinaryOpcode::power && !isLatexOutput(output_type))
4837           || (op_code == BinaryOpcode::minus && arg2_prec == prec)
4838           || (op_code == BinaryOpcode::divide && arg2_prec == prec && !isLatexOutput(output_type)))
4839         {
4840           output << LEFT_PAR(output_type);
4841           close_parenthesis = true;
4842         }
4843     }
4844 
4845   // Write right argument
4846   arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4847 
4848   if (isLatexOutput(output_type) && (op_code == BinaryOpcode::power || op_code == BinaryOpcode::divide))
4849     output << "}";
4850 
4851   if (close_parenthesis)
4852     output << RIGHT_PAR(output_type);
4853 }
4854 
4855 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const4856 BinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
4857                                           const temporary_terms_t &temporary_terms,
4858                                           const temporary_terms_idxs_t &temporary_terms_idxs,
4859                                           deriv_node_temp_terms_t &tef_terms) const
4860 {
4861   arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4862   arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4863 }
4864 
4865 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const4866 BinaryOpNode::writeJsonExternalFunctionOutput(vector<string> &efout,
4867                                               const temporary_terms_t &temporary_terms,
4868                                               deriv_node_temp_terms_t &tef_terms,
4869                                               bool isdynamic) const
4870 {
4871   arg1->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
4872   arg2->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
4873 }
4874 
4875 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const4876 BinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
4877                                             bool lhs_rhs, const temporary_terms_t &temporary_terms,
4878                                             const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
4879                                             deriv_node_temp_terms_t &tef_terms) const
4880 {
4881   arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
4882                                       dynamic, steady_dynamic, tef_terms);
4883   arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
4884                                       dynamic, steady_dynamic, tef_terms);
4885 }
4886 
4887 int
VarMinLag() const4888 BinaryOpNode::VarMinLag() const
4889 {
4890   return min(arg1->VarMinLag(), arg2->VarMinLag());
4891 }
4892 
4893 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const4894 BinaryOpNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
4895 {
4896   return max(arg1->VarMaxLag(lhs_lag_equiv),
4897              arg2->VarMaxLag(lhs_lag_equiv));
4898 }
4899 
4900 void
collectVARLHSVariable(set<expr_t> & result) const4901 BinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
4902 {
4903   cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
4904   exit(EXIT_FAILURE);
4905 }
4906 
4907 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const4908 BinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
4909 {
4910   arg1->collectDynamicVariables(type_arg, result);
4911   arg2->collectDynamicVariables(type_arg, result);
4912 }
4913 
4914 expr_t
Compute_RHS(expr_t arg1,expr_t arg2,int op,int op_type) const4915 BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const
4916 {
4917   temporary_terms_t temp;
4918   switch (op_type)
4919     {
4920     case 0: /*Unary Operator*/
4921       switch (static_cast<UnaryOpcode>(op))
4922         {
4923         case UnaryOpcode::uminus:
4924           return (datatree.AddUMinus(arg1));
4925           break;
4926         case UnaryOpcode::exp:
4927           return (datatree.AddExp(arg1));
4928           break;
4929         case UnaryOpcode::log:
4930           return (datatree.AddLog(arg1));
4931           break;
4932         case UnaryOpcode::log10:
4933           return (datatree.AddLog10(arg1));
4934           break;
4935         default:
4936           cerr << "BinaryOpNode::Compute_RHS: case not handled";
4937           exit(EXIT_FAILURE);
4938         }
4939       break;
4940     case 1: /*Binary Operator*/
4941       switch (static_cast<BinaryOpcode>(op))
4942         {
4943         case BinaryOpcode::plus:
4944           return (datatree.AddPlus(arg1, arg2));
4945           break;
4946         case BinaryOpcode::minus:
4947           return (datatree.AddMinus(arg1, arg2));
4948           break;
4949         case BinaryOpcode::times:
4950           return (datatree.AddTimes(arg1, arg2));
4951           break;
4952         case BinaryOpcode::divide:
4953           return (datatree.AddDivide(arg1, arg2));
4954           break;
4955         case BinaryOpcode::power:
4956           return (datatree.AddPower(arg1, arg2));
4957           break;
4958         default:
4959           cerr << "BinaryOpNode::Compute_RHS: case not handled";
4960           exit(EXIT_FAILURE);
4961         }
4962       break;
4963     }
4964   return nullptr;
4965 }
4966 
4967 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const4968 BinaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
4969 {
4970   /* Checks if the current value of the endogenous variable related to the equation
4971      is present in the arguments of the binary operator. */
4972   vector<tuple<int, expr_t, expr_t>> List_of_Op_RHS1, List_of_Op_RHS2;
4973   pair<int, expr_t> res = arg1->normalizeEquation(var_endo, List_of_Op_RHS1);
4974   int is_endogenous_present_1 = res.first;
4975   expr_t expr_t_1 = res.second;
4976 
4977   res = arg2->normalizeEquation(var_endo, List_of_Op_RHS2);
4978   int is_endogenous_present_2 = res.first;
4979   expr_t expr_t_2 = res.second;
4980 
4981   /* If the two expressions contains the current value of the endogenous variable associated to the equation
4982      the equation could not be normalized and the process is given-up.*/
4983   if (is_endogenous_present_1 == 2 || is_endogenous_present_2 == 2)
4984     return { 2, nullptr };
4985   else if (is_endogenous_present_1 && is_endogenous_present_2)
4986     return { 2, nullptr };
4987   else if (is_endogenous_present_1) /*If the current values of the endogenous variable associated to the equation
4988                                       is present only in the first operand of the expression, we try to normalize the equation*/
4989     {
4990       if (op_code == BinaryOpcode::equal) /* The end of the normalization process :
4991                                              All the operations needed to normalize the equation are applied. */
4992         while (!List_of_Op_RHS1.empty())
4993           {
4994             tuple<int, expr_t, expr_t> it = List_of_Op_RHS1.back();
4995             List_of_Op_RHS1.pop_back();
4996             if (get<1>(it) && !get<2>(it)) /*Binary operator*/
4997               expr_t_2 = Compute_RHS(expr_t_2, static_cast<BinaryOpNode *>(get<1>(it)), get<0>(it), 1);
4998             else if (get<2>(it) && !get<1>(it)) /*Binary operator*/
4999               expr_t_2 = Compute_RHS(get<2>(it), expr_t_2, get<0>(it), 1);
5000             else if (get<2>(it) && get<1>(it)) /*Binary operator*/
5001               expr_t_2 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1);
5002             else /*Unary operator*/
5003               expr_t_2 = Compute_RHS(static_cast<UnaryOpNode *>(expr_t_2), static_cast<UnaryOpNode *>(get<1>(it)), get<0>(it), 0);
5004           }
5005       else
5006         List_of_Op_RHS = List_of_Op_RHS1;
5007     }
5008   else if (is_endogenous_present_2)
5009     {
5010       if (op_code == BinaryOpcode::equal)
5011         while (!List_of_Op_RHS2.empty())
5012           {
5013             tuple<int, expr_t, expr_t> it = List_of_Op_RHS2.back();
5014             List_of_Op_RHS2.pop_back();
5015             if (get<1>(it) && !get<2>(it)) /*Binary operator*/
5016               expr_t_1 = Compute_RHS(static_cast<BinaryOpNode *>(expr_t_1), static_cast<BinaryOpNode *>(get<1>(it)), get<0>(it), 1);
5017             else if (get<2>(it) && !get<1>(it)) /*Binary operator*/
5018               expr_t_1 = Compute_RHS(static_cast<BinaryOpNode *>(get<2>(it)), static_cast<BinaryOpNode *>(expr_t_1), get<0>(it), 1);
5019             else if (get<2>(it) && get<1>(it)) /*Binary operator*/
5020               expr_t_1 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1);
5021             else
5022               expr_t_1 = Compute_RHS(static_cast<UnaryOpNode *>(expr_t_1), static_cast<UnaryOpNode *>(get<1>(it)), get<0>(it), 0);
5023           }
5024       else
5025         List_of_Op_RHS = List_of_Op_RHS2;
5026     }
5027   switch (op_code)
5028     {
5029     case BinaryOpcode::plus:
5030       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5031         return { 0, datatree.AddPlus(expr_t_1, expr_t_2) };
5032       else if (is_endogenous_present_1 && is_endogenous_present_2)
5033         return { 2, nullptr };
5034       else if (!is_endogenous_present_1 && is_endogenous_present_2)
5035         {
5036           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_1, nullptr);
5037           return { 1, expr_t_1 };
5038         }
5039       else if (is_endogenous_present_1 && !is_endogenous_present_2)
5040         {
5041           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_2, nullptr);
5042           return { 1, expr_t_2 };
5043         }
5044       break;
5045     case BinaryOpcode::minus:
5046       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5047         return { 0, datatree.AddMinus(expr_t_1, expr_t_2) };
5048       else if (is_endogenous_present_1 && is_endogenous_present_2)
5049         return { 2, nullptr };
5050       else if (!is_endogenous_present_1 && is_endogenous_present_2)
5051         {
5052           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::uminus), nullptr, nullptr);
5053           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_1, nullptr);
5054           return { 1, expr_t_1 };
5055         }
5056       else if (is_endogenous_present_1 && !is_endogenous_present_2)
5057         {
5058           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::plus), expr_t_2, nullptr);
5059           return { 1, datatree.AddUMinus(expr_t_2) };
5060         }
5061       break;
5062     case BinaryOpcode::times:
5063       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5064         return { 0, datatree.AddTimes(expr_t_1, expr_t_2) };
5065       else if (!is_endogenous_present_1 && is_endogenous_present_2)
5066         {
5067           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), expr_t_1, nullptr);
5068           return { 1, expr_t_1 };
5069         }
5070       else if (is_endogenous_present_1 && !is_endogenous_present_2)
5071         {
5072           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), expr_t_2, nullptr);
5073           return { 1, expr_t_2 };
5074         }
5075       else
5076         return { 2, nullptr };
5077       break;
5078     case BinaryOpcode::divide:
5079       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5080         return { 0, datatree.AddDivide(expr_t_1, expr_t_2) };
5081       else if (!is_endogenous_present_1 && is_endogenous_present_2)
5082         {
5083           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), nullptr, expr_t_1);
5084           return { 1, expr_t_1 };
5085         }
5086       else if (is_endogenous_present_1 && !is_endogenous_present_2)
5087         {
5088           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::times), expr_t_2, nullptr);
5089           return { 1, expr_t_2 };
5090         }
5091       else
5092         return { 2, nullptr };
5093       break;
5094     case BinaryOpcode::power:
5095       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5096         return { 0, datatree.AddPower(expr_t_1, expr_t_2) };
5097       else if (is_endogenous_present_1 && !is_endogenous_present_2)
5098         {
5099           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), datatree.AddDivide(datatree.One, expr_t_2), nullptr);
5100           return { 1, nullptr };
5101         }
5102       else if (!is_endogenous_present_1 && is_endogenous_present_2)
5103         {
5104           /* we have to nomalize a^f(X) = RHS */
5105           /* First computes the ln(RHS)*/
5106           List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::log), nullptr, nullptr);
5107           /* Second  computes f(X) = ln(RHS) / ln(a)*/
5108           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), nullptr, datatree.AddLog(expr_t_1));
5109           return { 1, nullptr };
5110         }
5111       break;
5112     case BinaryOpcode::equal:
5113       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5114         {
5115           return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.AddMinus(expr_t_2, expr_t_1)) };
5116         }
5117       else if (is_endogenous_present_1 && is_endogenous_present_2)
5118         {
5119           return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.Zero) };
5120         }
5121       else if (!is_endogenous_present_1 && is_endogenous_present_2)
5122         {
5123           return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), /*datatree.AddUMinus(expr_t_1)*/ expr_t_1) };
5124         }
5125       else if (is_endogenous_present_1 && !is_endogenous_present_2)
5126         {
5127           return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), expr_t_2) };
5128         }
5129       break;
5130     case BinaryOpcode::max:
5131       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5132         return { 0, datatree.AddMax(expr_t_1, expr_t_2) };
5133       else
5134         return { 2, nullptr };
5135       break;
5136     case BinaryOpcode::min:
5137       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5138         return { 0, datatree.AddMin(expr_t_1, expr_t_2) };
5139       else
5140         return { 2, nullptr };
5141       break;
5142     case BinaryOpcode::less:
5143       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5144         return { 0, datatree.AddLess(expr_t_1, expr_t_2) };
5145       else
5146         return { 2, nullptr };
5147       break;
5148     case BinaryOpcode::greater:
5149       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5150         return { 0, datatree.AddGreater(expr_t_1, expr_t_2) };
5151       else
5152         return { 2, nullptr };
5153       break;
5154     case BinaryOpcode::lessEqual:
5155       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5156         return { 0, datatree.AddLessEqual(expr_t_1, expr_t_2) };
5157       else
5158         return { 2, nullptr };
5159       break;
5160     case BinaryOpcode::greaterEqual:
5161       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5162         return { 0, datatree.AddGreaterEqual(expr_t_1, expr_t_2) };
5163       else
5164         return { 2, nullptr };
5165       break;
5166     case BinaryOpcode::equalEqual:
5167       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5168         return { 0, datatree.AddEqualEqual(expr_t_1, expr_t_2) };
5169       else
5170         return { 2, nullptr };
5171       break;
5172     case BinaryOpcode::different:
5173       if (!is_endogenous_present_1 && !is_endogenous_present_2)
5174         return { 0, datatree.AddDifferent(expr_t_1, expr_t_2) };
5175       else
5176         return { 2, nullptr };
5177       break;
5178     default:
5179       cerr << "Binary operator not handled during the normalization process" << endl;
5180       return { 2, nullptr }; // Could not be normalized
5181     }
5182   // Suppress GCC warning
5183   cerr << "BinaryOpNode::normalizeEquation: impossible case" << endl;
5184   exit(EXIT_FAILURE);
5185 }
5186 
5187 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)5188 BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
5189 {
5190   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
5191   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
5192   return composeDerivatives(darg1, darg2);
5193 }
5194 
5195 expr_t
buildSimilarBinaryOpNode(expr_t alt_arg1,expr_t alt_arg2,DataTree & alt_datatree) const5196 BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTree &alt_datatree) const
5197 {
5198   switch (op_code)
5199     {
5200     case BinaryOpcode::plus:
5201       return alt_datatree.AddPlus(alt_arg1, alt_arg2);
5202     case BinaryOpcode::minus:
5203       return alt_datatree.AddMinus(alt_arg1, alt_arg2);
5204     case BinaryOpcode::times:
5205       return alt_datatree.AddTimes(alt_arg1, alt_arg2);
5206     case BinaryOpcode::divide:
5207       return alt_datatree.AddDivide(alt_arg1, alt_arg2);
5208     case BinaryOpcode::power:
5209       return alt_datatree.AddPower(alt_arg1, alt_arg2);
5210     case BinaryOpcode::equal:
5211       return alt_datatree.AddEqual(alt_arg1, alt_arg2);
5212     case BinaryOpcode::max:
5213       return alt_datatree.AddMax(alt_arg1, alt_arg2);
5214     case BinaryOpcode::min:
5215       return alt_datatree.AddMin(alt_arg1, alt_arg2);
5216     case BinaryOpcode::less:
5217       return alt_datatree.AddLess(alt_arg1, alt_arg2);
5218     case BinaryOpcode::greater:
5219       return alt_datatree.AddGreater(alt_arg1, alt_arg2);
5220     case BinaryOpcode::lessEqual:
5221       return alt_datatree.AddLessEqual(alt_arg1, alt_arg2);
5222     case BinaryOpcode::greaterEqual:
5223       return alt_datatree.AddGreaterEqual(alt_arg1, alt_arg2);
5224     case BinaryOpcode::equalEqual:
5225       return alt_datatree.AddEqualEqual(alt_arg1, alt_arg2);
5226     case BinaryOpcode::different:
5227       return alt_datatree.AddDifferent(alt_arg1, alt_arg2);
5228     case BinaryOpcode::powerDeriv:
5229       return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
5230     }
5231   // Suppress GCC warning
5232   exit(EXIT_FAILURE);
5233 }
5234 
5235 expr_t
toStatic(DataTree & static_datatree) const5236 BinaryOpNode::toStatic(DataTree &static_datatree) const
5237 {
5238   expr_t sarg1 = arg1->toStatic(static_datatree);
5239   expr_t sarg2 = arg2->toStatic(static_datatree);
5240   return buildSimilarBinaryOpNode(sarg1, sarg2, static_datatree);
5241 }
5242 
5243 void
computeXrefs(EquationInfo & ei) const5244 BinaryOpNode::computeXrefs(EquationInfo &ei) const
5245 {
5246   arg1->computeXrefs(ei);
5247   arg2->computeXrefs(ei);
5248 }
5249 
5250 expr_t
clone(DataTree & datatree) const5251 BinaryOpNode::clone(DataTree &datatree) const
5252 {
5253   expr_t substarg1 = arg1->clone(datatree);
5254   expr_t substarg2 = arg2->clone(datatree);
5255   return buildSimilarBinaryOpNode(substarg1, substarg2, datatree);
5256 }
5257 
5258 int
maxEndoLead() const5259 BinaryOpNode::maxEndoLead() const
5260 {
5261   return max(arg1->maxEndoLead(), arg2->maxEndoLead());
5262 }
5263 
5264 int
maxExoLead() const5265 BinaryOpNode::maxExoLead() const
5266 {
5267   return max(arg1->maxExoLead(), arg2->maxExoLead());
5268 }
5269 
5270 int
maxEndoLag() const5271 BinaryOpNode::maxEndoLag() const
5272 {
5273   return max(arg1->maxEndoLag(), arg2->maxEndoLag());
5274 }
5275 
5276 int
maxExoLag() const5277 BinaryOpNode::maxExoLag() const
5278 {
5279   return max(arg1->maxExoLag(), arg2->maxExoLag());
5280 }
5281 
5282 int
maxLead() const5283 BinaryOpNode::maxLead() const
5284 {
5285   return max(arg1->maxLead(), arg2->maxLead());
5286 }
5287 
5288 int
maxLag() const5289 BinaryOpNode::maxLag() const
5290 {
5291   return max(arg1->maxLag(), arg2->maxLag());
5292 }
5293 
5294 int
maxLagWithDiffsExpanded() const5295 BinaryOpNode::maxLagWithDiffsExpanded() const
5296 {
5297   return max(arg1->maxLagWithDiffsExpanded(), arg2->maxLagWithDiffsExpanded());
5298 }
5299 
5300 expr_t
undiff() const5301 BinaryOpNode::undiff() const
5302 {
5303   expr_t arg1subst = arg1->undiff();
5304   expr_t arg2subst = arg2->undiff();
5305   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5306 }
5307 
5308 int
PacMaxLag(int lhs_symb_id) const5309 BinaryOpNode::PacMaxLag(int lhs_symb_id) const
5310 {
5311   return max(arg1->PacMaxLag(lhs_symb_id), arg2->PacMaxLag(lhs_symb_id));
5312 }
5313 
5314 int
getPacTargetSymbIdHelper(int lhs_symb_id,int undiff_lhs_symb_id,const set<pair<int,int>> & endogs) const5315 BinaryOpNode::getPacTargetSymbIdHelper(int lhs_symb_id, int undiff_lhs_symb_id, const set<pair<int, int>> &endogs) const
5316 {
5317   int target_symb_id = -1;
5318   bool found_lagged_lhs = false;
5319   for (auto &it : endogs)
5320     {
5321       int id = datatree.symbol_table.getUltimateOrigSymbID(it.first);
5322       if (id == lhs_symb_id || id == undiff_lhs_symb_id)
5323         found_lagged_lhs = true;
5324       if (id != lhs_symb_id && id != undiff_lhs_symb_id)
5325         if (target_symb_id < 0)
5326           target_symb_id = it.first;
5327     }
5328   if (!found_lagged_lhs)
5329     target_symb_id = -1;
5330   return target_symb_id;
5331 }
5332 
5333 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const5334 BinaryOpNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
5335 {
5336   set<pair<int, int>> endogs;
5337   arg1->collectDynamicVariables(SymbolType::endogenous, endogs);
5338   int target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, undiff_lhs_symb_id, endogs);
5339   if (target_symb_id >= 0)
5340     return target_symb_id;
5341 
5342   endogs.clear();
5343   arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
5344   target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, undiff_lhs_symb_id, endogs);
5345 
5346   if (target_symb_id < 0)
5347     {
5348       cerr << "Error finding target variable in PAC equation" << endl;
5349       exit(EXIT_FAILURE);
5350     }
5351 
5352   return target_symb_id;
5353 }
5354 
5355 expr_t
decreaseLeadsLags(int n) const5356 BinaryOpNode::decreaseLeadsLags(int n) const
5357 {
5358   expr_t arg1subst = arg1->decreaseLeadsLags(n);
5359   expr_t arg2subst = arg2->decreaseLeadsLags(n);
5360   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5361 }
5362 
5363 expr_t
decreaseLeadsLagsPredeterminedVariables() const5364 BinaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
5365 {
5366   expr_t arg1subst = arg1->decreaseLeadsLagsPredeterminedVariables();
5367   expr_t arg2subst = arg2->decreaseLeadsLagsPredeterminedVariables();
5368   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5369 }
5370 
5371 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const5372 BinaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
5373 {
5374   expr_t arg1subst, arg2subst;
5375   int maxendolead1 = arg1->maxEndoLead(), maxendolead2 = arg2->maxEndoLead();
5376 
5377   if (maxendolead1 < 2 && maxendolead2 < 2)
5378     return const_cast<BinaryOpNode *>(this);
5379   if (deterministic_model)
5380     {
5381       arg1subst = maxendolead1 >= 2 ? arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg1;
5382       arg2subst = maxendolead2 >= 2 ? arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg2;
5383       return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5384     }
5385   else
5386     switch (op_code)
5387       {
5388       case BinaryOpcode::plus:
5389       case BinaryOpcode::minus:
5390       case BinaryOpcode::equal:
5391         arg1subst = maxendolead1 >= 2 ? arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg1;
5392         arg2subst = maxendolead2 >= 2 ? arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg2;
5393         return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5394       case BinaryOpcode::times:
5395       case BinaryOpcode::divide:
5396         if (maxendolead1 >= 2 && maxendolead2 == 0 && arg2->maxExoLead() == 0)
5397           {
5398             arg1subst = arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
5399             return buildSimilarBinaryOpNode(arg1subst, arg2, datatree);
5400           }
5401         if (maxendolead1 == 0 && arg1->maxExoLead() == 0
5402             && maxendolead2 >= 2 && op_code == BinaryOpcode::times)
5403           {
5404             arg2subst = arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
5405             return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
5406           }
5407         return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5408       default:
5409         return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5410       }
5411 }
5412 
5413 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5414 BinaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5415 {
5416   expr_t arg1subst = arg1->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
5417   expr_t arg2subst = arg2->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
5418   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5419 }
5420 
5421 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const5422 BinaryOpNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
5423 {
5424   expr_t arg1subst, arg2subst;
5425   int maxexolead1 = arg1->maxExoLead(), maxexolead2 = arg2->maxExoLead();
5426 
5427   if (maxexolead1 < 1 && maxexolead2 < 1)
5428     return const_cast<BinaryOpNode *>(this);
5429   if (deterministic_model)
5430     {
5431       arg1subst = maxexolead1 >= 1 ? arg1->substituteExoLead(subst_table, neweqs, deterministic_model) : arg1;
5432       arg2subst = maxexolead2 >= 1 ? arg2->substituteExoLead(subst_table, neweqs, deterministic_model) : arg2;
5433       return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5434     }
5435   else
5436     switch (op_code)
5437       {
5438       case BinaryOpcode::plus:
5439       case BinaryOpcode::minus:
5440       case BinaryOpcode::equal:
5441         arg1subst = maxexolead1 >= 1 ? arg1->substituteExoLead(subst_table, neweqs, deterministic_model) : arg1;
5442         arg2subst = maxexolead2 >= 1 ? arg2->substituteExoLead(subst_table, neweqs, deterministic_model) : arg2;
5443         return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5444       case BinaryOpcode::times:
5445       case BinaryOpcode::divide:
5446         if (maxexolead1 >= 1 && maxexolead2 == 0 && arg2->maxEndoLead() == 0)
5447           {
5448             arg1subst = arg1->substituteExoLead(subst_table, neweqs, deterministic_model);
5449             return buildSimilarBinaryOpNode(arg1subst, arg2, datatree);
5450           }
5451         if (maxexolead1 == 0 && arg1->maxEndoLead() == 0
5452             && maxexolead2 >= 1 && op_code == BinaryOpcode::times)
5453           {
5454             arg2subst = arg2->substituteExoLead(subst_table, neweqs, deterministic_model);
5455             return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
5456           }
5457         return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5458       default:
5459         return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5460       }
5461 }
5462 
5463 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5464 BinaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5465 {
5466   expr_t arg1subst = arg1->substituteExoLag(subst_table, neweqs);
5467   expr_t arg2subst = arg2->substituteExoLag(subst_table, neweqs);
5468   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5469 }
5470 
5471 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const5472 BinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
5473 {
5474   expr_t arg1subst = arg1->substituteExpectation(subst_table, neweqs, partial_information_model);
5475   expr_t arg2subst = arg2->substituteExpectation(subst_table, neweqs, partial_information_model);
5476   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5477 }
5478 
5479 expr_t
substituteAdl() const5480 BinaryOpNode::substituteAdl() const
5481 {
5482   expr_t arg1subst = arg1->substituteAdl();
5483   expr_t arg2subst = arg2->substituteAdl();
5484   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5485 }
5486 
5487 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const5488 BinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
5489 {
5490   expr_t arg1subst = arg1->substituteVarExpectation(subst_table);
5491   expr_t arg2subst = arg2->substituteVarExpectation(subst_table);
5492   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5493 }
5494 
5495 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const5496 BinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
5497 {
5498   arg1->findUnaryOpNodesForAuxVarCreation(nodes);
5499   arg2->findUnaryOpNodesForAuxVarCreation(nodes);
5500 }
5501 
5502 void
findDiffNodes(lag_equivalence_table_t & nodes) const5503 BinaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
5504 {
5505   arg1->findDiffNodes(nodes);
5506   arg2->findDiffNodes(nodes);
5507 }
5508 
5509 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5510 BinaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
5511                              vector<BinaryOpNode *> &neweqs) const
5512 {
5513   expr_t arg1subst = arg1->substituteDiff(nodes, subst_table, neweqs);
5514   expr_t arg2subst = arg2->substituteDiff(nodes, subst_table, neweqs);
5515   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5516 }
5517 
5518 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5519 BinaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5520 {
5521   expr_t arg1subst = arg1->substituteUnaryOpNodes(nodes, subst_table, neweqs);
5522   expr_t arg2subst = arg2->substituteUnaryOpNodes(nodes, subst_table, neweqs);
5523   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5524 }
5525 
5526 int
countDiffs() const5527 BinaryOpNode::countDiffs() const
5528 {
5529   return max(arg1->countDiffs(), arg2->countDiffs());
5530 }
5531 
5532 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)5533 BinaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
5534 {
5535   expr_t arg1subst = arg1->substitutePacExpectation(name, subexpr);
5536   expr_t arg2subst = arg2->substitutePacExpectation(name, subexpr);
5537   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5538 }
5539 
5540 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5541 BinaryOpNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5542 {
5543   expr_t arg1subst = arg1->differentiateForwardVars(subset, subst_table, neweqs);
5544   expr_t arg2subst = arg2->differentiateForwardVars(subset, subst_table, neweqs);
5545   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5546 }
5547 
5548 expr_t
addMultipliersToConstraints(int i)5549 BinaryOpNode::addMultipliersToConstraints(int i)
5550 {
5551   int symb_id = datatree.symbol_table.addMultiplierAuxiliaryVar(i);
5552   expr_t newAuxLM = datatree.AddVariable(symb_id, 0);
5553   return datatree.AddEqual(datatree.AddTimes(newAuxLM, datatree.AddMinus(arg1, arg2)), datatree.Zero);
5554 }
5555 
5556 bool
isNumConstNodeEqualTo(double value) const5557 BinaryOpNode::isNumConstNodeEqualTo(double value) const
5558 {
5559   return false;
5560 }
5561 
5562 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const5563 BinaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
5564 {
5565   return false;
5566 }
5567 
5568 bool
containsPacExpectation(const string & pac_model_name) const5569 BinaryOpNode::containsPacExpectation(const string &pac_model_name) const
5570 {
5571   return arg1->containsPacExpectation(pac_model_name) || arg2->containsPacExpectation(pac_model_name);
5572 }
5573 
5574 bool
containsEndogenous() const5575 BinaryOpNode::containsEndogenous() const
5576 {
5577   return arg1->containsEndogenous() || arg2->containsEndogenous();
5578 }
5579 
5580 bool
containsExogenous() const5581 BinaryOpNode::containsExogenous() const
5582 {
5583   return arg1->containsExogenous() || arg2->containsExogenous();
5584 }
5585 
5586 expr_t
replaceTrendVar() const5587 BinaryOpNode::replaceTrendVar() const
5588 {
5589   expr_t arg1subst = arg1->replaceTrendVar();
5590   expr_t arg2subst = arg2->replaceTrendVar();
5591   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5592 }
5593 
5594 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const5595 BinaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
5596 {
5597   expr_t arg1subst = arg1->detrend(symb_id, log_trend, trend);
5598   expr_t arg2subst = arg2->detrend(symb_id, log_trend, trend);
5599   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5600 }
5601 
5602 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const5603 BinaryOpNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
5604 {
5605   expr_t arg1subst = arg1->removeTrendLeadLag(trend_symbols_map);
5606   expr_t arg2subst = arg2->removeTrendLeadLag(trend_symbols_map);
5607   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5608 }
5609 
5610 bool
isInStaticForm() const5611 BinaryOpNode::isInStaticForm() const
5612 {
5613   return arg1->isInStaticForm() && arg2->isInStaticForm();
5614 }
5615 
5616 bool
findTargetVariableHelper1(int lhs_symb_id,int rhs_symb_id) const5617 BinaryOpNode::findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
5618 {
5619   if (lhs_symb_id == rhs_symb_id)
5620     return true;
5621 
5622   try
5623     {
5624       if (datatree.symbol_table.isAuxiliaryVariable(rhs_symb_id)
5625           && lhs_symb_id == datatree.symbol_table.getOrigSymbIdForAuxVar(rhs_symb_id))
5626         return true;
5627     }
5628   catch (...)
5629     {
5630     }
5631   return false;
5632 }
5633 
5634 int
findTargetVariableHelper(const expr_t arg1,const expr_t arg2,int lhs_symb_id) const5635 BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2,
5636                                        int lhs_symb_id) const
5637 {
5638   set<int> params;
5639   arg1->collectVariables(SymbolType::parameter, params);
5640   if (params.size() != 1)
5641     return -1;
5642 
5643   set<pair<int, int>> endogs;
5644   arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
5645   if (auto testarg2 = dynamic_cast<BinaryOpNode *>(arg2);
5646       endogs.size() == 2 && testarg2 && testarg2->op_code == BinaryOpcode::minus
5647       && dynamic_cast<VariableNode *>(testarg2->arg1)
5648       && dynamic_cast<VariableNode *>(testarg2->arg2))
5649     {
5650       if (findTargetVariableHelper1(lhs_symb_id, endogs.begin()->first))
5651         return endogs.rbegin()->first;
5652       else if (findTargetVariableHelper1(lhs_symb_id, endogs.rbegin()->first))
5653         return endogs.begin()->first;
5654     }
5655   return -1;
5656 }
5657 
5658 int
findTargetVariable(int lhs_symb_id) const5659 BinaryOpNode::findTargetVariable(int lhs_symb_id) const
5660 {
5661   int retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id);
5662   if (retval < 0)
5663     retval = findTargetVariableHelper(arg2, arg1, lhs_symb_id);
5664   if (retval < 0)
5665     retval = arg1->findTargetVariable(lhs_symb_id);
5666   if (retval < 0)
5667     retval = arg2->findTargetVariable(lhs_symb_id);
5668   return retval;
5669 }
5670 
5671 pair<int, vector<tuple<int, bool, int>>>
getPacEC(BinaryOpNode * bopn,int lhs_symb_id,int lhs_orig_symb_id) const5672 BinaryOpNode::getPacEC(BinaryOpNode *bopn, int lhs_symb_id, int lhs_orig_symb_id) const
5673 {
5674   pair<int, vector<tuple<int, bool, int>>> ec_params_and_vars = {-1, vector<tuple<int, bool, int>>()};
5675   int optim_param_symb_id = -1;
5676   expr_t optim_part = nullptr;
5677   set<pair<int, int>> endogs;
5678   bopn->collectDynamicVariables(SymbolType::endogenous, endogs);
5679   int target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, lhs_orig_symb_id, endogs);
5680   if (target_symb_id >= 0 && bopn->isParamTimesEndogExpr())
5681     {
5682       optim_part = bopn->arg2;
5683       auto vn = dynamic_cast<VariableNode *>(bopn->arg1);
5684       if (!vn || datatree.symbol_table.getType(vn->symb_id) != SymbolType::parameter)
5685         {
5686           optim_part = bopn->arg1;
5687           vn = dynamic_cast<VariableNode *>(bopn->arg2);
5688         }
5689       if (!vn || datatree.symbol_table.getType(vn->symb_id) != SymbolType::parameter)
5690         return ec_params_and_vars;
5691       optim_param_symb_id = vn->symb_id;
5692     }
5693   if (optim_param_symb_id >= 0)
5694     {
5695       endogs.clear();
5696       optim_part->collectDynamicVariables(SymbolType::endogenous, endogs);
5697       optim_part->collectDynamicVariables(SymbolType::exogenous, endogs);
5698       if (endogs.size() != 2)
5699         {
5700           cerr << "Error getting EC part of PAC equation" << endl;
5701           exit(EXIT_FAILURE);
5702         }
5703       vector<pair<expr_t, int>> terms;
5704       vector<tuple<int, bool, int>> ordered_symb_ids;
5705       optim_part->decomposeAdditiveTerms(terms, 1);
5706       for (const auto &it : terms)
5707         {
5708           int scale = it.second;
5709           auto vn = dynamic_cast<VariableNode *>(it.first);
5710           if (!vn
5711               || !(datatree.symbol_table.getType(vn->symb_id) == SymbolType::endogenous
5712                    || datatree.symbol_table.getType(vn->symb_id) == SymbolType::exogenous))
5713             {
5714               cerr << "Problem with error component portion of PAC equation" << endl;
5715               exit(EXIT_FAILURE);
5716             }
5717           int id = vn->symb_id;
5718           int orig_id = datatree.symbol_table.getUltimateOrigSymbID(id);
5719           bool istarget = true;
5720           if (orig_id == lhs_symb_id || orig_id == lhs_orig_symb_id)
5721             istarget = false;
5722           ordered_symb_ids.emplace_back(id, istarget, scale);
5723         }
5724       ec_params_and_vars = { optim_param_symb_id, ordered_symb_ids };
5725     }
5726   return ec_params_and_vars;
5727 }
5728 
5729 void
getPacAREC(int lhs_symb_id,int lhs_orig_symb_id,pair<int,vector<tuple<int,bool,int>>> & ec_params_and_vars,set<pair<int,pair<int,int>>> & ar_params_and_vars,vector<tuple<int,int,int,double>> & additive_vars_params_and_constants) const5730 BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
5731                          pair<int, vector<tuple<int, bool, int>>> &ec_params_and_vars,
5732                          set<pair<int, pair<int, int>>> &ar_params_and_vars,
5733                          vector<tuple<int, int, int, double>> &additive_vars_params_and_constants) const
5734 {
5735   vector<pair<expr_t, int>> terms;
5736   decomposeAdditiveTerms(terms, 1);
5737   for (auto it = terms.begin(); it != terms.end(); ++it)
5738     if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
5739       {
5740         ec_params_and_vars = getPacEC(bopn, lhs_symb_id, lhs_orig_symb_id);
5741         if (ec_params_and_vars.first >= 0)
5742           {
5743             terms.erase(it);
5744             break;
5745           }
5746       }
5747 
5748   if (ec_params_and_vars.first < 0)
5749     {
5750       cerr << "Error finding EC part of PAC equation" << endl;
5751       exit(EXIT_FAILURE);
5752     }
5753 
5754   for (const auto &it : terms)
5755     {
5756       if (dynamic_cast<PacExpectationNode *>(it.first))
5757         continue;
5758 
5759       pair<int, vector<tuple<int, int, int, double>>> m;
5760       try
5761         {
5762           m = {-1, {it.first->matchVariableTimesConstantTimesParam()}};
5763         }
5764       catch (MatchFailureException &e)
5765         {
5766           try
5767             {
5768               m = it.first->matchParamTimesLinearCombinationOfVariables();
5769             }
5770           catch (MatchFailureException &e)
5771             {
5772               cerr << "Unsupported expression in PAC equation" << endl;
5773               exit(EXIT_FAILURE);
5774             }
5775         }
5776 
5777       for (auto &t : m.second)
5778         get<3>(t) *= it.second; // Update sign of constants
5779 
5780       int pid = get<0>(m);
5781       for (auto [vid, lag, pidtmp, constant] : m.second)
5782         {
5783           if (pid == -1)
5784             pid = pidtmp;
5785           else if (pidtmp >= 0)
5786             {
5787               cerr << "unexpected parameter found in PAC equation" << endl;
5788               exit(EXIT_FAILURE);
5789             }
5790 
5791           if (int vidorig = datatree.symbol_table.getUltimateOrigSymbID(vid);
5792               vidorig == lhs_symb_id || vidorig == lhs_orig_symb_id)
5793             {
5794               // This is an autoregressive term
5795               if (constant != 1 || pid == -1)
5796                 {
5797                   cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
5798                   exit(EXIT_FAILURE);
5799                 }
5800               ar_params_and_vars.insert({pid, { vid, lag }});
5801             }
5802           else
5803             // This is a residual additive term
5804             additive_vars_params_and_constants.emplace_back(vid, lag, pid, constant);
5805         }
5806     }
5807 }
5808 
5809 bool
isParamTimesEndogExpr() const5810 BinaryOpNode::isParamTimesEndogExpr() const
5811 {
5812   if (op_code == BinaryOpcode::times)
5813     {
5814       set<int> params;
5815       auto test_arg1 = dynamic_cast<VariableNode *>(arg1);
5816       auto test_arg2 = dynamic_cast<VariableNode *>(arg2);
5817       if (test_arg1)
5818         arg1->collectVariables(SymbolType::parameter, params);
5819       else if (test_arg2)
5820         arg2->collectVariables(SymbolType::parameter, params);
5821       else
5822         return false;
5823 
5824       if (params.size() != 1)
5825         return false;
5826 
5827       params.clear();
5828       set<pair<int, int>> endogs, exogs;
5829       if (test_arg1)
5830         {
5831           arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
5832           arg2->collectDynamicVariables(SymbolType::exogenous, exogs);
5833           arg2->collectVariables(SymbolType::parameter, params);
5834           if (params.size() == 0 && exogs.size() == 0 && endogs.size() >= 1)
5835             return true;
5836         }
5837       else
5838         {
5839           arg1->collectDynamicVariables(SymbolType::endogenous, endogs);
5840           arg1->collectDynamicVariables(SymbolType::exogenous, exogs);
5841           arg1->collectVariables(SymbolType::parameter, params);
5842           if (params.size() == 0 && exogs.size() == 0 && endogs.size() >= 1)
5843             return true;
5844         }
5845     }
5846   else if (op_code == BinaryOpcode::plus)
5847     return arg1->isParamTimesEndogExpr() || arg2->isParamTimesEndogExpr();
5848   return false;
5849 }
5850 
5851 bool
getPacNonOptimizingPartHelper(BinaryOpNode * bopn,int optim_share) const5852 BinaryOpNode::getPacNonOptimizingPartHelper(BinaryOpNode *bopn, int optim_share) const
5853 {
5854   set<int> params;
5855   bopn->collectVariables(SymbolType::parameter, params);
5856   if (params.size() == 1 && *params.begin() == optim_share)
5857     return true;
5858   return false;
5859 }
5860 
5861 expr_t
getPacNonOptimizingPart(BinaryOpNode * bopn,int optim_share) const5862 BinaryOpNode::getPacNonOptimizingPart(BinaryOpNode *bopn, int optim_share) const
5863 {
5864   auto a1 = dynamic_cast<BinaryOpNode *>(bopn->arg1);
5865   auto a2 = dynamic_cast<BinaryOpNode *>(bopn->arg2);
5866   if (!a1 && !a2)
5867     return nullptr;
5868 
5869   if (a1 && getPacNonOptimizingPartHelper(a1, optim_share))
5870     return bopn->arg2;
5871 
5872   if (a2 && getPacNonOptimizingPartHelper(a2, optim_share))
5873     return bopn->arg1;
5874 
5875   return nullptr;
5876 }
5877 
5878 pair<int, expr_t>
getPacOptimizingShareAndExprNodesHelper(BinaryOpNode * bopn,int lhs_symb_id,int lhs_orig_symb_id) const5879 BinaryOpNode::getPacOptimizingShareAndExprNodesHelper(BinaryOpNode *bopn, int lhs_symb_id, int lhs_orig_symb_id) const
5880 {
5881   int optim_param_symb_id = -1;
5882   expr_t optim_part = nullptr;
5883   set<pair<int, int>> endogs;
5884   bopn->collectDynamicVariables(SymbolType::endogenous, endogs);
5885   int target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, lhs_orig_symb_id, endogs);
5886   if (target_symb_id >= 0)
5887     {
5888       set<int> params;
5889       if (bopn->arg1->isParamTimesEndogExpr() && !bopn->arg2->isParamTimesEndogExpr())
5890         {
5891           optim_part = bopn->arg1;
5892           bopn->arg2->collectVariables(SymbolType::parameter, params);
5893           optim_param_symb_id = *(params.begin());
5894         }
5895       else if (bopn->arg2->isParamTimesEndogExpr() && !bopn->arg1->isParamTimesEndogExpr())
5896         {
5897           optim_part = bopn->arg2;
5898           bopn->arg1->collectVariables(SymbolType::parameter, params);
5899           optim_param_symb_id = *(params.begin());
5900         }
5901     }
5902   return {optim_param_symb_id, optim_part};
5903 }
5904 
5905 tuple<int, expr_t, expr_t, expr_t>
getPacOptimizingShareAndExprNodes(int lhs_symb_id,int lhs_orig_symb_id) const5906 BinaryOpNode::getPacOptimizingShareAndExprNodes(int lhs_symb_id, int lhs_orig_symb_id) const
5907 {
5908   vector<pair<expr_t, int>> terms;
5909   decomposeAdditiveTerms(terms, 1);
5910   for (auto &it : terms)
5911     if (dynamic_cast<PacExpectationNode *>(it.first))
5912       // if the pac_expectation operator is additive in the expression
5913       // there are no optimizing shares
5914       return {-1, nullptr, nullptr, nullptr};
5915 
5916   int optim_share;
5917   expr_t optim_part, non_optim_part, additive_part;
5918   optim_part = non_optim_part = additive_part = nullptr;
5919 
5920   for (auto it = terms.begin(); it != terms.end(); ++it)
5921     if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
5922       {
5923         tie(optim_share, optim_part)
5924           = getPacOptimizingShareAndExprNodesHelper(bopn, lhs_symb_id, lhs_orig_symb_id);
5925         if (optim_share >= 0 && optim_part)
5926           {
5927             terms.erase(it);
5928             break;
5929           }
5930       }
5931 
5932   if (!optim_part)
5933     return {-1, nullptr, nullptr, nullptr};
5934 
5935   for (auto it = terms.begin(); it != terms.end(); ++it)
5936     if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
5937       {
5938         non_optim_part = getPacNonOptimizingPart(bopn, optim_share);
5939         if (non_optim_part)
5940           {
5941             terms.erase(it);
5942             break;
5943           }
5944       }
5945 
5946   if (!non_optim_part)
5947     return {-1, nullptr, nullptr, nullptr};
5948   else
5949     {
5950       additive_part = datatree.Zero;
5951       for (auto it : terms)
5952         additive_part = datatree.AddPlus(additive_part, it.first);
5953       if (additive_part == datatree.Zero)
5954         additive_part = nullptr;
5955     }
5956 
5957   return {optim_share, optim_part, non_optim_part, additive_part};
5958 }
5959 
5960 void
fillAutoregressiveRow(int eqn,const vector<int> & lhs,map<tuple<int,int,int>,expr_t> & AR) const5961 BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const
5962 {
5963   vector<pair<expr_t, int>> terms;
5964   decomposeAdditiveTerms(terms, 1);
5965   for (const auto &it : terms)
5966     {
5967       int vid, lag, param_id;
5968       double constant;
5969       try
5970         {
5971           tie(vid, lag, param_id, constant) = it.first->matchVariableTimesConstantTimesParam();
5972           constant *= it.second;
5973         }
5974       catch (MatchFailureException &e)
5975         {
5976           continue;
5977         }
5978 
5979       if (datatree.symbol_table.isDiffAuxiliaryVariable(vid))
5980         {
5981           lag = -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(vid);
5982           vid = datatree.symbol_table.getOrigSymbIdForDiffAuxVar(vid);
5983         }
5984 
5985       if (find(lhs.begin(), lhs.end(), vid) == lhs.end())
5986         continue;
5987 
5988       if (AR.find({eqn, -lag, vid}) != AR.end())
5989         {
5990           cerr << "BinaryOpNode::fillAutoregressiveRow: Error filling AR matrix: "
5991                << "lag/symb_id encountered more than once in equation" << endl;
5992           exit(EXIT_FAILURE);
5993         }
5994       if (constant != 1 || param_id == -1)
5995         {
5996           cerr << "BinaryOpNode::fillAutoregressiveRow: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
5997           exit(EXIT_FAILURE);
5998         }
5999       AR[{eqn, -lag, vid}] = datatree.AddVariable(param_id);
6000     }
6001 }
6002 
6003 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const6004 BinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
6005 {
6006   if (op_code == BinaryOpcode::equal)
6007     {
6008       if (dynamic_cast<VariableNode *>(arg1) && dynamic_cast<NumConstNode *>(arg2))
6009         table[dynamic_cast<VariableNode *>(arg1)] = dynamic_cast<NumConstNode *>(arg2);
6010       else if (dynamic_cast<VariableNode *>(arg2) && dynamic_cast<NumConstNode *>(arg1))
6011         table[dynamic_cast<VariableNode *>(arg2)] = dynamic_cast<NumConstNode *>(arg1);
6012     }
6013   else
6014     {
6015       arg1->findConstantEquations(table);
6016       arg2->findConstantEquations(table);
6017     }
6018 }
6019 
6020 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const6021 BinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
6022 {
6023   if (op_code == BinaryOpcode::equal)
6024     for (auto &it : table)
6025       if ((it.first == arg1 && it.second == arg2) || (it.first == arg2 && it.second == arg1))
6026         return const_cast<BinaryOpNode *>(this);
6027   expr_t arg1subst = arg1->replaceVarsInEquation(table);
6028   expr_t arg2subst = arg2->replaceVarsInEquation(table);
6029   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
6030 }
6031 
6032 bool
isVarModelReferenced(const string & model_info_name) const6033 BinaryOpNode::isVarModelReferenced(const string &model_info_name) const
6034 {
6035   return arg1->isVarModelReferenced(model_info_name)
6036     || arg2->isVarModelReferenced(model_info_name);
6037 }
6038 
6039 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const6040 BinaryOpNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
6041 {
6042   arg1->getEndosAndMaxLags(model_endos_and_lags);
6043   arg2->getEndosAndMaxLags(model_endos_and_lags);
6044 }
6045 
6046 expr_t
substituteStaticAuxiliaryVariable() const6047 BinaryOpNode::substituteStaticAuxiliaryVariable() const
6048 {
6049   expr_t arg1subst = arg1->substituteStaticAuxiliaryVariable();
6050   expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
6051   return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
6052 }
6053 
6054 expr_t
substituteStaticAuxiliaryDefinition() const6055 BinaryOpNode::substituteStaticAuxiliaryDefinition() const
6056 {
6057   expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
6058   return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
6059 }
6060 
TrinaryOpNode(DataTree & datatree_arg,int idx_arg,const expr_t arg1_arg,TrinaryOpcode op_code_arg,const expr_t arg2_arg,const expr_t arg3_arg)6061 TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
6062                              TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg) :
6063   ExprNode{datatree_arg, idx_arg},
6064   arg1{arg1_arg},
6065   arg2{arg2_arg},
6066   arg3{arg3_arg},
6067   op_code{op_code_arg}
6068 {
6069 }
6070 
6071 void
prepareForDerivation()6072 TrinaryOpNode::prepareForDerivation()
6073 {
6074   if (preparedForDerivation)
6075     return;
6076 
6077   preparedForDerivation = true;
6078 
6079   arg1->prepareForDerivation();
6080   arg2->prepareForDerivation();
6081   arg3->prepareForDerivation();
6082 
6083   // Non-null derivatives are the union of those of the arguments
6084   // Compute set union of arg{1,2,3}->non_null_derivatives
6085   set<int> non_null_derivatives_tmp;
6086   set_union(arg1->non_null_derivatives.begin(),
6087             arg1->non_null_derivatives.end(),
6088             arg2->non_null_derivatives.begin(),
6089             arg2->non_null_derivatives.end(),
6090             inserter(non_null_derivatives_tmp, non_null_derivatives_tmp.begin()));
6091   set_union(non_null_derivatives_tmp.begin(),
6092             non_null_derivatives_tmp.end(),
6093             arg3->non_null_derivatives.begin(),
6094             arg3->non_null_derivatives.end(),
6095             inserter(non_null_derivatives, non_null_derivatives.begin()));
6096 }
6097 
6098 expr_t
composeDerivatives(expr_t darg1,expr_t darg2,expr_t darg3)6099 TrinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3)
6100 {
6101 
6102   expr_t t11, t12, t13, t14, t15;
6103 
6104   switch (op_code)
6105     {
6106     case TrinaryOpcode::normcdf:
6107       // normal pdf is inlined in the tree
6108       expr_t y;
6109       // sqrt(2*pi)
6110       t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
6111       // x - mu
6112       t12 = datatree.AddMinus(arg1, arg2);
6113       // y = (x-mu)/sigma
6114       y = datatree.AddDivide(t12, arg3);
6115       // (x-mu)^2/sigma^2
6116       t12 = datatree.AddTimes(y, y);
6117       // -(x-mu)^2/sigma^2
6118       t13 = datatree.AddUMinus(t12);
6119       // -((x-mu)^2/sigma^2)/2
6120       t12 = datatree.AddDivide(t13, datatree.Two);
6121       // exp(-((x-mu)^2/sigma^2)/2)
6122       t13 = datatree.AddExp(t12);
6123       // derivative of a standardized normal
6124       // t15 = (1/sqrt(2*pi))*exp(-y^2/2)
6125       t15 = datatree.AddDivide(t13, t14);
6126       // derivatives thru x
6127       t11 = datatree.AddDivide(darg1, arg3);
6128       // derivatives thru mu
6129       t12 = datatree.AddDivide(darg2, arg3);
6130       // intermediary sum
6131       t14 = datatree.AddMinus(t11, t12);
6132       // derivatives thru sigma
6133       t11 = datatree.AddDivide(y, arg3);
6134       t12 = datatree.AddTimes(t11, darg3);
6135       //intermediary sum
6136       t11 = datatree.AddMinus(t14, t12);
6137       // total derivative:
6138       // (darg1/sigma - darg2/sigma - darg3*(x-mu)/sigma^2) * t15
6139       // where t15 is the derivative of a standardized normal
6140       return datatree.AddTimes(t11, t15);
6141     case TrinaryOpcode::normpdf:
6142       // (x - mu)
6143       t11 = datatree.AddMinus(arg1, arg2);
6144       // (x - mu)/sigma
6145       t12 = datatree.AddDivide(t11, arg3);
6146       // darg3 * (x - mu)/sigma
6147       t11 = datatree.AddTimes(darg3, t12);
6148       // darg2 - darg1
6149       t13 = datatree.AddMinus(darg2, darg1);
6150       // darg2 - darg1 + darg3 * (x - mu)/sigma
6151       t14 = datatree.AddPlus(t13, t11);
6152       // ((x - mu)/sigma) * (darg2 - darg1 + darg3 * (x - mu)/sigma)
6153       t11 = datatree.AddTimes(t12, t14);
6154       // ((x - mu)/sigma) * (darg2 - darg1 + darg3 * (x - mu)/sigma) - darg3
6155       t12 = datatree.AddMinus(t11, darg3);
6156       // this / sigma
6157       t11 = datatree.AddDivide(this, arg3);
6158       // total derivative:
6159       // (this / sigma) * (((x - mu)/sigma) * (darg2 - darg1 + darg3 * (x - mu)/sigma) - darg3)
6160       return datatree.AddTimes(t11, t12);
6161     }
6162   // Suppress GCC warning
6163   exit(EXIT_FAILURE);
6164 }
6165 
6166 expr_t
computeDerivative(int deriv_id)6167 TrinaryOpNode::computeDerivative(int deriv_id)
6168 {
6169   expr_t darg1 = arg1->getDerivative(deriv_id);
6170   expr_t darg2 = arg2->getDerivative(deriv_id);
6171   expr_t darg3 = arg3->getDerivative(deriv_id);
6172   return composeDerivatives(darg1, darg2, darg3);
6173 }
6174 
6175 int
precedence(ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms) const6176 TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
6177 {
6178   // A temporary term behaves as a variable
6179   if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6180     return 100;
6181 
6182   switch (op_code)
6183     {
6184     case TrinaryOpcode::normcdf:
6185     case TrinaryOpcode::normpdf:
6186       return 100;
6187     }
6188   // Suppress GCC warning
6189   exit(EXIT_FAILURE);
6190 }
6191 
6192 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const6193 TrinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
6194 {
6195   // For a temporary term, the cost is null
6196   for (const auto &it : temp_terms_map)
6197     if (it.second.find(const_cast<TrinaryOpNode *>(this)) != it.second.end())
6198       return 0;
6199 
6200   int arg_cost = arg1->cost(temp_terms_map, is_matlab)
6201     + arg2->cost(temp_terms_map, is_matlab)
6202     + arg3->cost(temp_terms_map, is_matlab);
6203 
6204   return cost(arg_cost, is_matlab);
6205 }
6206 
6207 int
cost(const temporary_terms_t & temporary_terms,bool is_matlab) const6208 TrinaryOpNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
6209 {
6210   // For a temporary term, the cost is null
6211   if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6212     return 0;
6213 
6214   int arg_cost = arg1->cost(temporary_terms, is_matlab)
6215     + arg2->cost(temporary_terms, is_matlab)
6216     + arg3->cost(temporary_terms, is_matlab);
6217 
6218   return cost(arg_cost, is_matlab);
6219 }
6220 
6221 int
cost(int cost,bool is_matlab) const6222 TrinaryOpNode::cost(int cost, bool is_matlab) const
6223 {
6224   if (is_matlab)
6225     // Cost for Matlab files
6226     switch (op_code)
6227       {
6228       case TrinaryOpcode::normcdf:
6229       case TrinaryOpcode::normpdf:
6230         return cost+1000;
6231       }
6232   else
6233     // Cost for C files
6234     switch (op_code)
6235       {
6236       case TrinaryOpcode::normcdf:
6237       case TrinaryOpcode::normpdf:
6238         return cost+1000;
6239       }
6240   // Suppress GCC warning
6241   exit(EXIT_FAILURE);
6242 }
6243 
6244 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const6245 TrinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
6246                                      map<pair<int, int>, temporary_terms_t> &temp_terms_map,
6247                                      map<expr_t, pair<int, pair<int, int>>> &reference_count,
6248                                      bool is_matlab) const
6249 {
6250   expr_t this2 = const_cast<TrinaryOpNode *>(this);
6251   if (auto it = reference_count.find(this2);
6252       it == reference_count.end())
6253     {
6254       // If this node has never been encountered, set its ref count to one,
6255       //  and travel through its children
6256       reference_count[this2] = { 1, derivOrder };
6257       arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
6258       arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
6259       arg3->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
6260     }
6261   else
6262     {
6263       // If the node has already been encountered, increment its ref count
6264       //  and declare it as a temporary term if it is too costly
6265       reference_count[this2] = { it->second.first + 1, it->second.second };
6266       if (reference_count[this2].first * cost(temp_terms_map, is_matlab) > min_cost(is_matlab))
6267         temp_terms_map[reference_count[this2].second].insert(this2);
6268     }
6269 }
6270 
6271 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const6272 TrinaryOpNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
6273                                      temporary_terms_t &temporary_terms,
6274                                      map<expr_t, pair<int, int>> &first_occurence,
6275                                      int Curr_block,
6276                                      vector<vector<temporary_terms_t>> &v_temporary_terms,
6277                                      int equation) const
6278 {
6279   expr_t this2 = const_cast<TrinaryOpNode *>(this);
6280   if (auto it = reference_count.find(this2);
6281       it == reference_count.end())
6282     {
6283       reference_count[this2] = 1;
6284       first_occurence[this2] = { Curr_block, equation };
6285       arg1->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
6286       arg2->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
6287       arg3->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
6288     }
6289   else
6290     {
6291       reference_count[this2]++;
6292       if (reference_count[this2] * cost(temporary_terms, false) > min_cost_c)
6293         {
6294           temporary_terms.insert(this2);
6295           v_temporary_terms[first_occurence[this2].first][first_occurence[this2].second].insert(this2);
6296         }
6297     }
6298 }
6299 
6300 double
eval_opcode(double v1,TrinaryOpcode op_code,double v2,double v3)6301 TrinaryOpNode::eval_opcode(double v1, TrinaryOpcode op_code, double v2, double v3) noexcept(false)
6302 {
6303   switch (op_code)
6304     {
6305     case TrinaryOpcode::normcdf:
6306       return (0.5*(1+erf((v1-v2)/v3/M_SQRT2)));
6307     case TrinaryOpcode::normpdf:
6308       return (1/(v3*sqrt(2*M_PI)*exp(pow((v1-v2)/v3, 2)/2)));
6309     }
6310   // Suppress GCC warning
6311   exit(EXIT_FAILURE);
6312 }
6313 
6314 double
eval(const eval_context_t & eval_context) const6315 TrinaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
6316 {
6317   double v1 = arg1->eval(eval_context);
6318   double v2 = arg2->eval(eval_context);
6319   double v3 = arg3->eval(eval_context);
6320 
6321   return eval_opcode(v1, op_code, v2, v3);
6322 }
6323 
6324 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const6325 TrinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
6326                        bool lhs_rhs, const temporary_terms_t &temporary_terms,
6327                        const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
6328                        const deriv_node_temp_terms_t &tef_terms) const
6329 {
6330   // If current node is a temporary term
6331   if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6332     {
6333       if (dynamic)
6334         {
6335           auto ii = map_idx.find(idx);
6336           FLDT_ fldt(ii->second);
6337           fldt.write(CompileCode, instruction_number);
6338         }
6339       else
6340         {
6341           auto ii = map_idx.find(idx);
6342           FLDST_ fldst(ii->second);
6343           fldst.write(CompileCode, instruction_number);
6344         }
6345       return;
6346     }
6347   arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
6348   arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
6349   arg3->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
6350   FTRINARY_ ftrinary{static_cast<int>(op_code)};
6351   ftrinary.write(CompileCode, instruction_number);
6352 }
6353 
6354 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const6355 TrinaryOpNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
6356 {
6357   if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6358     temporary_terms_inuse.insert(idx);
6359   else
6360     {
6361       arg1->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
6362       arg2->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
6363       arg3->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
6364     }
6365 }
6366 
6367 bool
containsExternalFunction() const6368 TrinaryOpNode::containsExternalFunction() const
6369 {
6370   return arg1->containsExternalFunction()
6371     || arg2->containsExternalFunction()
6372     || arg3->containsExternalFunction();
6373 }
6374 
6375 void
writeJsonAST(ostream & output) const6376 TrinaryOpNode::writeJsonAST(ostream &output) const
6377 {
6378   output << R"({"node_type" : "TrinaryOpNode", )"
6379          << R"("op" : ")";
6380   switch (op_code)
6381     {
6382     case TrinaryOpcode::normcdf:
6383       output << "normcdf";
6384       break;
6385     case TrinaryOpcode::normpdf:
6386       output << "normpdf";
6387       break;
6388     }
6389   output << R"(", "arg1" : )";
6390   arg1->writeJsonAST(output);
6391   output << R"(, "arg2" : )";
6392   arg2->writeJsonAST(output);
6393   output << R"(, "arg2" : )";
6394   arg3->writeJsonAST(output);
6395   output << "}";
6396 }
6397 
6398 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const6399 TrinaryOpNode::writeJsonOutput(ostream &output,
6400                                const temporary_terms_t &temporary_terms,
6401                                const deriv_node_temp_terms_t &tef_terms,
6402                                bool isdynamic) const
6403 {
6404   // If current node is a temporary term
6405   if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6406     {
6407       output << "T" << idx;
6408       return;
6409     }
6410 
6411   switch (op_code)
6412     {
6413     case TrinaryOpcode::normcdf:
6414       output << "normcdf(";
6415       break;
6416     case TrinaryOpcode::normpdf:
6417       output << "normpdf(";
6418       break;
6419     }
6420 
6421   arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
6422   output << ",";
6423   arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
6424   output << ",";
6425   arg3->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
6426   output << ")";
6427 }
6428 
6429 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const6430 TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
6431                            const temporary_terms_t &temporary_terms,
6432                            const temporary_terms_idxs_t &temporary_terms_idxs,
6433                            const deriv_node_temp_terms_t &tef_terms) const
6434 {
6435   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
6436     return;
6437 
6438   switch (op_code)
6439     {
6440     case TrinaryOpcode::normcdf:
6441       if (isCOutput(output_type))
6442         {
6443           // In C, there is no normcdf() primitive, so use erf()
6444           output << "(0.5*(1+erf(((";
6445           arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6446           output << ")-(";
6447           arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6448           output << "))/(";
6449           arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6450           output << ")/M_SQRT2)))";
6451         }
6452       else
6453         {
6454           output << "normcdf(";
6455           arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6456           output << ",";
6457           arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6458           output << ",";
6459           arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6460           output << ")";
6461         }
6462       break;
6463     case TrinaryOpcode::normpdf:
6464       if (isCOutput(output_type))
6465         {
6466           //(1/(v3*sqrt(2*M_PI)*exp(pow((v1-v2)/v3,2)/2)))
6467           output << "(1/(";
6468           arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6469           output << "*sqrt(2*M_PI)*exp(pow((";
6470           arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6471           output << "-";
6472           arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6473           output << ")/";
6474           arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6475           output << ",2)/2)))";
6476         }
6477       else
6478         {
6479           output << "normpdf(";
6480           arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6481           output << ",";
6482           arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6483           output << ",";
6484           arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6485           output << ")";
6486         }
6487       break;
6488     }
6489 }
6490 
6491 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const6492 TrinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
6493                                            const temporary_terms_t &temporary_terms,
6494                                            const temporary_terms_idxs_t &temporary_terms_idxs,
6495                                            deriv_node_temp_terms_t &tef_terms) const
6496 {
6497   arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6498   arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6499   arg3->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6500 }
6501 
6502 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const6503 TrinaryOpNode::writeJsonExternalFunctionOutput(vector<string> &efout,
6504                                                const temporary_terms_t &temporary_terms,
6505                                                deriv_node_temp_terms_t &tef_terms,
6506                                                bool isdynamic) const
6507 {
6508   arg1->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
6509   arg2->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
6510   arg3->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
6511 }
6512 
6513 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const6514 TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
6515                                              bool lhs_rhs, const temporary_terms_t &temporary_terms,
6516                                              const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
6517                                              deriv_node_temp_terms_t &tef_terms) const
6518 {
6519   arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
6520                                       dynamic, steady_dynamic, tef_terms);
6521   arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
6522                                       dynamic, steady_dynamic, tef_terms);
6523   arg3->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
6524                                       dynamic, steady_dynamic, tef_terms);
6525 }
6526 
6527 void
collectVARLHSVariable(set<expr_t> & result) const6528 TrinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
6529 {
6530   cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
6531   exit(EXIT_FAILURE);
6532 }
6533 
6534 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const6535 TrinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
6536 {
6537   arg1->collectDynamicVariables(type_arg, result);
6538   arg2->collectDynamicVariables(type_arg, result);
6539   arg3->collectDynamicVariables(type_arg, result);
6540 }
6541 
6542 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const6543 TrinaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
6544 {
6545   pair<int, expr_t> res = arg1->normalizeEquation(var_endo, List_of_Op_RHS);
6546   bool is_endogenous_present_1 = res.first;
6547   expr_t expr_t_1 = res.second;
6548   res = arg2->normalizeEquation(var_endo, List_of_Op_RHS);
6549   bool is_endogenous_present_2 = res.first;
6550   expr_t expr_t_2 = res.second;
6551   res = arg3->normalizeEquation(var_endo, List_of_Op_RHS);
6552   bool is_endogenous_present_3 = res.first;
6553   expr_t expr_t_3 = res.second;
6554   if (!is_endogenous_present_1 && !is_endogenous_present_2 && !is_endogenous_present_3)
6555     return { 0, datatree.AddNormcdf(expr_t_1, expr_t_2, expr_t_3) };
6556   else
6557     return { 2, nullptr };
6558 }
6559 
6560 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)6561 TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
6562 {
6563   expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
6564   expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
6565   expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
6566   return composeDerivatives(darg1, darg2, darg3);
6567 }
6568 
6569 expr_t
buildSimilarTrinaryOpNode(expr_t alt_arg1,expr_t alt_arg2,expr_t alt_arg3,DataTree & alt_datatree) const6570 TrinaryOpNode::buildSimilarTrinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, expr_t alt_arg3, DataTree &alt_datatree) const
6571 {
6572   switch (op_code)
6573     {
6574     case TrinaryOpcode::normcdf:
6575       return alt_datatree.AddNormcdf(alt_arg1, alt_arg2, alt_arg3);
6576     case TrinaryOpcode::normpdf:
6577       return alt_datatree.AddNormpdf(alt_arg1, alt_arg2, alt_arg3);
6578     }
6579   // Suppress GCC warning
6580   exit(EXIT_FAILURE);
6581 }
6582 
6583 expr_t
toStatic(DataTree & static_datatree) const6584 TrinaryOpNode::toStatic(DataTree &static_datatree) const
6585 {
6586   expr_t sarg1 = arg1->toStatic(static_datatree);
6587   expr_t sarg2 = arg2->toStatic(static_datatree);
6588   expr_t sarg3 = arg3->toStatic(static_datatree);
6589   return buildSimilarTrinaryOpNode(sarg1, sarg2, sarg3, static_datatree);
6590 }
6591 
6592 void
computeXrefs(EquationInfo & ei) const6593 TrinaryOpNode::computeXrefs(EquationInfo &ei) const
6594 {
6595   arg1->computeXrefs(ei);
6596   arg2->computeXrefs(ei);
6597   arg3->computeXrefs(ei);
6598 }
6599 
6600 expr_t
clone(DataTree & datatree) const6601 TrinaryOpNode::clone(DataTree &datatree) const
6602 {
6603   expr_t substarg1 = arg1->clone(datatree);
6604   expr_t substarg2 = arg2->clone(datatree);
6605   expr_t substarg3 = arg3->clone(datatree);
6606   return buildSimilarTrinaryOpNode(substarg1, substarg2, substarg3, datatree);
6607 }
6608 
6609 int
maxEndoLead() const6610 TrinaryOpNode::maxEndoLead() const
6611 {
6612   return max(arg1->maxEndoLead(), max(arg2->maxEndoLead(), arg3->maxEndoLead()));
6613 }
6614 
6615 int
maxExoLead() const6616 TrinaryOpNode::maxExoLead() const
6617 {
6618   return max(arg1->maxExoLead(), max(arg2->maxExoLead(), arg3->maxExoLead()));
6619 }
6620 
6621 int
maxEndoLag() const6622 TrinaryOpNode::maxEndoLag() const
6623 {
6624   return max(arg1->maxEndoLag(), max(arg2->maxEndoLag(), arg3->maxEndoLag()));
6625 }
6626 
6627 int
maxExoLag() const6628 TrinaryOpNode::maxExoLag() const
6629 {
6630   return max(arg1->maxExoLag(), max(arg2->maxExoLag(), arg3->maxExoLag()));
6631 }
6632 
6633 int
maxLead() const6634 TrinaryOpNode::maxLead() const
6635 {
6636   return max(arg1->maxLead(), max(arg2->maxLead(), arg3->maxLead()));
6637 }
6638 
6639 int
maxLag() const6640 TrinaryOpNode::maxLag() const
6641 {
6642   return max(arg1->maxLag(), max(arg2->maxLag(), arg3->maxLag()));
6643 }
6644 
6645 int
maxLagWithDiffsExpanded() const6646 TrinaryOpNode::maxLagWithDiffsExpanded() const
6647 {
6648   return max(arg1->maxLagWithDiffsExpanded(),
6649              max(arg2->maxLagWithDiffsExpanded(), arg3->maxLagWithDiffsExpanded()));
6650 }
6651 
6652 expr_t
undiff() const6653 TrinaryOpNode::undiff() const
6654 {
6655   expr_t arg1subst = arg1->undiff();
6656   expr_t arg2subst = arg2->undiff();
6657   expr_t arg3subst = arg3->undiff();
6658   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6659 }
6660 
6661 int
VarMinLag() const6662 TrinaryOpNode::VarMinLag() const
6663 {
6664   return min(min(arg1->VarMinLag(), arg2->VarMinLag()), arg3->VarMinLag());
6665 }
6666 
6667 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const6668 TrinaryOpNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
6669 {
6670   return max(arg1->VarMaxLag(lhs_lag_equiv),
6671              max(arg2->VarMaxLag(lhs_lag_equiv),
6672                  arg3->VarMaxLag(lhs_lag_equiv)));
6673 }
6674 
6675 int
PacMaxLag(int lhs_symb_id) const6676 TrinaryOpNode::PacMaxLag(int lhs_symb_id) const
6677 {
6678   return max(arg1->PacMaxLag(lhs_symb_id), max(arg2->PacMaxLag(lhs_symb_id), arg3->PacMaxLag(lhs_symb_id)));
6679 }
6680 
6681 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const6682 TrinaryOpNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
6683 {
6684   int symb_id = arg1->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
6685   if (symb_id >= 0)
6686     return symb_id;
6687 
6688   symb_id = arg2->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
6689   if (symb_id >= 0)
6690     return symb_id;
6691 
6692   return arg3->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
6693 }
6694 
6695 expr_t
decreaseLeadsLags(int n) const6696 TrinaryOpNode::decreaseLeadsLags(int n) const
6697 {
6698   expr_t arg1subst = arg1->decreaseLeadsLags(n);
6699   expr_t arg2subst = arg2->decreaseLeadsLags(n);
6700   expr_t arg3subst = arg3->decreaseLeadsLags(n);
6701   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6702 }
6703 
6704 expr_t
decreaseLeadsLagsPredeterminedVariables() const6705 TrinaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
6706 {
6707   expr_t arg1subst = arg1->decreaseLeadsLagsPredeterminedVariables();
6708   expr_t arg2subst = arg2->decreaseLeadsLagsPredeterminedVariables();
6709   expr_t arg3subst = arg3->decreaseLeadsLagsPredeterminedVariables();
6710   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6711 }
6712 
6713 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const6714 TrinaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
6715 {
6716   if (maxEndoLead() < 2)
6717     return const_cast<TrinaryOpNode *>(this);
6718   else if (deterministic_model)
6719     {
6720       expr_t arg1subst = arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
6721       expr_t arg2subst = arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
6722       expr_t arg3subst = arg3->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
6723       return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6724     }
6725   else
6726     return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
6727 }
6728 
6729 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6730 TrinaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6731 {
6732   expr_t arg1subst = arg1->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
6733   expr_t arg2subst = arg2->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
6734   expr_t arg3subst = arg3->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
6735   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6736 }
6737 
6738 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const6739 TrinaryOpNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
6740 {
6741   if (maxExoLead() == 0)
6742     return const_cast<TrinaryOpNode *>(this);
6743   else if (deterministic_model)
6744     {
6745       expr_t arg1subst = arg1->substituteExoLead(subst_table, neweqs, deterministic_model);
6746       expr_t arg2subst = arg2->substituteExoLead(subst_table, neweqs, deterministic_model);
6747       expr_t arg3subst = arg3->substituteExoLead(subst_table, neweqs, deterministic_model);
6748       return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6749     }
6750   else
6751     return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
6752 }
6753 
6754 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6755 TrinaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6756 {
6757   expr_t arg1subst = arg1->substituteExoLag(subst_table, neweqs);
6758   expr_t arg2subst = arg2->substituteExoLag(subst_table, neweqs);
6759   expr_t arg3subst = arg3->substituteExoLag(subst_table, neweqs);
6760   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6761 }
6762 
6763 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const6764 TrinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
6765 {
6766   expr_t arg1subst = arg1->substituteExpectation(subst_table, neweqs, partial_information_model);
6767   expr_t arg2subst = arg2->substituteExpectation(subst_table, neweqs, partial_information_model);
6768   expr_t arg3subst = arg3->substituteExpectation(subst_table, neweqs, partial_information_model);
6769   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6770 }
6771 
6772 expr_t
substituteAdl() const6773 TrinaryOpNode::substituteAdl() const
6774 {
6775   expr_t arg1subst = arg1->substituteAdl();
6776   expr_t arg2subst = arg2->substituteAdl();
6777   expr_t arg3subst = arg3->substituteAdl();
6778   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6779 }
6780 
6781 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const6782 TrinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
6783 {
6784   expr_t arg1subst = arg1->substituteVarExpectation(subst_table);
6785   expr_t arg2subst = arg2->substituteVarExpectation(subst_table);
6786   expr_t arg3subst = arg3->substituteVarExpectation(subst_table);
6787   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6788 }
6789 
6790 void
findDiffNodes(lag_equivalence_table_t & nodes) const6791 TrinaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
6792 {
6793   arg1->findDiffNodes(nodes);
6794   arg2->findDiffNodes(nodes);
6795   arg3->findDiffNodes(nodes);
6796 }
6797 
6798 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const6799 TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
6800 {
6801   arg1->findUnaryOpNodesForAuxVarCreation(nodes);
6802   arg2->findUnaryOpNodesForAuxVarCreation(nodes);
6803   arg3->findUnaryOpNodesForAuxVarCreation(nodes);
6804 }
6805 
6806 int
findTargetVariable(int lhs_symb_id) const6807 TrinaryOpNode::findTargetVariable(int lhs_symb_id) const
6808 {
6809   int retval = arg1->findTargetVariable(lhs_symb_id);
6810   if (retval < 0)
6811     retval = arg2->findTargetVariable(lhs_symb_id);
6812   if (retval < 0)
6813     retval = arg3->findTargetVariable(lhs_symb_id);
6814   return retval;
6815 }
6816 
6817 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6818 TrinaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
6819                               vector<BinaryOpNode *> &neweqs) const
6820 {
6821   expr_t arg1subst = arg1->substituteDiff(nodes, subst_table, neweqs);
6822   expr_t arg2subst = arg2->substituteDiff(nodes, subst_table, neweqs);
6823   expr_t arg3subst = arg3->substituteDiff(nodes, subst_table, neweqs);
6824   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6825 }
6826 
6827 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6828 TrinaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6829 {
6830   expr_t arg1subst = arg1->substituteUnaryOpNodes(nodes, subst_table, neweqs);
6831   expr_t arg2subst = arg2->substituteUnaryOpNodes(nodes, subst_table, neweqs);
6832   expr_t arg3subst = arg3->substituteUnaryOpNodes(nodes, subst_table, neweqs);
6833   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6834 }
6835 
6836 int
countDiffs() const6837 TrinaryOpNode::countDiffs() const
6838 {
6839   return max(arg1->countDiffs(), max(arg2->countDiffs(), arg3->countDiffs()));
6840 }
6841 
6842 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)6843 TrinaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
6844 {
6845   expr_t arg1subst = arg1->substitutePacExpectation(name, subexpr);
6846   expr_t arg2subst = arg2->substitutePacExpectation(name, subexpr);
6847   expr_t arg3subst = arg3->substitutePacExpectation(name, subexpr);
6848   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6849 }
6850 
6851 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6852 TrinaryOpNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6853 {
6854   expr_t arg1subst = arg1->differentiateForwardVars(subset, subst_table, neweqs);
6855   expr_t arg2subst = arg2->differentiateForwardVars(subset, subst_table, neweqs);
6856   expr_t arg3subst = arg3->differentiateForwardVars(subset, subst_table, neweqs);
6857   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6858 }
6859 
6860 bool
isNumConstNodeEqualTo(double value) const6861 TrinaryOpNode::isNumConstNodeEqualTo(double value) const
6862 {
6863   return false;
6864 }
6865 
6866 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const6867 TrinaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
6868 {
6869   return false;
6870 }
6871 
6872 bool
containsPacExpectation(const string & pac_model_name) const6873 TrinaryOpNode::containsPacExpectation(const string &pac_model_name) const
6874 {
6875   return (arg1->containsPacExpectation(pac_model_name) || arg2->containsPacExpectation(pac_model_name) || arg3->containsPacExpectation(pac_model_name));
6876 }
6877 
6878 bool
containsEndogenous() const6879 TrinaryOpNode::containsEndogenous() const
6880 {
6881   return (arg1->containsEndogenous() || arg2->containsEndogenous() || arg3->containsEndogenous());
6882 }
6883 
6884 bool
containsExogenous() const6885 TrinaryOpNode::containsExogenous() const
6886 {
6887   return (arg1->containsExogenous() || arg2->containsExogenous() || arg3->containsExogenous());
6888 }
6889 
6890 expr_t
replaceTrendVar() const6891 TrinaryOpNode::replaceTrendVar() const
6892 {
6893   expr_t arg1subst = arg1->replaceTrendVar();
6894   expr_t arg2subst = arg2->replaceTrendVar();
6895   expr_t arg3subst = arg3->replaceTrendVar();
6896   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6897 }
6898 
6899 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const6900 TrinaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
6901 {
6902   expr_t arg1subst = arg1->detrend(symb_id, log_trend, trend);
6903   expr_t arg2subst = arg2->detrend(symb_id, log_trend, trend);
6904   expr_t arg3subst = arg3->detrend(symb_id, log_trend, trend);
6905   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6906 }
6907 
6908 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const6909 TrinaryOpNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
6910 {
6911   expr_t arg1subst = arg1->removeTrendLeadLag(trend_symbols_map);
6912   expr_t arg2subst = arg2->removeTrendLeadLag(trend_symbols_map);
6913   expr_t arg3subst = arg3->removeTrendLeadLag(trend_symbols_map);
6914   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6915 }
6916 
6917 bool
isInStaticForm() const6918 TrinaryOpNode::isInStaticForm() const
6919 {
6920   return arg1->isInStaticForm() && arg2->isInStaticForm() && arg3->isInStaticForm();
6921 }
6922 
6923 bool
isParamTimesEndogExpr() const6924 TrinaryOpNode::isParamTimesEndogExpr() const
6925 {
6926   return arg1->isParamTimesEndogExpr()
6927     || arg2->isParamTimesEndogExpr()
6928     || arg3->isParamTimesEndogExpr();
6929 }
6930 
6931 bool
isVarModelReferenced(const string & model_info_name) const6932 TrinaryOpNode::isVarModelReferenced(const string &model_info_name) const
6933 {
6934   return arg1->isVarModelReferenced(model_info_name)
6935     || arg2->isVarModelReferenced(model_info_name)
6936     || arg3->isVarModelReferenced(model_info_name);
6937 }
6938 
6939 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const6940 TrinaryOpNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
6941 {
6942   arg1->getEndosAndMaxLags(model_endos_and_lags);
6943   arg2->getEndosAndMaxLags(model_endos_and_lags);
6944   arg3->getEndosAndMaxLags(model_endos_and_lags);
6945 }
6946 
6947 expr_t
substituteStaticAuxiliaryVariable() const6948 TrinaryOpNode::substituteStaticAuxiliaryVariable() const
6949 {
6950   expr_t arg1subst = arg1->substituteStaticAuxiliaryVariable();
6951   expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
6952   expr_t arg3subst = arg3->substituteStaticAuxiliaryVariable();
6953   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6954 }
6955 
6956 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const6957 TrinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
6958 {
6959   arg1->findConstantEquations(table);
6960   arg2->findConstantEquations(table);
6961   arg3->findConstantEquations(table);
6962 }
6963 
6964 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const6965 TrinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
6966 {
6967   expr_t arg1subst = arg1->replaceVarsInEquation(table);
6968   expr_t arg2subst = arg2->replaceVarsInEquation(table);
6969   expr_t arg3subst = arg3->replaceVarsInEquation(table);
6970   return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6971 }
6972 
AbstractExternalFunctionNode(DataTree & datatree_arg,int idx_arg,int symb_id_arg,vector<expr_t> arguments_arg)6973 AbstractExternalFunctionNode::AbstractExternalFunctionNode(DataTree &datatree_arg,
6974                                                            int idx_arg,
6975                                                            int symb_id_arg,
6976                                                            vector<expr_t> arguments_arg) :
6977   ExprNode{datatree_arg, idx_arg},
6978   symb_id{symb_id_arg},
6979   arguments{move(arguments_arg)}
6980 {
6981 }
6982 
6983 void
prepareForDerivation()6984 AbstractExternalFunctionNode::prepareForDerivation()
6985 {
6986   if (preparedForDerivation)
6987     return;
6988 
6989   for (auto argument : arguments)
6990     argument->prepareForDerivation();
6991 
6992   non_null_derivatives = arguments.at(0)->non_null_derivatives;
6993   for (int i = 1; i < static_cast<int>(arguments.size()); i++)
6994     set_union(non_null_derivatives.begin(),
6995               non_null_derivatives.end(),
6996               arguments.at(i)->non_null_derivatives.begin(),
6997               arguments.at(i)->non_null_derivatives.end(),
6998               inserter(non_null_derivatives, non_null_derivatives.begin()));
6999 
7000   preparedForDerivation = true;
7001 }
7002 
7003 expr_t
computeDerivative(int deriv_id)7004 AbstractExternalFunctionNode::computeDerivative(int deriv_id)
7005 {
7006   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
7007   vector<expr_t> dargs;
7008   for (auto argument : arguments)
7009     dargs.push_back(argument->getDerivative(deriv_id));
7010   return composeDerivatives(dargs);
7011 }
7012 
7013 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)7014 AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
7015 {
7016   assert(datatree.external_functions_table.getNargs(symb_id) > 0);
7017   vector<expr_t> dargs;
7018   for (auto argument : arguments)
7019     dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables));
7020   return composeDerivatives(dargs);
7021 }
7022 
7023 unsigned int
compileExternalFunctionArguments(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const7024 AbstractExternalFunctionNode::compileExternalFunctionArguments(ostream &CompileCode, unsigned int &instruction_number,
7025                                                                bool lhs_rhs, const temporary_terms_t &temporary_terms,
7026                                                                const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
7027                                                                const deriv_node_temp_terms_t &tef_terms) const
7028 {
7029   for (auto argument : arguments)
7030     argument->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
7031                       dynamic, steady_dynamic, tef_terms);
7032   return (arguments.size());
7033 }
7034 
7035 void
collectVARLHSVariable(set<expr_t> & result) const7036 AbstractExternalFunctionNode::collectVARLHSVariable(set<expr_t> &result) const
7037 {
7038   cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
7039   exit(EXIT_FAILURE);
7040 }
7041 
7042 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const7043 AbstractExternalFunctionNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
7044 {
7045   for (auto argument : arguments)
7046     argument->collectDynamicVariables(type_arg, result);
7047 }
7048 
7049 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const7050 AbstractExternalFunctionNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
7051 {
7052   if (temporary_terms.find(const_cast<AbstractExternalFunctionNode *>(this)) != temporary_terms.end())
7053     temporary_terms_inuse.insert(idx);
7054   else
7055     for (auto argument : arguments)
7056       argument->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
7057 }
7058 
7059 double
eval(const eval_context_t & eval_context) const7060 AbstractExternalFunctionNode::eval(const eval_context_t &eval_context) const noexcept(false)
7061 {
7062   throw EvalExternalFunctionException();
7063 }
7064 
7065 int
maxEndoLead() const7066 AbstractExternalFunctionNode::maxEndoLead() const
7067 {
7068   int val = 0;
7069   for (auto argument : arguments)
7070     val = max(val, argument->maxEndoLead());
7071   return val;
7072 }
7073 
7074 int
maxExoLead() const7075 AbstractExternalFunctionNode::maxExoLead() const
7076 {
7077   int val = 0;
7078   for (auto argument : arguments)
7079     val = max(val, argument->maxExoLead());
7080   return val;
7081 }
7082 
7083 int
maxEndoLag() const7084 AbstractExternalFunctionNode::maxEndoLag() const
7085 {
7086   int val = 0;
7087   for (auto argument : arguments)
7088     val = max(val, argument->maxEndoLag());
7089   return val;
7090 }
7091 
7092 int
maxExoLag() const7093 AbstractExternalFunctionNode::maxExoLag() const
7094 {
7095   int val = 0;
7096   for (auto argument : arguments)
7097     val = max(val, argument->maxExoLag());
7098   return val;
7099 }
7100 
7101 int
maxLead() const7102 AbstractExternalFunctionNode::maxLead() const
7103 {
7104   int val = 0;
7105   for (auto argument : arguments)
7106     val = max(val, argument->maxLead());
7107   return val;
7108 }
7109 
7110 int
maxLag() const7111 AbstractExternalFunctionNode::maxLag() const
7112 {
7113   int val = 0;
7114   for (auto argument : arguments)
7115     val = max(val, argument->maxLag());
7116   return val;
7117 }
7118 
7119 int
maxLagWithDiffsExpanded() const7120 AbstractExternalFunctionNode::maxLagWithDiffsExpanded() const
7121 {
7122   int val = 0;
7123   for (auto argument : arguments)
7124     val = max(val, argument->maxLagWithDiffsExpanded());
7125   return val;
7126 }
7127 
7128 expr_t
undiff() const7129 AbstractExternalFunctionNode::undiff() const
7130 {
7131   vector<expr_t> arguments_subst;
7132   for (auto argument : arguments)
7133     arguments_subst.push_back(argument->undiff());
7134   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7135 }
7136 
7137 int
VarMinLag() const7138 AbstractExternalFunctionNode::VarMinLag() const
7139 {
7140   int val = 0;
7141   for (auto argument : arguments)
7142     val = min(val, argument->VarMinLag());
7143   return val;
7144 }
7145 
7146 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const7147 AbstractExternalFunctionNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
7148 {
7149   int max_lag = 0;
7150   for (auto argument : arguments)
7151     max_lag = max(max_lag, argument->VarMaxLag(lhs_lag_equiv));
7152   return max_lag;
7153 }
7154 
7155 int
PacMaxLag(int lhs_symb_id) const7156 AbstractExternalFunctionNode::PacMaxLag(int lhs_symb_id) const
7157 {
7158   int val = 0;
7159   for (auto argument : arguments)
7160     val = max(val, argument->PacMaxLag(lhs_symb_id));
7161   return val;
7162 }
7163 
7164 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const7165 AbstractExternalFunctionNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
7166 {
7167   int symb_id = -1;
7168   for (auto argument : arguments)
7169     {
7170       symb_id = argument->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
7171       if (symb_id >= 0)
7172         return symb_id;
7173     }
7174   return -1;
7175 }
7176 
7177 expr_t
decreaseLeadsLags(int n) const7178 AbstractExternalFunctionNode::decreaseLeadsLags(int n) const
7179 {
7180   vector<expr_t> arguments_subst;
7181   for (auto argument : arguments)
7182     arguments_subst.push_back(argument->decreaseLeadsLags(n));
7183   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7184 }
7185 
7186 expr_t
decreaseLeadsLagsPredeterminedVariables() const7187 AbstractExternalFunctionNode::decreaseLeadsLagsPredeterminedVariables() const
7188 {
7189   vector<expr_t> arguments_subst;
7190   for (auto argument : arguments)
7191     arguments_subst.push_back(argument->decreaseLeadsLagsPredeterminedVariables());
7192   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7193 }
7194 
7195 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const7196 AbstractExternalFunctionNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
7197 {
7198   vector<expr_t> arguments_subst;
7199   for (auto argument : arguments)
7200     arguments_subst.push_back(argument->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model));
7201   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7202 }
7203 
7204 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7205 AbstractExternalFunctionNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7206 {
7207   vector<expr_t> arguments_subst;
7208   for (auto argument : arguments)
7209     arguments_subst.push_back(argument->substituteEndoLagGreaterThanTwo(subst_table, neweqs));
7210   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7211 }
7212 
7213 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const7214 AbstractExternalFunctionNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
7215 {
7216   vector<expr_t> arguments_subst;
7217   for (auto argument : arguments)
7218     arguments_subst.push_back(argument->substituteExoLead(subst_table, neweqs, deterministic_model));
7219   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7220 }
7221 
7222 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7223 AbstractExternalFunctionNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7224 {
7225   vector<expr_t> arguments_subst;
7226   for (auto argument : arguments)
7227     arguments_subst.push_back(argument->substituteExoLag(subst_table, neweqs));
7228   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7229 }
7230 
7231 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const7232 AbstractExternalFunctionNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
7233 {
7234   vector<expr_t> arguments_subst;
7235   for (auto argument : arguments)
7236     arguments_subst.push_back(argument->substituteExpectation(subst_table, neweqs, partial_information_model));
7237   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7238 }
7239 
7240 expr_t
substituteAdl() const7241 AbstractExternalFunctionNode::substituteAdl() const
7242 {
7243   vector<expr_t> arguments_subst;
7244   for (auto argument : arguments)
7245     arguments_subst.push_back(argument->substituteAdl());
7246   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7247 }
7248 
7249 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const7250 AbstractExternalFunctionNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
7251 {
7252   vector<expr_t> arguments_subst;
7253   for (auto argument : arguments)
7254     arguments_subst.push_back(argument->substituteVarExpectation(subst_table));
7255   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7256 }
7257 
7258 void
findDiffNodes(lag_equivalence_table_t & nodes) const7259 AbstractExternalFunctionNode::findDiffNodes(lag_equivalence_table_t &nodes) const
7260 {
7261   for (auto argument : arguments)
7262     argument->findDiffNodes(nodes);
7263 }
7264 
7265 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const7266 AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
7267 {
7268   for (auto argument : arguments)
7269     argument->findUnaryOpNodesForAuxVarCreation(nodes);
7270 }
7271 
7272 int
findTargetVariable(int lhs_symb_id) const7273 AbstractExternalFunctionNode::findTargetVariable(int lhs_symb_id) const
7274 {
7275   for (auto argument : arguments)
7276     if (int retval = argument->findTargetVariable(lhs_symb_id);
7277         retval >= 0)
7278       return retval;
7279   return -1;
7280 }
7281 
7282 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7283 AbstractExternalFunctionNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
7284                                              vector<BinaryOpNode *> &neweqs) const
7285 {
7286   vector<expr_t> arguments_subst;
7287   for (auto argument : arguments)
7288     arguments_subst.push_back(argument->substituteDiff(nodes, subst_table, neweqs));
7289   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7290 }
7291 
7292 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7293 AbstractExternalFunctionNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7294 {
7295   vector<expr_t> arguments_subst;
7296   for (auto argument : arguments)
7297     arguments_subst.push_back(argument->substituteUnaryOpNodes(nodes, subst_table, neweqs));
7298   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7299 }
7300 
7301 int
countDiffs() const7302 AbstractExternalFunctionNode::countDiffs() const
7303 {
7304   int ndiffs = 0;
7305   for (auto argument : arguments)
7306     ndiffs = max(ndiffs, argument->countDiffs());
7307   return ndiffs;
7308 }
7309 
7310 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)7311 AbstractExternalFunctionNode::substitutePacExpectation(const string &name, expr_t subexpr)
7312 {
7313   vector<expr_t> arguments_subst;
7314   for (auto argument : arguments)
7315     arguments_subst.push_back(argument->substitutePacExpectation(name, subexpr));
7316   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7317 }
7318 
7319 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7320 AbstractExternalFunctionNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7321 {
7322   vector<expr_t> arguments_subst;
7323   for (auto argument : arguments)
7324     arguments_subst.push_back(argument->differentiateForwardVars(subset, subst_table, neweqs));
7325   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7326 }
7327 
7328 bool
alreadyWrittenAsTefTerm(int the_symb_id,const deriv_node_temp_terms_t & tef_terms) const7329 AbstractExternalFunctionNode::alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const
7330 {
7331   if (tef_terms.find({ the_symb_id, arguments }) != tef_terms.end())
7332     return true;
7333   return false;
7334 }
7335 
7336 int
getIndxInTefTerms(int the_symb_id,const deriv_node_temp_terms_t & tef_terms) const7337 AbstractExternalFunctionNode::getIndxInTefTerms(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const noexcept(false)
7338 {
7339   if (auto it = tef_terms.find({ the_symb_id, arguments });
7340       it != tef_terms.end())
7341     return it->second;
7342   throw UnknownFunctionNameAndArgs();
7343 }
7344 
7345 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const7346 AbstractExternalFunctionNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
7347                                                     map<pair<int, int>, temporary_terms_t> &temp_terms_map,
7348                                                     map<expr_t, pair<int, pair<int, int>>> &reference_count,
7349                                                     bool is_matlab) const
7350 {
7351   /* All external function nodes are declared as temporary terms.
7352 
7353      Given that temporary terms are separated in several functions (residuals,
7354      jacobian, …), we must make sure that all temporary terms derived from a
7355      given external function call are assigned just after that call.
7356 
7357      As a consequence, we need to “promote” some terms to a previous level (in
7358      the sense that residuals come before jacobian), if a temporary term
7359      corresponding to the same external function call is present in that
7360      previous level. */
7361 
7362   for (auto &tt : temp_terms_map)
7363     if (auto it = find_if(tt.second.cbegin(), tt.second.cend(), sameTefTermPredicate());
7364         it != tt.second.cend())
7365       {
7366         tt.second.insert(const_cast<AbstractExternalFunctionNode *>(this));
7367         return;
7368       }
7369 
7370   temp_terms_map[derivOrder].insert(const_cast<AbstractExternalFunctionNode *>(this));
7371 }
7372 
7373 bool
isNumConstNodeEqualTo(double value) const7374 AbstractExternalFunctionNode::isNumConstNodeEqualTo(double value) const
7375 {
7376   return false;
7377 }
7378 
7379 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const7380 AbstractExternalFunctionNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
7381 {
7382   return false;
7383 }
7384 
7385 bool
containsPacExpectation(const string & pac_model_name) const7386 AbstractExternalFunctionNode::containsPacExpectation(const string &pac_model_name) const
7387 {
7388   for (auto argument : arguments)
7389     if (argument->containsPacExpectation(pac_model_name))
7390       return true;
7391   return false;
7392 }
7393 
7394 bool
containsEndogenous() const7395 AbstractExternalFunctionNode::containsEndogenous() const
7396 {
7397   for (auto argument : arguments)
7398     if (argument->containsEndogenous())
7399       return true;
7400   return false;
7401 }
7402 
7403 bool
containsExogenous() const7404 AbstractExternalFunctionNode::containsExogenous() const
7405 {
7406   for (auto argument : arguments)
7407     if (argument->containsExogenous())
7408       return true;
7409   return false;
7410 }
7411 
7412 expr_t
replaceTrendVar() const7413 AbstractExternalFunctionNode::replaceTrendVar() const
7414 {
7415   vector<expr_t> arguments_subst;
7416   for (auto argument : arguments)
7417     arguments_subst.push_back(argument->replaceTrendVar());
7418   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7419 }
7420 
7421 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const7422 AbstractExternalFunctionNode::detrend(int symb_id, bool log_trend, expr_t trend) const
7423 {
7424   vector<expr_t> arguments_subst;
7425   for (auto argument : arguments)
7426     arguments_subst.push_back(argument->detrend(symb_id, log_trend, trend));
7427   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7428 }
7429 
7430 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const7431 AbstractExternalFunctionNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
7432 {
7433   vector<expr_t> arguments_subst;
7434   for (auto argument : arguments)
7435     arguments_subst.push_back(argument->removeTrendLeadLag(trend_symbols_map));
7436   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7437 }
7438 
7439 bool
isInStaticForm() const7440 AbstractExternalFunctionNode::isInStaticForm() const
7441 {
7442   for (auto argument : arguments)
7443     if (!argument->isInStaticForm())
7444       return false;
7445   return true;
7446 }
7447 
7448 bool
isParamTimesEndogExpr() const7449 AbstractExternalFunctionNode::isParamTimesEndogExpr() const
7450 {
7451   return false;
7452 }
7453 
7454 bool
isVarModelReferenced(const string & model_info_name) const7455 AbstractExternalFunctionNode::isVarModelReferenced(const string &model_info_name) const
7456 {
7457   for (auto argument : arguments)
7458     if (!argument->isVarModelReferenced(model_info_name))
7459       return true;
7460   return false;
7461 }
7462 
7463 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const7464 AbstractExternalFunctionNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
7465 {
7466   for (auto argument : arguments)
7467     argument->getEndosAndMaxLags(model_endos_and_lags);
7468 }
7469 
7470 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const7471 AbstractExternalFunctionNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
7472 {
7473   vector<pair<bool, expr_t>> V_arguments;
7474   vector<expr_t> V_expr_t;
7475   bool present = false;
7476   for (auto argument : arguments)
7477     {
7478       V_arguments.emplace_back(argument->normalizeEquation(var_endo, List_of_Op_RHS));
7479       present = present || V_arguments[V_arguments.size()-1].first;
7480       V_expr_t.push_back(V_arguments[V_arguments.size()-1].second);
7481     }
7482   if (!present)
7483     return { 0, datatree.AddExternalFunction(symb_id, V_expr_t) };
7484   else
7485     return { 2, nullptr };
7486 }
7487 
7488 void
writeExternalFunctionArguments(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const7489 AbstractExternalFunctionNode::writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type,
7490                                                              const temporary_terms_t &temporary_terms,
7491                                                              const temporary_terms_idxs_t &temporary_terms_idxs,
7492                                                              const deriv_node_temp_terms_t &tef_terms) const
7493 {
7494   for (auto it = arguments.begin(); it != arguments.end(); ++it)
7495     {
7496       if (it != arguments.begin())
7497         output << ",";
7498 
7499       (*it)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7500     }
7501 }
7502 
7503 void
writeJsonASTExternalFunctionArguments(ostream & output) const7504 AbstractExternalFunctionNode::writeJsonASTExternalFunctionArguments(ostream &output) const
7505 {
7506   int i = 0;
7507   output << "{";
7508   for (auto it = arguments.begin(); it != arguments.end(); ++it, i++)
7509     {
7510       if (it != arguments.begin())
7511         output << ",";
7512 
7513       output << R"("arg)" << i << R"(" : )";
7514       (*it)->writeJsonAST(output);
7515     }
7516   output << "}";
7517 }
7518 
7519 void
writeJsonExternalFunctionArguments(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7520 AbstractExternalFunctionNode::writeJsonExternalFunctionArguments(ostream &output,
7521                                                                  const temporary_terms_t &temporary_terms,
7522                                                                  const deriv_node_temp_terms_t &tef_terms,
7523                                                                  bool isdynamic) const
7524 {
7525   for (auto it = arguments.begin(); it != arguments.end(); ++it)
7526     {
7527       if (it != arguments.begin())
7528         output << ",";
7529 
7530       (*it)->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
7531     }
7532 }
7533 
7534 void
writePrhs(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms,const string & ending) const7535 AbstractExternalFunctionNode::writePrhs(ostream &output, ExprNodeOutputType output_type,
7536                                         const temporary_terms_t &temporary_terms,
7537                                         const temporary_terms_idxs_t &temporary_terms_idxs,
7538                                         const deriv_node_temp_terms_t &tef_terms, const string &ending) const
7539 {
7540   output << "mxArray *prhs"<< ending << "[nrhs"<< ending << "];" << endl;
7541   int i = 0;
7542   for (auto argument : arguments)
7543     {
7544       output << "prhs" << ending << "[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars
7545       argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7546       output << ");" << endl;
7547     }
7548 }
7549 
7550 bool
containsExternalFunction() const7551 AbstractExternalFunctionNode::containsExternalFunction() const
7552 {
7553   return true;
7554 }
7555 
7556 expr_t
substituteStaticAuxiliaryVariable() const7557 AbstractExternalFunctionNode::substituteStaticAuxiliaryVariable() const
7558 {
7559   vector<expr_t> arguments_subst;
7560   for (auto argument : arguments)
7561     arguments_subst.push_back(argument->substituteStaticAuxiliaryVariable());
7562   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7563 }
7564 
7565 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const7566 AbstractExternalFunctionNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
7567 {
7568   for (auto argument : arguments)
7569     argument->findConstantEquations(table);
7570 }
7571 
7572 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const7573 AbstractExternalFunctionNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
7574 {
7575   vector<expr_t> arguments_subst;
7576   for (auto argument : arguments)
7577     arguments_subst.push_back(argument->replaceVarsInEquation(table));
7578   return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7579 }
7580 
ExternalFunctionNode(DataTree & datatree_arg,int idx_arg,int symb_id_arg,const vector<expr_t> & arguments_arg)7581 ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg,
7582                                            int idx_arg,
7583                                            int symb_id_arg,
7584                                            const vector<expr_t> &arguments_arg) :
7585   AbstractExternalFunctionNode{datatree_arg, idx_arg, symb_id_arg, arguments_arg}
7586 {
7587 }
7588 
7589 expr_t
composeDerivatives(const vector<expr_t> & dargs)7590 ExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
7591 {
7592   vector<expr_t> dNodes;
7593   for (int i = 0; i < static_cast<int>(dargs.size()); i++)
7594     dNodes.push_back(datatree.AddTimes(dargs.at(i),
7595                                        datatree.AddFirstDerivExternalFunction(symb_id, arguments, i+1)));
7596 
7597   expr_t theDeriv = datatree.Zero;
7598   for (auto &dNode : dNodes)
7599     theDeriv = datatree.AddPlus(theDeriv, dNode);
7600   return theDeriv;
7601 }
7602 
7603 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const7604 ExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
7605                                             temporary_terms_t &temporary_terms,
7606                                             map<expr_t, pair<int, int>> &first_occurence,
7607                                             int Curr_block,
7608                                             vector< vector<temporary_terms_t>> &v_temporary_terms,
7609                                             int equation) const
7610 {
7611   expr_t this2 = const_cast<ExternalFunctionNode *>(this);
7612   temporary_terms.insert(this2);
7613   first_occurence[this2] = { Curr_block, equation };
7614   v_temporary_terms[Curr_block][equation].insert(this2);
7615 }
7616 
7617 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const7618 ExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number,
7619                               bool lhs_rhs, const temporary_terms_t &temporary_terms,
7620                               const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
7621                               const deriv_node_temp_terms_t &tef_terms) const
7622 {
7623   if (temporary_terms.find(const_cast<ExternalFunctionNode *>(this)) != temporary_terms.end())
7624     {
7625       if (dynamic)
7626         {
7627           auto ii = map_idx.find(idx);
7628           FLDT_ fldt(ii->second);
7629           fldt.write(CompileCode, instruction_number);
7630         }
7631       else
7632         {
7633           auto ii = map_idx.find(idx);
7634           FLDST_ fldst(ii->second);
7635           fldst.write(CompileCode, instruction_number);
7636         }
7637       return;
7638     }
7639 
7640   if (!lhs_rhs)
7641     {
7642       FLDTEF_ fldtef(getIndxInTefTerms(symb_id, tef_terms));
7643       fldtef.write(CompileCode, instruction_number);
7644     }
7645   else
7646     {
7647       FSTPTEF_ fstptef(getIndxInTefTerms(symb_id, tef_terms));
7648       fstptef.write(CompileCode, instruction_number);
7649     }
7650 }
7651 
7652 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const7653 ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
7654                                                     bool lhs_rhs, const temporary_terms_t &temporary_terms,
7655                                                     const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
7656                                                     deriv_node_temp_terms_t &tef_terms) const
7657 {
7658   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7659   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7660 
7661   for (auto argument : arguments)
7662     argument->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms,
7663                                             map_idx, dynamic, steady_dynamic, tef_terms);
7664 
7665   if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
7666     {
7667       tef_terms[{ symb_id, arguments }] = static_cast<int>(tef_terms.size());
7668       int indx = getIndxInTefTerms(symb_id, tef_terms);
7669       int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
7670       assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7671 
7672       unsigned int nb_output_arguments = 0;
7673       if (symb_id == first_deriv_symb_id
7674           && symb_id == second_deriv_symb_id)
7675         nb_output_arguments = 3;
7676       else if (symb_id == first_deriv_symb_id)
7677         nb_output_arguments = 2;
7678       else
7679         nb_output_arguments = 1;
7680       unsigned int nb_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms,
7681                                                                          map_idx, dynamic, steady_dynamic, tef_terms);
7682 
7683       FCALL_ fcall(nb_output_arguments, nb_input_arguments, datatree.symbol_table.getName(symb_id), indx);
7684       switch (nb_output_arguments)
7685         {
7686         case 1:
7687           fcall.set_function_type(ExternalFunctionType::withoutDerivative);
7688           break;
7689         case 2:
7690           fcall.set_function_type(ExternalFunctionType::withFirstDerivative);
7691           break;
7692         case 3:
7693           fcall.set_function_type(ExternalFunctionType::withFirstAndSecondDerivative);
7694           break;
7695         }
7696       fcall.write(CompileCode, instruction_number);
7697       FSTPTEF_ fstptef(indx);
7698       fstptef.write(CompileCode, instruction_number);
7699     }
7700 }
7701 
7702 void
writeJsonAST(ostream & output) const7703 ExternalFunctionNode::writeJsonAST(ostream &output) const
7704 {
7705   output << R"({"node_type" : "ExternalFunctionNode", )"
7706          << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "args" : [)";
7707   writeJsonASTExternalFunctionArguments(output);
7708   output << "]}";
7709 }
7710 
7711 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7712 ExternalFunctionNode::writeJsonOutput(ostream &output,
7713                                       const temporary_terms_t &temporary_terms,
7714                                       const deriv_node_temp_terms_t &tef_terms,
7715                                       bool isdynamic) const
7716 {
7717   if (temporary_terms.find(const_cast<ExternalFunctionNode *>(this)) != temporary_terms.end())
7718     {
7719       output << "T" << idx;
7720       return;
7721     }
7722 
7723   try
7724     {
7725       int tef_idx = getIndxInTefTerms(symb_id, tef_terms);
7726       output << "TEF_" << tef_idx;
7727     }
7728   catch (UnknownFunctionNameAndArgs &)
7729     {
7730       // When writing the JSON output at parsing pass, we don’t use TEF terms
7731       output << datatree.symbol_table.getName(symb_id) << "(";
7732       writeJsonExternalFunctionArguments(output, temporary_terms, tef_terms, isdynamic);
7733       output << ")";
7734     }
7735 }
7736 
7737 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const7738 ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
7739                                   const temporary_terms_t &temporary_terms,
7740                                   const temporary_terms_idxs_t &temporary_terms_idxs,
7741                                   const deriv_node_temp_terms_t &tef_terms) const
7742 {
7743   if (output_type == ExprNodeOutputType::matlabOutsideModel || output_type == ExprNodeOutputType::steadyStateFile
7744       || output_type == ExprNodeOutputType::juliaSteadyStateFile
7745       || output_type == ExprNodeOutputType::epilogueFile
7746       || isLatexOutput(output_type))
7747     {
7748       string name = isLatexOutput(output_type) ? datatree.symbol_table.getTeXName(symb_id)
7749         : datatree.symbol_table.getName(symb_id);
7750       output << name << "(";
7751       writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7752       output << ")";
7753       return;
7754     }
7755 
7756   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
7757     return;
7758 
7759   if (isCOutput(output_type))
7760     output << "*";
7761   output << "TEF_" << getIndxInTefTerms(symb_id, tef_terms);
7762 }
7763 
7764 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const7765 ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
7766                                                   const temporary_terms_t &temporary_terms,
7767                                                   const temporary_terms_idxs_t &temporary_terms_idxs,
7768                                                   deriv_node_temp_terms_t &tef_terms) const
7769 {
7770   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7771   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7772 
7773   for (auto argument : arguments)
7774     argument->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7775 
7776   if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
7777     {
7778       tef_terms[{ symb_id, arguments }] = static_cast<int>(tef_terms.size());
7779       int indx = getIndxInTefTerms(symb_id, tef_terms);
7780       int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
7781       assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7782 
7783       if (isCOutput(output_type))
7784         {
7785           stringstream ending;
7786           ending << "_tef_" << getIndxInTefTerms(symb_id, tef_terms);
7787           if (symb_id == first_deriv_symb_id
7788               && symb_id == second_deriv_symb_id)
7789             output << "int nlhs" << ending.str() << " = 3;" << endl
7790                    << "double *TEF_" << indx << ", "
7791                    << "*TEFD_" << indx << ", "
7792                    << "*TEFDD_" << indx << ";" << endl;
7793           else if (symb_id == first_deriv_symb_id)
7794             output << "int nlhs" << ending.str() << " = 2;" << endl
7795                    << "double *TEF_" << indx << ", "
7796                    << "*TEFD_" << indx << "; " << endl;
7797           else
7798             output << "int nlhs" << ending.str() << " = 1;" << endl
7799                    << "double *TEF_" << indx << ";" << endl;
7800 
7801           output << "mxArray *plhs" << ending.str()<< "[nlhs"<< ending.str() << "];" << endl;
7802           output << "int nrhs" << ending.str()<< " = " << arguments.size() << ";" << endl;
7803           writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
7804 
7805           output << "mexCallMATLAB("
7806                  << "nlhs" << ending.str() << ", "
7807                  << "plhs" << ending.str() << ", "
7808                  << "nrhs" << ending.str() << ", "
7809                  << "prhs" << ending.str() << R"(, ")"
7810                  << datatree.symbol_table.getName(symb_id) << R"(");)" << endl;
7811 
7812           if (symb_id == first_deriv_symb_id
7813               && symb_id == second_deriv_symb_id)
7814             output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl
7815                    << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl
7816                    << "TEFDD_" << indx << " = mxGetPr(plhs" << ending.str() << "[2]);" << endl
7817                    << "int TEFDD_" << indx << "_nrows = (int)mxGetM(plhs" << ending.str()<< "[2]);" << endl;
7818           else if (symb_id == first_deriv_symb_id)
7819             output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl
7820                    << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl;
7821           else
7822             output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
7823         }
7824       else
7825         {
7826           if (symb_id == first_deriv_symb_id
7827               && symb_id == second_deriv_symb_id)
7828             output << "[TEF_" << indx << ", TEFD_"<< indx << ", TEFDD_"<< indx << "] = ";
7829           else if (symb_id == first_deriv_symb_id)
7830             output << "[TEF_" << indx << ", TEFD_"<< indx << "] = ";
7831           else
7832             output << "TEF_" << indx << " = ";
7833 
7834           output << datatree.symbol_table.getName(symb_id) << "(";
7835           writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7836           output << ");" << endl;
7837         }
7838     }
7839 }
7840 
7841 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7842 ExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
7843                                                       const temporary_terms_t &temporary_terms,
7844                                                       deriv_node_temp_terms_t &tef_terms,
7845                                                       bool isdynamic) const
7846 {
7847   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7848   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7849 
7850   for (auto argument : arguments)
7851     argument->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
7852 
7853   if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
7854     {
7855       tef_terms[{ symb_id, arguments }] = static_cast<int>(tef_terms.size());
7856       int indx = getIndxInTefTerms(symb_id, tef_terms);
7857       int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
7858       assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7859 
7860       stringstream ef;
7861       ef << R"({"external_function": {)"
7862          << R"("external_function_term": "TEF_)" << indx << R"(")";
7863 
7864       if (symb_id == first_deriv_symb_id)
7865         ef << R"(, "external_function_term_d": "TEFD_)" << indx << R"(")";
7866 
7867       if (symb_id == second_deriv_symb_id)
7868         ef << R"(, "external_function_term_dd": "TEFDD_)" << indx << R"(")";
7869 
7870       ef << R"(, "value": ")" << datatree.symbol_table.getName(symb_id) << "(";
7871       writeJsonExternalFunctionArguments(ef, temporary_terms, tef_terms, isdynamic);
7872       ef << R"lit()"}})lit";
7873       efout.push_back(ef.str());
7874     }
7875 }
7876 
7877 expr_t
toStatic(DataTree & static_datatree) const7878 ExternalFunctionNode::toStatic(DataTree &static_datatree) const
7879 {
7880   vector<expr_t> static_arguments;
7881   for (auto argument : arguments)
7882     static_arguments.push_back(argument->toStatic(static_datatree));
7883   return static_datatree.AddExternalFunction(symb_id, static_arguments);
7884 }
7885 
7886 void
computeXrefs(EquationInfo & ei) const7887 ExternalFunctionNode::computeXrefs(EquationInfo &ei) const
7888 {
7889   vector<expr_t> dynamic_arguments;
7890   for (auto argument : arguments)
7891     argument->computeXrefs(ei);
7892 }
7893 
7894 expr_t
clone(DataTree & datatree) const7895 ExternalFunctionNode::clone(DataTree &datatree) const
7896 {
7897   vector<expr_t> dynamic_arguments;
7898   for (auto argument : arguments)
7899     dynamic_arguments.push_back(argument->clone(datatree));
7900   return datatree.AddExternalFunction(symb_id, dynamic_arguments);
7901 }
7902 
7903 expr_t
buildSimilarExternalFunctionNode(vector<expr_t> & alt_args,DataTree & alt_datatree) const7904 ExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const
7905 {
7906   return alt_datatree.AddExternalFunction(symb_id, alt_args);
7907 }
7908 
7909 function<bool (expr_t)>
sameTefTermPredicate() const7910 ExternalFunctionNode::sameTefTermPredicate() const
7911 {
7912   return [this](expr_t e) {
7913            auto e2 = dynamic_cast<ExternalFunctionNode *>(e);
7914            return (e2 != nullptr && e2->symb_id == symb_id);
7915          };
7916 }
7917 
FirstDerivExternalFunctionNode(DataTree & datatree_arg,int idx_arg,int top_level_symb_id_arg,const vector<expr_t> & arguments_arg,int inputIndex_arg)7918 FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatree_arg,
7919                                                                int idx_arg,
7920                                                                int top_level_symb_id_arg,
7921                                                                const vector<expr_t> &arguments_arg,
7922                                                                int inputIndex_arg) :
7923   AbstractExternalFunctionNode{datatree_arg, idx_arg, top_level_symb_id_arg, arguments_arg},
7924   inputIndex{inputIndex_arg}
7925 {
7926 }
7927 
7928 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const7929 FirstDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
7930                                                       temporary_terms_t &temporary_terms,
7931                                                       map<expr_t, pair<int, int>> &first_occurence,
7932                                                       int Curr_block,
7933                                                       vector< vector<temporary_terms_t>> &v_temporary_terms,
7934                                                       int equation) const
7935 {
7936   expr_t this2 = const_cast<FirstDerivExternalFunctionNode *>(this);
7937   temporary_terms.insert(this2);
7938   first_occurence[this2] = { Curr_block, equation };
7939   v_temporary_terms[Curr_block][equation].insert(this2);
7940 }
7941 
7942 expr_t
composeDerivatives(const vector<expr_t> & dargs)7943 FirstDerivExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
7944 {
7945   vector<expr_t> dNodes;
7946   for (int i = 0; i < static_cast<int>(dargs.size()); i++)
7947     dNodes.push_back(datatree.AddTimes(dargs.at(i),
7948                                        datatree.AddSecondDerivExternalFunction(symb_id, arguments, inputIndex, i+1)));
7949   expr_t theDeriv = datatree.Zero;
7950   for (auto &dNode : dNodes)
7951     theDeriv = datatree.AddPlus(theDeriv, dNode);
7952   return theDeriv;
7953 }
7954 
7955 void
writeJsonAST(ostream & output) const7956 FirstDerivExternalFunctionNode::writeJsonAST(ostream &output) const
7957 {
7958   output << R"({"node_type" : "FirstDerivExternalFunctionNode", )"
7959          << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "args" : [)";
7960   writeJsonASTExternalFunctionArguments(output);
7961   output << "]}";
7962 }
7963 
7964 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7965 FirstDerivExternalFunctionNode::writeJsonOutput(ostream &output,
7966                                                 const temporary_terms_t &temporary_terms,
7967                                                 const deriv_node_temp_terms_t &tef_terms,
7968                                                 bool isdynamic) const
7969 {
7970   // If current node is a temporary term
7971   if (temporary_terms.find(const_cast<FirstDerivExternalFunctionNode *>(this)) != temporary_terms.end())
7972     {
7973       output << "T" << idx;
7974       return;
7975     }
7976 
7977   const int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7978   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7979 
7980   const int tmpIndx = inputIndex - 1;
7981 
7982   if (first_deriv_symb_id == symb_id)
7983     output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms)
7984            << "[" << tmpIndx << "]";
7985   else if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
7986     output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
7987   else
7988     output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
7989            << "[" << tmpIndx << "]";
7990 }
7991 
7992 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const7993 FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
7994                                             const temporary_terms_t &temporary_terms,
7995                                             const temporary_terms_idxs_t &temporary_terms_idxs,
7996                                             const deriv_node_temp_terms_t &tef_terms) const
7997 {
7998   assert(output_type != ExprNodeOutputType::matlabOutsideModel);
7999 
8000   if (isLatexOutput(output_type))
8001     {
8002       output << R"(\frac{\partial )" << datatree.symbol_table.getTeXName(symb_id)
8003              << R"(}{\partial )" << inputIndex << "}(";
8004       writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8005       output << ")";
8006       return;
8007     }
8008 
8009   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
8010     return;
8011 
8012   const int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8013   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8014 
8015   const int tmpIndx = inputIndex - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
8016 
8017   if (first_deriv_symb_id == symb_id)
8018     output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms)
8019            << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type);
8020   else if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8021     {
8022       if (isCOutput(output_type))
8023         output << "*";
8024       output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
8025     }
8026   else
8027     output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
8028            << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type);
8029 }
8030 
8031 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const8032 FirstDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number,
8033                                         bool lhs_rhs, const temporary_terms_t &temporary_terms,
8034                                         const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8035                                         const deriv_node_temp_terms_t &tef_terms) const
8036 {
8037   if (temporary_terms.find(const_cast<FirstDerivExternalFunctionNode *>(this)) != temporary_terms.end())
8038     {
8039       if (dynamic)
8040         {
8041           auto ii = map_idx.find(idx);
8042           FLDT_ fldt(ii->second);
8043           fldt.write(CompileCode, instruction_number);
8044         }
8045       else
8046         {
8047           auto ii = map_idx.find(idx);
8048           FLDST_ fldst(ii->second);
8049           fldst.write(CompileCode, instruction_number);
8050         }
8051       return;
8052     }
8053   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8054   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8055 
8056   if (!lhs_rhs)
8057     {
8058       FLDTEFD_ fldtefd(getIndxInTefTerms(symb_id, tef_terms), inputIndex);
8059       fldtefd.write(CompileCode, instruction_number);
8060     }
8061   else
8062     {
8063       FSTPTEFD_ fstptefd(getIndxInTefTerms(symb_id, tef_terms), inputIndex);
8064       fstptefd.write(CompileCode, instruction_number);
8065     }
8066 }
8067 
8068 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const8069 FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
8070                                                             const temporary_terms_t &temporary_terms,
8071                                                             const temporary_terms_idxs_t &temporary_terms_idxs,
8072                                                             deriv_node_temp_terms_t &tef_terms) const
8073 {
8074   assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8075   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8076   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8077 
8078   /* For a node with derivs provided by the user function, call the method
8079      on the non-derived node */
8080   if (first_deriv_symb_id == symb_id)
8081     {
8082       expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8083       parent->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs,
8084                                           tef_terms);
8085       return;
8086     }
8087 
8088   if (alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
8089     return;
8090 
8091   if (isCOutput(output_type))
8092     if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8093       {
8094         stringstream ending;
8095         ending << "_tefd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
8096         output << "int nlhs" << ending.str() << " = 1;" << endl
8097                << "double *TEFD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl
8098                << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8099                << "int nrhs" << ending.str() << " = 3;" << endl
8100                << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl
8101                << "mwSize dims" << ending.str() << "[2];" << endl;
8102 
8103         output << "dims" << ending.str() << "[0] = 1;" << endl
8104                << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl;
8105 
8106         output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
8107                << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl
8108                << "prhs" << ending.str() << "[2] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl;
8109 
8110         int i = 0;
8111         for (auto argument : arguments)
8112           {
8113             output << "mxSetCell(prhs" << ending.str() << "[2], "
8114                    << i++ << ", "
8115                    << "mxCreateDoubleScalar("; // All external_function arguments are scalars
8116             argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8117             output << "));" << endl;
8118           }
8119 
8120         output << "mexCallMATLAB("
8121                << "nlhs" << ending.str() << ", "
8122                << "plhs" << ending.str() << ", "
8123                << "nrhs" << ending.str() << ", "
8124                << "prhs" << ending.str() << R"(, ")"
8125                << R"(jacob_element");)" << endl;
8126 
8127         output << "TEFD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex
8128                << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8129       }
8130     else
8131       {
8132         tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8133         int indx = getIndxInTefTerms(first_deriv_symb_id, tef_terms);
8134         stringstream ending;
8135         ending << "_tefd_def_" << indx;
8136         output << "int nlhs" << ending.str() << " = 1;" << endl
8137                << "double *TEFD_def_" << indx << ";" << endl
8138                << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8139                << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl;
8140         writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
8141 
8142         output << "mexCallMATLAB("
8143                << "nlhs" << ending.str() << ", "
8144                << "plhs" << ending.str() << ", "
8145                << "nrhs" << ending.str() << ", "
8146                << "prhs" << ending.str() << R"(, ")"
8147                << datatree.symbol_table.getName(first_deriv_symb_id) << R"(");)" << endl;
8148 
8149         output << "TEFD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8150       }
8151   else
8152     {
8153       if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8154         output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << " = jacob_element('"
8155                << datatree.symbol_table.getName(symb_id) << "'," << inputIndex << ",{";
8156       else
8157         {
8158           tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8159           output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
8160                  << " = " << datatree.symbol_table.getName(first_deriv_symb_id) << "(";
8161         }
8162 
8163       writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8164 
8165       if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8166         output << "}";
8167       output << ");" << endl;
8168     }
8169 }
8170 
8171 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const8172 FirstDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
8173                                                                 const temporary_terms_t &temporary_terms,
8174                                                                 deriv_node_temp_terms_t &tef_terms,
8175                                                                 bool isdynamic) const
8176 {
8177   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8178   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8179 
8180   /* For a node with derivs provided by the user function, call the method
8181      on the non-derived node */
8182   if (first_deriv_symb_id == symb_id)
8183     {
8184       expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8185       parent->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
8186       return;
8187     }
8188 
8189   if (alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
8190     return;
8191 
8192   stringstream ef;
8193   if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8194     ef << R"({"first_deriv_external_function": {)"
8195        << R"("external_function_term": "TEFD_fdd_)" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << R"(")"
8196        << R"(, "analytic_derivative": false)"
8197        << R"(, "wrt": )" << inputIndex
8198        << R"(, "value": ")" << datatree.symbol_table.getName(symb_id) << "(";
8199   else
8200     {
8201       tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8202       ef << R"({"first_deriv_external_function": {)"
8203          << R"("external_function_term": "TEFD_def_)" << getIndxInTefTerms(first_deriv_symb_id, tef_terms) << R"(")"
8204          << R"(, "analytic_derivative": true)"
8205          << R"(, "value": ")" << datatree.symbol_table.getName(first_deriv_symb_id) << "(";
8206     }
8207 
8208   writeJsonExternalFunctionArguments(ef, temporary_terms, tef_terms, isdynamic);
8209   ef << R"lit()"}})lit";
8210   efout.push_back(ef.str());
8211 }
8212 
8213 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const8214 FirstDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
8215                                                               bool lhs_rhs, const temporary_terms_t &temporary_terms,
8216                                                               const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8217                                                               deriv_node_temp_terms_t &tef_terms) const
8218 {
8219   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8220   assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8221 
8222   if (first_deriv_symb_id == symb_id || alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
8223     return;
8224 
8225   unsigned int nb_add_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms,
8226                                                                          map_idx, dynamic, steady_dynamic, tef_terms);
8227   if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8228     {
8229       unsigned int nb_input_arguments = 0;
8230       unsigned int nb_output_arguments = 1;
8231       unsigned int indx = getIndxInTefTerms(symb_id, tef_terms);
8232       FCALL_ fcall(nb_output_arguments, nb_input_arguments, "jacob_element", indx);
8233       fcall.set_arg_func_name(datatree.symbol_table.getName(symb_id));
8234       fcall.set_row(inputIndex);
8235       fcall.set_nb_add_input_arguments(nb_add_input_arguments);
8236       fcall.set_function_type(ExternalFunctionType::numericalFirstDerivative);
8237       fcall.write(CompileCode, instruction_number);
8238       FSTPTEFD_ fstptefd(indx, inputIndex);
8239       fstptefd.write(CompileCode, instruction_number);
8240     }
8241   else
8242     {
8243       tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8244       int indx = getIndxInTefTerms(symb_id, tef_terms);
8245       int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8246       assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8247 
8248       unsigned int nb_output_arguments = 1;
8249 
8250       FCALL_ fcall(nb_output_arguments, nb_add_input_arguments, datatree.symbol_table.getName(first_deriv_symb_id), indx);
8251       fcall.set_function_type(ExternalFunctionType::firstDerivative);
8252       fcall.write(CompileCode, instruction_number);
8253       FSTPTEFD_ fstptefd(indx, inputIndex);
8254       fstptefd.write(CompileCode, instruction_number);
8255     }
8256 }
8257 
8258 expr_t
8259 FirstDerivExternalFunctionNode::clone(DataTree &datatree) const
8260 {
8261   vector<expr_t> dynamic_arguments;
8262   for (auto argument : arguments)
8263     dynamic_arguments.push_back(argument->clone(datatree));
8264   return datatree.AddFirstDerivExternalFunction(symb_id, dynamic_arguments,
8265                                                 inputIndex);
8266 }
8267 
8268 expr_t
8269 FirstDerivExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const
8270 {
8271   return alt_datatree.AddFirstDerivExternalFunction(symb_id, alt_args, inputIndex);
8272 }
8273 
8274 expr_t
8275 FirstDerivExternalFunctionNode::toStatic(DataTree &static_datatree) const
8276 {
8277   vector<expr_t> static_arguments;
8278   for (auto argument : arguments)
8279     static_arguments.push_back(argument->toStatic(static_datatree));
8280   return static_datatree.AddFirstDerivExternalFunction(symb_id, static_arguments,
8281                                                        inputIndex);
8282 }
8283 
8284 void
8285 FirstDerivExternalFunctionNode::computeXrefs(EquationInfo &ei) const
8286 {
8287   vector<expr_t> dynamic_arguments;
8288   for (auto argument : arguments)
8289     argument->computeXrefs(ei);
8290 }
8291 
8292 function<bool (expr_t)>
8293 FirstDerivExternalFunctionNode::sameTefTermPredicate() const
8294 {
8295   int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8296   if (first_deriv_symb_id == symb_id)
8297     return [this](expr_t e) {
8298       auto e2 = dynamic_cast<ExternalFunctionNode *>(e);
8299       return (e2 && e2->symb_id == symb_id);
8300     };
8301   else
8302     return [this](expr_t e) {
8303       auto e2 = dynamic_cast<FirstDerivExternalFunctionNode *>(e);
8304       return (e2 && e2->symb_id == symb_id);
8305     };
8306 }
8307 
8308 SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datatree_arg,
8309                                                                  int idx_arg,
8310                                                                  int top_level_symb_id_arg,
8311                                                                  const vector<expr_t> &arguments_arg,
8312                                                                  int inputIndex1_arg,
8313                                                                  int inputIndex2_arg) :
8314   AbstractExternalFunctionNode{datatree_arg, idx_arg, top_level_symb_id_arg, arguments_arg},
8315   inputIndex1{inputIndex1_arg},
8316   inputIndex2{inputIndex2_arg}
8317 {
8318 }
8319 
8320 void
8321 SecondDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
8322                                                        temporary_terms_t &temporary_terms,
8323                                                        map<expr_t, pair<int, int>> &first_occurence,
8324                                                        int Curr_block,
8325                                                        vector< vector<temporary_terms_t>> &v_temporary_terms,
8326                                                        int equation) const
8327 {
8328   expr_t this2 = const_cast<SecondDerivExternalFunctionNode *>(this);
8329   temporary_terms.insert(this2);
8330   first_occurence[this2] = { Curr_block, equation };
8331   v_temporary_terms[Curr_block][equation].insert(this2);
8332 }
8333 
8334 expr_t
8335 SecondDerivExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
8336 {
8337   cerr << "ERROR: third order derivatives of external functions are not implemented" << endl;
8338   exit(EXIT_FAILURE);
8339 }
8340 
8341 void
8342 SecondDerivExternalFunctionNode::writeJsonAST(ostream &output) const
8343 {
8344   output << R"({"node_type" : "SecondDerivExternalFunctionNode", )"
8345          << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "args" : [)";
8346   writeJsonASTExternalFunctionArguments(output);
8347   output << "]}";
8348 }
8349 
8350 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const8351 SecondDerivExternalFunctionNode::writeJsonOutput(ostream &output,
8352                                                  const temporary_terms_t &temporary_terms,
8353                                                  const deriv_node_temp_terms_t &tef_terms,
8354                                                  bool isdynamic) const
8355 {
8356   // If current node is a temporary term
8357   if (temporary_terms.find(const_cast<SecondDerivExternalFunctionNode *>(this)) != temporary_terms.end())
8358     {
8359       output << "T" << idx;
8360       return;
8361     }
8362 
8363   const int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8364   assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8365 
8366   const int tmpIndex1 = inputIndex1 - 1;
8367   const int tmpIndex2 = inputIndex2 - 1;
8368 
8369   if (second_deriv_symb_id == symb_id)
8370     output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms)
8371            << "[" << tmpIndex1 << "," << tmpIndex2 << "]";
8372   else if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8373     output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
8374   else
8375     output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8376            << "[" << tmpIndex1 << "," << tmpIndex2 << "]";
8377 }
8378 
8379 void
8380 SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
8381                                              const temporary_terms_t &temporary_terms,
8382                                              const temporary_terms_idxs_t &temporary_terms_idxs,
8383                                              const deriv_node_temp_terms_t &tef_terms) const
8384 {
8385   assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8386 
8387   if (isLatexOutput(output_type))
8388     {
8389       output << R"(\frac{\partial^2 )" << datatree.symbol_table.getTeXName(symb_id)
8390              << R"(}{\partial )" << inputIndex1 << R"(\partial )" << inputIndex2 << "}(";
8391       writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8392       output << ")";
8393       return;
8394     }
8395 
8396   if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
8397     return;
8398 
8399   const int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8400   assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8401 
8402   const int tmpIndex1 = inputIndex1 - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
8403   const int tmpIndex2 = inputIndex2 - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
8404 
8405   int indx = getIndxInTefTerms(symb_id, tef_terms);
8406   if (second_deriv_symb_id == symb_id)
8407     if (isCOutput(output_type))
8408       output << "TEFDD_" << indx
8409              << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << " * TEFDD_" << indx << "_nrows + "
8410              << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8411     else
8412       output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms)
8413              << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8414   else if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8415     {
8416       if (isCOutput(output_type))
8417         output << "*";
8418       output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
8419     }
8420   else
8421     if (isCOutput(output_type))
8422       output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8423              << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << " * PROBLEM_" << indx << "_nrows"
8424              << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8425     else
8426       output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8427              << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8428 }
8429 
8430 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const8431 SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
8432                                                              const temporary_terms_t &temporary_terms,
8433                                                              const temporary_terms_idxs_t &temporary_terms_idxs,
8434                                                              deriv_node_temp_terms_t &tef_terms) const
8435 {
8436   assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8437   int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8438   assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8439 
8440   /* For a node with derivs provided by the user function, call the method
8441      on the non-derived node */
8442   if (second_deriv_symb_id == symb_id)
8443     {
8444       expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8445       parent->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs,
8446                                           tef_terms);
8447       return;
8448     }
8449 
8450   if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms))
8451     return;
8452 
8453   if (isCOutput(output_type))
8454     if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8455       {
8456         stringstream ending;
8457         ending << "_tefdd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
8458         output << "int nlhs" << ending.str() << " = 1;" << endl
8459                << "double *TEFDD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl
8460                << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8461                << "int nrhs" << ending.str() << " = 4;" << endl
8462                << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl
8463                << "mwSize dims" << ending.str() << "[2];" << endl;
8464 
8465         output << "dims" << ending.str() << "[0] = 1;" << endl
8466                << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl;
8467 
8468         output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
8469                << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl
8470                << "prhs" << ending.str() << "[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl
8471                << "prhs" << ending.str() << "[3] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl;
8472 
8473         int i = 0;
8474         for (auto argument : arguments)
8475           {
8476             output << "mxSetCell(prhs" << ending.str() << "[3], "
8477                    << i++ << ", "
8478                    << "mxCreateDoubleScalar("; // All external_function arguments are scalars
8479             argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8480             output << "));" << endl;
8481           }
8482 
8483         output << "mexCallMATLAB("
8484                << "nlhs" << ending.str() << ", "
8485                << "plhs" << ending.str() << ", "
8486                << "nrhs" << ending.str() << ", "
8487                << "prhs" << ending.str() << R"(, ")"
8488                << R"(hess_element");)" << endl;
8489 
8490         output << "TEFDD_fdd_" <<  getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
8491                << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8492       }
8493     else
8494       {
8495         tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8496         int indx = getIndxInTefTerms(second_deriv_symb_id, tef_terms);
8497         stringstream ending;
8498         ending << "_tefdd_def_" << indx;
8499 
8500         output << "int nlhs" << ending.str() << " = 1;" << endl
8501                << "double *TEFDD_def_" << indx << ";" << endl
8502                << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8503                << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl;
8504         writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
8505 
8506         output << "mexCallMATLAB("
8507                << "nlhs" << ending.str() << ", "
8508                << "plhs" << ending.str() << ", "
8509                << "nrhs" << ending.str() << ", "
8510                << "prhs" << ending.str() << R"(, ")"
8511                << datatree.symbol_table.getName(second_deriv_symb_id) << R"(");)" << endl;
8512 
8513         output << "TEFDD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8514       }
8515   else
8516     {
8517       if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8518         output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
8519                << " = hess_element('" << datatree.symbol_table.getName(symb_id) << "',"
8520                << inputIndex1 << "," << inputIndex2 << ",{";
8521       else
8522         {
8523           tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8524           output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8525                  << " = " << datatree.symbol_table.getName(second_deriv_symb_id) << "(";
8526         }
8527 
8528       writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8529 
8530       if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8531         output << "}";
8532       output << ");" << endl;
8533     }
8534 }
8535 
8536 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const8537 SecondDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
8538                                                                  const temporary_terms_t &temporary_terms,
8539                                                                  deriv_node_temp_terms_t &tef_terms,
8540                                                                  bool isdynamic) const
8541 {
8542   int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8543   assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8544 
8545   /* For a node with derivs provided by the user function, call the method
8546      on the non-derived node */
8547   if (second_deriv_symb_id == symb_id)
8548     {
8549       expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8550       parent->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
8551       return;
8552     }
8553 
8554   if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms))
8555     return;
8556 
8557   stringstream ef;
8558   if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8559     ef << R"({"second_deriv_external_function": {)"
8560        << R"("external_function_term": "TEFDD_fdd_)" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << R"(")"
8561        << R"(, "analytic_derivative": false)"
8562        << R"(, "wrt1": )" << inputIndex1
8563        << R"(, "wrt2": )" << inputIndex2
8564        << R"(, "value": ")" << datatree.symbol_table.getName(symb_id) << "(";
8565   else
8566     {
8567       tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8568       ef << R"({"second_deriv_external_function": {)"
8569          << R"("external_function_term": "TEFDD_def_)" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) << R"(")"
8570          << R"(, "analytic_derivative": true)"
8571          << R"(, "value": ")" << datatree.symbol_table.getName(second_deriv_symb_id) << "(";
8572     }
8573 
8574   writeJsonExternalFunctionArguments(ef, temporary_terms, tef_terms, isdynamic);
8575   ef << R"lit()"}})lit" << endl;
8576   efout.push_back(ef.str());
8577 }
8578 
8579 expr_t
8580 SecondDerivExternalFunctionNode::clone(DataTree &datatree) const
8581 {
8582   vector<expr_t> dynamic_arguments;
8583   for (auto argument : arguments)
8584     dynamic_arguments.push_back(argument->clone(datatree));
8585   return datatree.AddSecondDerivExternalFunction(symb_id, dynamic_arguments,
8586                                                  inputIndex1, inputIndex2);
8587 }
8588 
8589 expr_t
8590 SecondDerivExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const
8591 {
8592   return alt_datatree.AddSecondDerivExternalFunction(symb_id, alt_args, inputIndex1, inputIndex2);
8593 }
8594 
8595 expr_t
8596 SecondDerivExternalFunctionNode::toStatic(DataTree &static_datatree) const
8597 {
8598   vector<expr_t> static_arguments;
8599   for (auto argument : arguments)
8600     static_arguments.push_back(argument->toStatic(static_datatree));
8601   return static_datatree.AddSecondDerivExternalFunction(symb_id, static_arguments,
8602                                                         inputIndex1, inputIndex2);
8603 }
8604 
8605 void
8606 SecondDerivExternalFunctionNode::computeXrefs(EquationInfo &ei) const
8607 {
8608   vector<expr_t> dynamic_arguments;
8609   for (auto argument : arguments)
8610     argument->computeXrefs(ei);
8611 }
8612 
8613 void
8614 SecondDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number,
8615                                          bool lhs_rhs, const temporary_terms_t &temporary_terms,
8616                                          const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8617                                          const deriv_node_temp_terms_t &tef_terms) const
8618 {
8619   cerr << "SecondDerivExternalFunctionNode::compile: not implemented." << endl;
8620   exit(EXIT_FAILURE);
8621 }
8622 
8623 void
8624 SecondDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
8625                                                                bool lhs_rhs, const temporary_terms_t &temporary_terms,
8626                                                                const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8627                                                                deriv_node_temp_terms_t &tef_terms) const
8628 {
8629   cerr << "SecondDerivExternalFunctionNode::compileExternalFunctionOutput: not implemented." << endl;
8630   exit(EXIT_FAILURE);
8631 }
8632 
8633 function<bool (expr_t)>
8634 SecondDerivExternalFunctionNode::sameTefTermPredicate() const
8635 {
8636   int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8637   if (second_deriv_symb_id == symb_id)
8638     return [this](expr_t e) {
8639       auto e2 = dynamic_cast<ExternalFunctionNode *>(e);
8640       return (e2 && e2->symb_id == symb_id);
8641     };
8642   else
8643     return [this](expr_t e) {
8644       auto e2 = dynamic_cast<SecondDerivExternalFunctionNode *>(e);
8645       return (e2 && e2->symb_id == symb_id);
8646     };
8647 }
8648 
8649 VarExpectationNode::VarExpectationNode(DataTree &datatree_arg,
8650                                        int idx_arg,
8651                                        string model_name_arg) :
8652   ExprNode{datatree_arg, idx_arg},
8653   model_name{move(model_name_arg)}
8654 {
8655 }
8656 
8657 void
8658 VarExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
8659                                           map<pair<int, int>, temporary_terms_t> &temp_terms_map,
8660                                           map<expr_t, pair<int, pair<int, int>>> &reference_count,
8661                                           bool is_matlab) const
8662 {
8663   cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl;
8664   exit(EXIT_FAILURE);
8665 }
8666 
8667 void
8668 VarExpectationNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
8669                                           temporary_terms_t &temporary_terms,
8670                                           map<expr_t, pair<int, int>> &first_occurence,
8671                                           int Curr_block,
8672                                           vector< vector<temporary_terms_t>> &v_temporary_terms,
8673                                           int equation) const
8674 {
8675   cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl;
8676   exit(EXIT_FAILURE);
8677 }
8678 
8679 expr_t
8680 VarExpectationNode::toStatic(DataTree &static_datatree) const
8681 {
8682   cerr << "VarExpectationNode::toStatic not implemented." << endl;
8683   exit(EXIT_FAILURE);
8684 }
8685 
8686 expr_t
8687 VarExpectationNode::clone(DataTree &datatree) const
8688 {
8689   return datatree.AddVarExpectation(model_name);
8690 }
8691 
8692 void
8693 VarExpectationNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
8694                                 const temporary_terms_t &temporary_terms,
8695                                 const temporary_terms_idxs_t &temporary_terms_idxs,
8696                                 const deriv_node_temp_terms_t &tef_terms) const
8697 {
8698   assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8699 
8700   if (isLatexOutput(output_type))
8701     {
8702       output << "VAR_EXPECTATION(" << model_name << ')';
8703       return;
8704     }
8705 
8706   cerr << "VarExpectationNode::writeOutput not implemented for non-LaTeX." << endl;
8707   exit(EXIT_FAILURE);
8708 }
8709 
8710 int
8711 VarExpectationNode::maxEndoLead() const
8712 {
8713   cerr << "VarExpectationNode::maxEndoLead not implemented." << endl;
8714   exit(EXIT_FAILURE);
8715 }
8716 
8717 int
8718 VarExpectationNode::maxExoLead() const
8719 {
8720   cerr << "VarExpectationNode::maxExoLead not implemented." << endl;
8721   exit(EXIT_FAILURE);
8722 }
8723 
8724 int
8725 VarExpectationNode::maxEndoLag() const
8726 {
8727   cerr << "VarExpectationNode::maxEndoLead not implemented." << endl;
8728   exit(EXIT_FAILURE);
8729 }
8730 
8731 int
8732 VarExpectationNode::maxExoLag() const
8733 {
8734   cerr << "VarExpectationNode::maxExoLead not implemented." << endl;
8735   exit(EXIT_FAILURE);
8736 }
8737 
8738 int
8739 VarExpectationNode::maxLead() const
8740 {
8741   cerr << "VarExpectationNode::maxLead not implemented." << endl;
8742   exit(EXIT_FAILURE);
8743 }
8744 
8745 int
8746 VarExpectationNode::maxLag() const
8747 {
8748   cerr << "VarExpectationNode::maxLag not implemented." << endl;
8749   exit(EXIT_FAILURE);
8750 }
8751 
8752 int
8753 VarExpectationNode::maxLagWithDiffsExpanded() const
8754 {
8755   /* This node will be substituted by lagged variables, so in theory we should
8756      return a strictly positive value. But from here this value is not easy to
8757      compute.
8758      We return 0, because currently this function is only called from
8759      DynamicModel::setLeadsLagsOrig(), and the maximum lag will nevertheless be
8760      correctly computed because the maximum lag of the VAR will be taken into
8761      account via the corresponding equations. */
8762   return 0;
8763 }
8764 
8765 expr_t
8766 VarExpectationNode::undiff() const
8767 {
8768   cerr << "VarExpectationNode::undiff not implemented." << endl;
8769   exit(EXIT_FAILURE);
8770 }
8771 
8772 int
8773 VarExpectationNode::VarMinLag() const
8774 {
8775   cerr << "VarExpectationNode::VarMinLag not implemented." << endl;
8776   exit(EXIT_FAILURE);
8777 }
8778 
8779 int
8780 VarExpectationNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
8781 {
8782   cerr << "VarExpectationNode::VarMaxLag not implemented." << endl;
8783   exit(EXIT_FAILURE);
8784 }
8785 
8786 int
8787 VarExpectationNode::PacMaxLag(int lhs_symb_id) const
8788 {
8789   cerr << "VarExpectationNode::PacMaxLag not implemented." << endl;
8790   exit(EXIT_FAILURE);
8791 }
8792 
8793 int
8794 VarExpectationNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
8795 {
8796   return -1;
8797 }
8798 
8799 expr_t
8800 VarExpectationNode::decreaseLeadsLags(int n) const
8801 {
8802   cerr << "VarExpectationNode::decreaseLeadsLags not implemented." << endl;
8803   exit(EXIT_FAILURE);
8804 }
8805 
8806 void
8807 VarExpectationNode::prepareForDerivation()
8808 {
8809   cerr << "VarExpectationNode::prepareForDerivation not implemented." << endl;
8810   exit(EXIT_FAILURE);
8811 }
8812 
8813 expr_t
8814 VarExpectationNode::computeDerivative(int deriv_id)
8815 {
8816   cerr << "VarExpectationNode::computeDerivative not implemented." << endl;
8817   exit(EXIT_FAILURE);
8818 }
8819 
8820 expr_t
8821 VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
8822 {
8823   cerr << "VarExpectationNode::getChainRuleDerivative not implemented." << endl;
8824   exit(EXIT_FAILURE);
8825 }
8826 
8827 bool
8828 VarExpectationNode::containsExternalFunction() const
8829 {
8830   return false;
8831 }
8832 
8833 double
8834 VarExpectationNode::eval(const eval_context_t &eval_context) const noexcept(false)
8835 {
8836   throw EvalException();
8837 }
8838 
8839 int
8840 VarExpectationNode::countDiffs() const
8841 {
8842   cerr << "VarExpectationNode::countDiffs not implemented." << endl;
8843   exit(EXIT_FAILURE);
8844 }
8845 
8846 void
8847 VarExpectationNode::computeXrefs(EquationInfo &ei) const
8848 {
8849 }
8850 
8851 void
8852 VarExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
8853 {
8854   cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
8855   exit(EXIT_FAILURE);
8856 }
8857 
8858 void
8859 VarExpectationNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
8860 {
8861 }
8862 
8863 void
8864 VarExpectationNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
8865 {
8866   cerr << "VarExpectationNode::collectTemporary_terms not implemented." << endl;
8867   exit(EXIT_FAILURE);
8868 }
8869 
8870 void
8871 VarExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_number,
8872                             bool lhs_rhs, const temporary_terms_t &temporary_terms,
8873                             const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8874                             const deriv_node_temp_terms_t &tef_terms) const
8875 {
8876   cerr << "VarExpectationNode::compile not implemented." << endl;
8877   exit(EXIT_FAILURE);
8878 }
8879 
8880 pair<int, expr_t>
8881 VarExpectationNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
8882 {
8883   cerr << "VarExpectationNode::normalizeEquation not implemented." << endl;
8884   exit(EXIT_FAILURE);
8885 }
8886 
8887 expr_t
8888 VarExpectationNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
8889 {
8890   cerr << "VarExpectationNode::substituteEndoLeadGreaterThanTwo not implemented." << endl;
8891   exit(EXIT_FAILURE);
8892 }
8893 
8894 expr_t
8895 VarExpectationNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8896 {
8897   cerr << "VarExpectationNode::substituteEndoLagGreaterThanTwo not implemented." << endl;
8898   exit(EXIT_FAILURE);
8899 }
8900 
8901 expr_t
8902 VarExpectationNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
8903 {
8904   cerr << "VarExpectationNode::substituteExoLead not implemented." << endl;
8905   exit(EXIT_FAILURE);
8906 }
8907 
8908 expr_t
8909 VarExpectationNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8910 {
8911   cerr << "VarExpectationNode::substituteExoLag not implemented." << endl;
8912   exit(EXIT_FAILURE);
8913 }
8914 
8915 expr_t
8916 VarExpectationNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
8917 {
8918   return const_cast<VarExpectationNode *>(this);
8919 }
8920 
8921 expr_t
8922 VarExpectationNode::substituteAdl() const
8923 {
8924   return const_cast<VarExpectationNode *>(this);
8925 }
8926 
8927 expr_t
8928 VarExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
8929 {
8930   auto it = subst_table.find(model_name);
8931   if (it == subst_table.end())
8932     {
8933       cerr << "ERROR: unknown model '" << model_name << "' used in var_expectation expression" << endl;
8934       exit(EXIT_FAILURE);
8935     }
8936   return it->second;
8937 }
8938 
8939 void
8940 VarExpectationNode::findDiffNodes(lag_equivalence_table_t &nodes) const
8941 {
8942 }
8943 
8944 void
8945 VarExpectationNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
8946 {
8947 }
8948 
8949 int
8950 VarExpectationNode::findTargetVariable(int lhs_symb_id) const
8951 {
8952   return -1;
8953 }
8954 
8955 expr_t
8956 VarExpectationNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
8957                                    vector<BinaryOpNode *> &neweqs) const
8958 {
8959   return const_cast<VarExpectationNode *>(this);
8960 }
8961 
8962 expr_t
8963 VarExpectationNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8964 {
8965   return const_cast<VarExpectationNode *>(this);
8966 }
8967 
8968 expr_t
8969 VarExpectationNode::substitutePacExpectation(const string &name, expr_t subexpr)
8970 {
8971   return const_cast<VarExpectationNode *>(this);
8972 }
8973 
8974 expr_t
8975 VarExpectationNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8976 {
8977   cerr << "VarExpectationNode::differentiateForwardVars not implemented." << endl;
8978   exit(EXIT_FAILURE);
8979 }
8980 
8981 bool
8982 VarExpectationNode::containsPacExpectation(const string &pac_model_name) const
8983 {
8984   return false;
8985 }
8986 
8987 bool
8988 VarExpectationNode::containsEndogenous() const
8989 {
8990   cerr << "VarExpectationNode::containsEndogenous not implemented." << endl;
8991   exit(EXIT_FAILURE);
8992 }
8993 
8994 bool
8995 VarExpectationNode::containsExogenous() const
8996 {
8997   cerr << "VarExpectationNode::containsExogenous not implemented." << endl;
8998   exit(EXIT_FAILURE);
8999 }
9000 
9001 bool
9002 VarExpectationNode::isNumConstNodeEqualTo(double value) const
9003 {
9004   return false;
9005 }
9006 
9007 expr_t
9008 VarExpectationNode::decreaseLeadsLagsPredeterminedVariables() const
9009 {
9010   cerr << "VarExpectationNode::decreaseLeadsLagsPredeterminedVariables not implemented." << endl;
9011   exit(EXIT_FAILURE);
9012 }
9013 
9014 bool
9015 VarExpectationNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
9016 {
9017   return false;
9018 }
9019 
9020 expr_t
9021 VarExpectationNode::replaceTrendVar() const
9022 {
9023   cerr << "VarExpectationNode::replaceTrendVar not implemented." << endl;
9024   exit(EXIT_FAILURE);
9025 }
9026 
9027 expr_t
9028 VarExpectationNode::detrend(int symb_id, bool log_trend, expr_t trend) const
9029 {
9030   cerr << "VarExpectationNode::detrend not implemented." << endl;
9031   exit(EXIT_FAILURE);
9032 }
9033 
9034 expr_t
9035 VarExpectationNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
9036 {
9037   cerr << "VarExpectationNode::removeTrendLeadLag not implemented." << endl;
9038   exit(EXIT_FAILURE);
9039 }
9040 
9041 bool
9042 VarExpectationNode::isInStaticForm() const
9043 {
9044   cerr << "VarExpectationNode::isInStaticForm not implemented." << endl;
9045   exit(EXIT_FAILURE);
9046 }
9047 
9048 bool
9049 VarExpectationNode::isVarModelReferenced(const string &model_info_name) const
9050 {
9051   /* TODO: should check here whether the var_expectation_model is equal to the
9052      argument; we probably need a VarModelTable class to do that elegantly */
9053   return false;
9054 }
9055 
9056 void
9057 VarExpectationNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
9058 {
9059 }
9060 
9061 bool
9062 VarExpectationNode::isParamTimesEndogExpr() const
9063 {
9064   return false;
9065 }
9066 
9067 expr_t
9068 VarExpectationNode::substituteStaticAuxiliaryVariable() const
9069 {
9070   return const_cast<VarExpectationNode *>(this);
9071 }
9072 
9073 void
9074 VarExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
9075 {
9076   return;
9077 }
9078 
9079 expr_t
9080 VarExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
9081 {
9082   return const_cast<VarExpectationNode *>(this);
9083 }
9084 
9085 void
9086 VarExpectationNode::writeJsonAST(ostream &output) const
9087 {
9088   output << R"({"node_type" : "VarExpectationNode", )"
9089          << R"("name" : ")" << model_name << R"("})";
9090 }
9091 
9092 void
9093 VarExpectationNode::writeJsonOutput(ostream &output,
9094                                     const temporary_terms_t &temporary_terms,
9095                                     const deriv_node_temp_terms_t &tef_terms,
9096                                     bool isdynamic) const
9097 {
9098   output << "var_expectation("
9099          << "model_name = " << model_name
9100          << ")";
9101 }
9102 
9103 PacExpectationNode::PacExpectationNode(DataTree &datatree_arg,
9104                                        int idx_arg,
9105                                        string model_name_arg) :
9106   ExprNode{datatree_arg, idx_arg},
9107   model_name{move(model_name_arg)}
9108 {
9109 }
9110 
9111 void
9112 PacExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
9113                                           map<pair<int, int>, temporary_terms_t> &temp_terms_map,
9114                                           map<expr_t, pair<int, pair<int, int>>> &reference_count,
9115                                           bool is_matlab) const
9116 {
9117   temp_terms_map[derivOrder].insert(const_cast<PacExpectationNode *>(this));
9118 }
9119 
9120 void
9121 PacExpectationNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
9122                                           temporary_terms_t &temporary_terms,
9123                                           map<expr_t, pair<int, int>> &first_occurence,
9124                                           int Curr_block,
9125                                           vector< vector<temporary_terms_t>> &v_temporary_terms,
9126                                           int equation) const
9127 {
9128   expr_t this2 = const_cast<PacExpectationNode *>(this);
9129   temporary_terms.insert(this2);
9130   first_occurence[this2] = { Curr_block, equation };
9131   v_temporary_terms[Curr_block][equation].insert(this2);
9132 }
9133 
9134 expr_t
9135 PacExpectationNode::toStatic(DataTree &static_datatree) const
9136 {
9137   return static_datatree.AddPacExpectation(string(model_name));
9138 }
9139 
9140 expr_t
9141 PacExpectationNode::clone(DataTree &datatree) const
9142 {
9143   return datatree.AddPacExpectation(string(model_name));
9144 }
9145 
9146 void
9147 PacExpectationNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
9148                                 const temporary_terms_t &temporary_terms,
9149                                 const temporary_terms_idxs_t &temporary_terms_idxs,
9150                                 const deriv_node_temp_terms_t &tef_terms) const
9151 {
9152   assert(output_type != ExprNodeOutputType::matlabOutsideModel);
9153   if (isLatexOutput(output_type))
9154     {
9155       output << "PAC_EXPECTATION" << LEFT_PAR(output_type) << model_name << RIGHT_PAR(output_type);
9156       return;
9157     }
9158 }
9159 
9160 int
9161 PacExpectationNode::maxEndoLead() const
9162 {
9163   return 0;
9164 }
9165 
9166 int
9167 PacExpectationNode::maxExoLead() const
9168 {
9169   return 0;
9170 }
9171 
9172 int
9173 PacExpectationNode::maxEndoLag() const
9174 {
9175   return 0;
9176 }
9177 
9178 int
9179 PacExpectationNode::maxExoLag() const
9180 {
9181   return 0;
9182 }
9183 
9184 int
9185 PacExpectationNode::maxLead() const
9186 {
9187   return 0;
9188 }
9189 
9190 int
9191 PacExpectationNode::maxLag() const
9192 {
9193   return 0;
9194 }
9195 
9196 int
9197 PacExpectationNode::maxLagWithDiffsExpanded() const
9198 {
9199   // Same comment as in VarExpectationNode::maxLagWithDiffsExpanded()
9200   return 0;
9201 }
9202 
9203 expr_t
9204 PacExpectationNode::undiff() const
9205 {
9206   return const_cast<PacExpectationNode *>(this);
9207 }
9208 
9209 int
9210 PacExpectationNode::VarMinLag() const
9211 {
9212   return 1;
9213 }
9214 
9215 int
9216 PacExpectationNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
9217 {
9218   return 0;
9219 }
9220 
9221 int
9222 PacExpectationNode::PacMaxLag(int lhs_symb_id) const
9223 {
9224   return 0;
9225 }
9226 
9227 int
9228 PacExpectationNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
9229 {
9230   return -1;
9231 }
9232 
9233 expr_t
9234 PacExpectationNode::decreaseLeadsLags(int n) const
9235 {
9236   return const_cast<PacExpectationNode *>(this);
9237 }
9238 
9239 void
9240 PacExpectationNode::prepareForDerivation()
9241 {
9242   cerr << "PacExpectationNode::prepareForDerivation: shouldn't arrive here." << endl;
9243   exit(EXIT_FAILURE);
9244 }
9245 
9246 expr_t
9247 PacExpectationNode::computeDerivative(int deriv_id)
9248 {
9249   cerr << "PacExpectationNode::computeDerivative: shouldn't arrive here." << endl;
9250   exit(EXIT_FAILURE);
9251 }
9252 
9253 expr_t
9254 PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
9255 {
9256   cerr << "PacExpectationNode::getChainRuleDerivative: shouldn't arrive here." << endl;
9257   exit(EXIT_FAILURE);
9258 }
9259 
9260 bool
9261 PacExpectationNode::containsExternalFunction() const
9262 {
9263   return false;
9264 }
9265 
9266 double
9267 PacExpectationNode::eval(const eval_context_t &eval_context) const noexcept(false)
9268 {
9269   throw EvalException();
9270 }
9271 
9272 void
9273 PacExpectationNode::computeXrefs(EquationInfo &ei) const
9274 {
9275 }
9276 
9277 void
9278 PacExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
9279 {
9280   cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
9281   exit(EXIT_FAILURE);
9282 }
9283 
9284 void
9285 PacExpectationNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
9286 {
9287 }
9288 
9289 void
9290 PacExpectationNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
9291 {
9292   if (temporary_terms.find(const_cast<PacExpectationNode *>(this)) != temporary_terms.end())
9293     temporary_terms_inuse.insert(idx);
9294 }
9295 
9296 void
9297 PacExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_number,
9298                             bool lhs_rhs, const temporary_terms_t &temporary_terms,
9299                             const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
9300                             const deriv_node_temp_terms_t &tef_terms) const
9301 {
9302   cerr << "PacExpectationNode::compile not implemented." << endl;
9303   exit(EXIT_FAILURE);
9304 }
9305 
9306 int
9307 PacExpectationNode::countDiffs() const
9308 {
9309   return 0;
9310 }
9311 
9312 pair<int, expr_t>
9313 PacExpectationNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
9314 {
9315   cerr << "PacExpectationNode::normalizeEquation not implemented." << endl;
9316   exit(EXIT_FAILURE);
9317 }
9318 
9319 expr_t
9320 PacExpectationNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
9321 {
9322   return const_cast<PacExpectationNode *>(this);
9323 }
9324 
9325 expr_t
9326 PacExpectationNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9327 {
9328   return const_cast<PacExpectationNode *>(this);
9329 }
9330 
9331 expr_t
9332 PacExpectationNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
9333 {
9334   return const_cast<PacExpectationNode *>(this);
9335 }
9336 
9337 expr_t
9338 PacExpectationNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9339 {
9340   return const_cast<PacExpectationNode *>(this);
9341 }
9342 
9343 expr_t
9344 PacExpectationNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
9345 {
9346   return const_cast<PacExpectationNode *>(this);
9347 }
9348 
9349 expr_t
9350 PacExpectationNode::substituteAdl() const
9351 {
9352   return const_cast<PacExpectationNode *>(this);
9353 }
9354 
9355 expr_t
9356 PacExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
9357 {
9358   return const_cast<PacExpectationNode *>(this);
9359 }
9360 
9361 void
9362 PacExpectationNode::findDiffNodes(lag_equivalence_table_t &nodes) const
9363 {
9364 }
9365 
9366 void
9367 PacExpectationNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
9368 {
9369 }
9370 
9371 int
9372 PacExpectationNode::findTargetVariable(int lhs_symb_id) const
9373 {
9374   return -1;
9375 }
9376 
9377 expr_t
9378 PacExpectationNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
9379                                    vector<BinaryOpNode *> &neweqs) const
9380 {
9381   return const_cast<PacExpectationNode *>(this);
9382 }
9383 
9384 expr_t
9385 PacExpectationNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9386 {
9387   return const_cast<PacExpectationNode *>(this);
9388 }
9389 
9390 expr_t
9391 PacExpectationNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9392 {
9393   return const_cast<PacExpectationNode *>(this);
9394 }
9395 
9396 bool
9397 PacExpectationNode::containsPacExpectation(const string &pac_model_name) const
9398 {
9399   if (pac_model_name.empty())
9400     return true;
9401   else
9402     return pac_model_name == model_name;
9403 }
9404 
9405 bool
9406 PacExpectationNode::containsEndogenous() const
9407 {
9408   return true;
9409 }
9410 
9411 bool
9412 PacExpectationNode::containsExogenous() const
9413 {
9414   return false;
9415 }
9416 
9417 bool
9418 PacExpectationNode::isNumConstNodeEqualTo(double value) const
9419 {
9420   return false;
9421 }
9422 
9423 expr_t
9424 PacExpectationNode::decreaseLeadsLagsPredeterminedVariables() const
9425 {
9426   return const_cast<PacExpectationNode *>(this);
9427 }
9428 
9429 bool
9430 PacExpectationNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
9431 {
9432   return false;
9433 }
9434 
9435 expr_t
9436 PacExpectationNode::replaceTrendVar() const
9437 {
9438   return const_cast<PacExpectationNode *>(this);
9439 }
9440 
9441 expr_t
9442 PacExpectationNode::detrend(int symb_id, bool log_trend, expr_t trend) const
9443 {
9444   return const_cast<PacExpectationNode *>(this);
9445 }
9446 
9447 expr_t
9448 PacExpectationNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
9449 {
9450   return const_cast<PacExpectationNode *>(this);
9451 }
9452 
9453 bool
9454 PacExpectationNode::isInStaticForm() const
9455 {
9456   return false;
9457 }
9458 
9459 bool
9460 PacExpectationNode::isVarModelReferenced(const string &model_info_name) const
9461 {
9462   return model_name == model_info_name;
9463 }
9464 
9465 void
9466 PacExpectationNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
9467 {
9468 }
9469 
9470 expr_t
9471 PacExpectationNode::substituteStaticAuxiliaryVariable() const
9472 {
9473   return const_cast<PacExpectationNode *>(this);
9474 }
9475 
9476 void
9477 PacExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
9478 {
9479   return;
9480 }
9481 
9482 expr_t
9483 PacExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
9484 {
9485   return const_cast<PacExpectationNode *>(this);
9486 }
9487 
9488 void
9489 PacExpectationNode::writeJsonAST(ostream &output) const
9490 {
9491   output << R"({"node_type" : "PacExpectationNode", )"
9492          << R"("name" : ")" << model_name << R"("})";
9493 }
9494 
9495 void
9496 PacExpectationNode::writeJsonOutput(ostream &output,
9497                                     const temporary_terms_t &temporary_terms,
9498                                     const deriv_node_temp_terms_t &tef_terms,
9499                                     bool isdynamic) const
9500 {
9501   output << "pac_expectation("
9502          << "model_name = " << model_name
9503          << ")";
9504 }
9505 
9506 bool
9507 PacExpectationNode::isParamTimesEndogExpr() const
9508 {
9509   return false;
9510 }
9511 
9512 expr_t
9513 PacExpectationNode::substitutePacExpectation(const string &name, expr_t subexpr)
9514 {
9515   if (model_name != name)
9516     return const_cast<PacExpectationNode *>(this);
9517   return subexpr;
9518 }
9519 
9520 void
9521 ExprNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const
9522 {
9523   terms.emplace_back(const_cast<ExprNode *>(this), current_sign);
9524 }
9525 
9526 void
9527 UnaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const
9528 {
9529   if (op_code == UnaryOpcode::uminus)
9530     arg->decomposeAdditiveTerms(terms, -current_sign);
9531   else
9532     ExprNode::decomposeAdditiveTerms(terms, current_sign);
9533 }
9534 
9535 void
9536 BinaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const
9537 {
9538   if (op_code == BinaryOpcode::plus || op_code == BinaryOpcode::minus)
9539     {
9540       arg1->decomposeAdditiveTerms(terms, current_sign);
9541       if (op_code == BinaryOpcode::plus)
9542         arg2->decomposeAdditiveTerms(terms, current_sign);
9543       else
9544         arg2->decomposeAdditiveTerms(terms, -current_sign);
9545     }
9546   else
9547     ExprNode::decomposeAdditiveTerms(terms, current_sign);
9548 }
9549 
9550 tuple<int, int, int, double>
9551 ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const
9552 {
9553   int variable_id = -1, lag = 0, param_id = -1;
9554   double constant = 1.0;
9555   matchVTCTPHelper(variable_id, lag, param_id, constant, false);
9556   if (variable_obligatory && variable_id == -1)
9557     throw MatchFailureException{"No variable in this expression"};
9558   return {variable_id, lag, param_id, constant};
9559 }
9560 
9561 void
9562 ExprNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const
9563 {
9564   throw MatchFailureException{"Expression not allowed in linear combination of variables"};
9565 }
9566 
9567 void
9568 NumConstNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const
9569 {
9570   double myvalue = eval({});
9571   if (at_denominator)
9572     constant /= myvalue;
9573   else
9574     constant *= myvalue;
9575 }
9576 
9577 void
9578 VariableNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const
9579 {
9580   if (at_denominator)
9581     throw MatchFailureException{"A variable or parameter cannot appear at denominator"};
9582 
9583   SymbolType type = get_type();
9584   if (type == SymbolType::endogenous || type == SymbolType::exogenous)
9585     {
9586       if (var_id != -1)
9587         throw MatchFailureException{"More than one variable in this expression"};
9588       var_id = symb_id;
9589       lag = this->lag;
9590     }
9591   else if (type == SymbolType::parameter)
9592     {
9593       if (param_id != -1)
9594         throw MatchFailureException{"More than one parameter in this expression"};
9595       param_id = symb_id;
9596     }
9597   else
9598     throw MatchFailureException{"Symbol " + datatree.symbol_table.getName(symb_id) + " not allowed here"};
9599 }
9600 
9601 void
9602 UnaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const
9603 {
9604   if (op_code == UnaryOpcode::uminus)
9605     {
9606       constant = -constant;
9607       arg->matchVTCTPHelper(var_id, lag, param_id, constant, at_denominator);
9608     }
9609   else
9610     throw MatchFailureException{"Operator not allowed in this expression"};
9611 }
9612 
9613 void
9614 BinaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int &param_id, double &constant, bool at_denominator) const
9615 {
9616   if (op_code == BinaryOpcode::times || op_code == BinaryOpcode::divide)
9617     {
9618       arg1->matchVTCTPHelper(var_id, lag, param_id, constant, at_denominator);
9619       if (op_code == BinaryOpcode::times)
9620         arg2->matchVTCTPHelper(var_id, lag, param_id, constant, at_denominator);
9621       else
9622         arg2->matchVTCTPHelper(var_id, lag, param_id, constant, !at_denominator);
9623     }
9624   else
9625     throw MatchFailureException{"Operator not allowed in this expression"};
9626 }
9627 
9628 vector<tuple<int, int, int, double>>
9629 ExprNode::matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term) const
9630 {
9631   vector<pair<expr_t, int>> terms;
9632   decomposeAdditiveTerms(terms);
9633 
9634   vector<tuple<int, int, int, double>> result;
9635 
9636   for (const auto &it : terms)
9637     {
9638       expr_t term = it.first;
9639       int sign = it.second;
9640       auto m = term->matchVariableTimesConstantTimesParam(variable_obligatory_in_each_term);
9641       get<3>(m) *= sign;
9642       result.push_back(m);
9643     }
9644   return result;
9645 }
9646 
9647 pair<int, vector<tuple<int, int, int, double>>>
9648 ExprNode::matchParamTimesLinearCombinationOfVariables() const
9649 {
9650   auto bopn = dynamic_cast<const BinaryOpNode *>(this);
9651   if (!bopn || bopn->op_code != BinaryOpcode::times)
9652     throw MatchFailureException{"Not a multiplicative expression"};
9653 
9654   expr_t param = bopn->arg1, lincomb = bopn->arg2;
9655 
9656   auto is_param = [](expr_t e) {
9657                     auto vn = dynamic_cast<VariableNode *>(e);
9658                     return vn && vn->get_type() == SymbolType::parameter;
9659                   };
9660 
9661   if (!is_param(param))
9662     {
9663       swap(param, lincomb);
9664       if (!is_param(param))
9665         throw MatchFailureException{"No parameter on either side of the multiplication"};
9666     }
9667 
9668   return { dynamic_cast<VariableNode *>(param)->symb_id, lincomb->matchLinearCombinationOfVariables() };
9669 }
9670