1 /*
2 * Copyright © 2007-2021 Dynare Team
3 *
4 * This file is part of Dynare.
5 *
6 * Dynare is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation, either version 3 of the License, or
9 * (at your option) any later version.
10 *
11 * Dynare is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with Dynare. If not, see <http://www.gnu.org/licenses/>.
18 */
19
20 #include <iostream>
21 #include <algorithm>
22 #include <cassert>
23 #include <cmath>
24 #include <utility>
25 #include <limits>
26
27 #include "ExprNode.hh"
28 #include "DataTree.hh"
29 #include "ModFile.hh"
30
ExprNode(DataTree & datatree_arg,int idx_arg)31 ExprNode::ExprNode(DataTree &datatree_arg, int idx_arg) : datatree{datatree_arg}, idx{idx_arg}
32 {
33 }
34
35 expr_t
getDerivative(int deriv_id)36 ExprNode::getDerivative(int deriv_id)
37 {
38 if (!preparedForDerivation)
39 prepareForDerivation();
40
41 // Return zero if derivative is necessarily null (using symbolic a priori)
42 if (auto it = non_null_derivatives.find(deriv_id); it == non_null_derivatives.end())
43 return datatree.Zero;
44
45 // If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
46 if (auto it2 = derivatives.find(deriv_id); it2 != derivatives.end())
47 return it2->second;
48 else
49 {
50 expr_t d = computeDerivative(deriv_id);
51 derivatives[deriv_id] = d;
52 return d;
53 }
54 }
55
56 int
precedence(ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms) const57 ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
58 {
59 // For a constant, a variable, or a unary op, the precedence is maximal
60 return 100;
61 }
62
63 int
precedenceJson(const temporary_terms_t & temporary_terms) const64 ExprNode::precedenceJson(const temporary_terms_t &temporary_terms) const
65 {
66 // For a constant, a variable, or a unary op, the precedence is maximal
67 return 100;
68 }
69
70 int
cost(int cost,bool is_matlab) const71 ExprNode::cost(int cost, bool is_matlab) const
72 {
73 // For a terminal node, the cost is null
74 return 0;
75 }
76
77 int
cost(const temporary_terms_t & temp_terms_map,bool is_matlab) const78 ExprNode::cost(const temporary_terms_t &temp_terms_map, bool is_matlab) const
79 {
80 // For a terminal node, the cost is null
81 return 0;
82 }
83
84 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const85 ExprNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
86 {
87 // For a terminal node, the cost is null
88 return 0;
89 }
90
91 bool
checkIfTemporaryTermThenWrite(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs) const92 ExprNode::checkIfTemporaryTermThenWrite(ostream &output, ExprNodeOutputType output_type,
93 const temporary_terms_t &temporary_terms,
94 const temporary_terms_idxs_t &temporary_terms_idxs) const
95 {
96 if (auto it = temporary_terms.find(const_cast<ExprNode *>(this)); it == temporary_terms.end())
97 return false;
98
99 if (output_type == ExprNodeOutputType::matlabDynamicModelSparse)
100 output << "T" << idx << "(it_)";
101 else
102 if (output_type == ExprNodeOutputType::matlabStaticModelSparse)
103 output << "T" << idx;
104 else
105 {
106 auto it2 = temporary_terms_idxs.find(const_cast<ExprNode *>(this));
107 // It is the responsibility of the caller to ensure that all temporary terms have their index
108 assert(it2 != temporary_terms_idxs.end());
109 output << "T" << LEFT_ARRAY_SUBSCRIPT(output_type)
110 << it2->second + ARRAY_SUBSCRIPT_OFFSET(output_type)
111 << RIGHT_ARRAY_SUBSCRIPT(output_type);
112 }
113 return true;
114 }
115
116 pair<expr_t, int>
getLagEquivalenceClass() const117 ExprNode::getLagEquivalenceClass() const
118 {
119 int index = maxLead();
120
121 if (index == numeric_limits<int>::min())
122 index = 0; // If no variable in the expression, the equivalence class has size 1
123
124 return { decreaseLeadsLags(index), index };
125 }
126
127 void
collectVariables(SymbolType type,set<int> & result) const128 ExprNode::collectVariables(SymbolType type, set<int> &result) const
129 {
130 set<pair<int, int>> symbs_lags;
131 collectDynamicVariables(type, symbs_lags);
132 transform(symbs_lags.begin(), symbs_lags.end(), inserter(result, result.begin()),
133 [](auto x) { return x.first; });
134 }
135
136 void
collectEndogenous(set<pair<int,int>> & result) const137 ExprNode::collectEndogenous(set<pair<int, int>> &result) const
138 {
139 set<pair<int, int>> symb_ids;
140 collectDynamicVariables(SymbolType::endogenous, symb_ids);
141 for (const auto &symb_id : symb_ids)
142 result.emplace(datatree.symbol_table.getTypeSpecificID(symb_id.first), symb_id.second);
143 }
144
145 void
collectExogenous(set<pair<int,int>> & result) const146 ExprNode::collectExogenous(set<pair<int, int>> &result) const
147 {
148 set<pair<int, int>> symb_ids;
149 collectDynamicVariables(SymbolType::exogenous, symb_ids);
150 for (const auto &symb_id : symb_ids)
151 result.emplace(datatree.symbol_table.getTypeSpecificID(symb_id.first), symb_id.second);
152 }
153
154 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const155 ExprNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
156 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
157 map<expr_t, pair<int, pair<int, int>>> &reference_count,
158 bool is_matlab) const
159 {
160 // Nothing to do for a terminal node
161 }
162
163 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const164 ExprNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
165 temporary_terms_t &temporary_terms,
166 map<expr_t, pair<int, int>> &first_occurence,
167 int Curr_block,
168 vector<vector<temporary_terms_t>> &v_temporary_terms,
169 int equation) const
170 {
171 // Nothing to do for a terminal node
172 }
173
174 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const175 ExprNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
176 {
177 /* nothing to do */
178 return { 0, nullptr };
179 }
180
181 void
writeOutput(ostream & output) const182 ExprNode::writeOutput(ostream &output) const
183 {
184 writeOutput(output, ExprNodeOutputType::matlabOutsideModel, {}, {});
185 }
186
187 void
writeOutput(ostream & output,ExprNodeOutputType output_type) const188 ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type) const
189 {
190 writeOutput(output, output_type, {}, {});
191 }
192
193 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs) const194 ExprNode::writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs) const
195 {
196 writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, {});
197 }
198
199 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic) const200 ExprNode::compile(ostream &CompileCode, unsigned int &instruction_number,
201 bool lhs_rhs, const temporary_terms_t &temporary_terms,
202 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic) const
203 {
204 compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, {});
205 }
206
207 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const208 ExprNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
209 const temporary_terms_t &temporary_terms,
210 const temporary_terms_idxs_t &temporary_terms_idxs,
211 deriv_node_temp_terms_t &tef_terms) const
212 {
213 // Nothing to do
214 }
215
216 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const217 ExprNode::writeJsonExternalFunctionOutput(vector<string> &efout,
218 const temporary_terms_t &temporary_terms,
219 deriv_node_temp_terms_t &tef_terms,
220 bool isdynamic) const
221 {
222 // Nothing to do
223 }
224
225 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const226 ExprNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
227 bool lhs_rhs, const temporary_terms_t &temporary_terms,
228 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
229 deriv_node_temp_terms_t &tef_terms) const
230 {
231 // Nothing to do
232 }
233
234 VariableNode *
createEndoLeadAuxiliaryVarForMyself(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const235 ExprNode::createEndoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
236 {
237 int n = maxEndoLead();
238 assert(n >= 2);
239
240 if (auto it = subst_table.find(this);
241 it != subst_table.end())
242 return const_cast<VariableNode *>(it->second);
243
244 expr_t substexpr = decreaseLeadsLags(n-1);
245 int lag = n-2;
246
247 // Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag)
248 // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag))
249 while (lag >= 0)
250 {
251 expr_t orig_expr = decreaseLeadsLags(lag);
252 if (auto it = subst_table.find(orig_expr); it == subst_table.end())
253 {
254 int symb_id = datatree.symbol_table.addEndoLeadAuxiliaryVar(orig_expr->idx, substexpr);
255 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr)));
256 substexpr = datatree.AddVariable(symb_id, +1);
257 assert(dynamic_cast<VariableNode *>(substexpr));
258 subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr);
259 }
260 else
261 substexpr = const_cast<VariableNode *>(it->second);
262
263 lag--;
264 }
265
266 return dynamic_cast<VariableNode *>(substexpr);
267 }
268
269 VariableNode *
createExoLeadAuxiliaryVarForMyself(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const270 ExprNode::createExoLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
271 {
272 int n = maxExoLead();
273 assert(n >= 1);
274
275 if (auto it = subst_table.find(this);
276 it != subst_table.end())
277 return const_cast<VariableNode *>(it->second);
278
279 expr_t substexpr = decreaseLeadsLags(n);
280 int lag = n-1;
281
282 // Each iteration tries to create an auxvar such that auxvar(+1)=expr(-lag)
283 // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to expr(-lag-1) (resp. expr(-lag))
284 while (lag >= 0)
285 {
286 expr_t orig_expr = decreaseLeadsLags(lag);
287 if (auto it = subst_table.find(orig_expr); it == subst_table.end())
288 {
289 int symb_id = datatree.symbol_table.addExoLeadAuxiliaryVar(orig_expr->idx, substexpr);
290 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(symb_id, 0), substexpr)));
291 substexpr = datatree.AddVariable(symb_id, +1);
292 assert(dynamic_cast<VariableNode *>(substexpr));
293 subst_table[orig_expr] = dynamic_cast<VariableNode *>(substexpr);
294 }
295 else
296 substexpr = const_cast<VariableNode *>(it->second);
297
298 lag--;
299 }
300
301 return dynamic_cast<VariableNode *>(substexpr);
302 }
303
304 bool
isNumConstNodeEqualTo(double value) const305 ExprNode::isNumConstNodeEqualTo(double value) const
306 {
307 return false;
308 }
309
310 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const311 ExprNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
312 {
313 return false;
314 }
315
316 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const317 ExprNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
318 {
319 }
320
321 void
fillErrorCorrectionRow(int eqn,const vector<int> & nontarget_lhs,const vector<int> & target_lhs,map<tuple<int,int,int>,expr_t> & A0,map<tuple<int,int,int>,expr_t> & A0star) const322 ExprNode::fillErrorCorrectionRow(int eqn,
323 const vector<int> &nontarget_lhs,
324 const vector<int> &target_lhs,
325 map<tuple<int, int, int>, expr_t> &A0,
326 map<tuple<int, int, int>, expr_t> &A0star) const
327 {
328 vector<pair<expr_t, int>> terms;
329 decomposeAdditiveTerms(terms, 1);
330
331 for (const auto &it : terms)
332 {
333 pair<int, vector<tuple<int, int, int, double>>> m;
334 try
335 {
336 m = it.first->matchParamTimesLinearCombinationOfVariables();
337 for (auto &t : m.second)
338 get<3>(t) *= it.second; // Update sign of constants
339 }
340 catch (MatchFailureException &e)
341 {
342 /* FIXME: we should not just skip them, but rather verify that they are
343 autoregressive terms or residuals (probably by merging the two "fill" procedures) */
344 continue;
345 }
346
347 // Helper function
348 auto one_step_orig = [this](int symb_id) {
349 return datatree.symbol_table.isAuxiliaryVariable(symb_id) ?
350 datatree.symbol_table.getOrigSymbIdForDiffAuxVar(symb_id) : symb_id;
351 };
352
353 /* Verify that all variables belong to the error-correction term.
354 FIXME: same remark as above about skipping terms. */
355 bool not_ec = false;
356 for (const auto &t : m.second)
357 {
358 int vid = one_step_orig(get<0>(t));
359 not_ec = not_ec || (find(target_lhs.begin(), target_lhs.end(), vid) == target_lhs.end()
360 && find(nontarget_lhs.begin(), nontarget_lhs.end(), vid) == nontarget_lhs.end());
361 }
362 if (not_ec)
363 continue;
364
365 // Now fill the matrices
366 for (auto [var_id, lag, param_id, constant] : m.second)
367 {
368 int orig_vid = one_step_orig(var_id);
369 int orig_lag = datatree.symbol_table.isAuxiliaryVariable(var_id) ? -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(var_id) : lag;
370 if (find(target_lhs.begin(), target_lhs.end(), orig_vid) == target_lhs.end())
371 {
372 // This an LHS variable, so fill A0
373 if (constant != 1)
374 {
375 cerr << "ERROR in trend component model: LHS variable should not appear with a multiplicative constant in error correction term" << endl;
376 exit(EXIT_FAILURE);
377 }
378 if (param_id != -1)
379 {
380 cerr << "ERROR in trend component model: spurious parameter in error correction term" << endl;
381 exit(EXIT_FAILURE);
382 }
383 int colidx = static_cast<int>(distance(nontarget_lhs.begin(), find(nontarget_lhs.begin(), nontarget_lhs.end(), orig_vid)));
384 if (A0.find({eqn, -orig_lag, colidx}) != A0.end())
385 {
386 cerr << "ExprNode::fillErrorCorrection: Error filling A0 matrix: "
387 << "lag/symb_id encountered more than once in equation" << endl;
388 exit(EXIT_FAILURE);
389 }
390 A0[{eqn, -orig_lag, colidx}] = datatree.AddVariable(m.first);
391 }
392 else
393 {
394 // This is a target, so fill A0star
395 int colidx = static_cast<int>(distance(target_lhs.begin(), find(target_lhs.begin(), target_lhs.end(), orig_vid)));
396 expr_t e = datatree.AddTimes(datatree.AddVariable(m.first), datatree.AddPossiblyNegativeConstant(-constant));
397 if (param_id != -1)
398 e = datatree.AddTimes(e, datatree.AddVariable(param_id));
399 if (auto coor = tuple(eqn, -orig_lag, colidx); A0star.find(coor) == A0star.end())
400 A0star[coor] = e;
401 else
402 A0star[coor] = datatree.AddPlus(e, A0star[coor]);
403 }
404 }
405 }
406 }
407
NumConstNode(DataTree & datatree_arg,int idx_arg,int id_arg)408 NumConstNode::NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg) :
409 ExprNode{datatree_arg, idx_arg},
410 id{id_arg}
411 {
412 }
413
414 int
countDiffs() const415 NumConstNode::countDiffs() const
416 {
417 return 0;
418 }
419
420 void
prepareForDerivation()421 NumConstNode::prepareForDerivation()
422 {
423 preparedForDerivation = true;
424 // All derivatives are null, so non_null_derivatives is left empty
425 }
426
427 expr_t
computeDerivative(int deriv_id)428 NumConstNode::computeDerivative(int deriv_id)
429 {
430 return datatree.Zero;
431 }
432
433 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const434 NumConstNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
435 {
436 if (temporary_terms.find(const_cast<NumConstNode *>(this)) != temporary_terms.end())
437 temporary_terms_inuse.insert(idx);
438 }
439
440 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const441 NumConstNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
442 const temporary_terms_t &temporary_terms,
443 const temporary_terms_idxs_t &temporary_terms_idxs,
444 const deriv_node_temp_terms_t &tef_terms) const
445 {
446 if (!checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
447 output << datatree.num_constants.get(id);
448 }
449
450 void
writeJsonAST(ostream & output) const451 NumConstNode::writeJsonAST(ostream &output) const
452 {
453 output << R"({"node_type" : "NumConstNode", "value" : )";
454 output << std::stof(datatree.num_constants.get(id)) << "}";
455 }
456
457 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const458 NumConstNode::writeJsonOutput(ostream &output,
459 const temporary_terms_t &temporary_terms,
460 const deriv_node_temp_terms_t &tef_terms,
461 bool isdynamic) const
462 {
463 output << datatree.num_constants.get(id);
464 }
465
466 bool
containsExternalFunction() const467 NumConstNode::containsExternalFunction() const
468 {
469 return false;
470 }
471
472 double
eval(const eval_context_t & eval_context) const473 NumConstNode::eval(const eval_context_t &eval_context) const noexcept(false)
474 {
475 return datatree.num_constants.getDouble(id);
476 }
477
478 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const479 NumConstNode::compile(ostream &CompileCode, unsigned int &instruction_number,
480 bool lhs_rhs, const temporary_terms_t &temporary_terms,
481 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
482 const deriv_node_temp_terms_t &tef_terms) const
483 {
484 FLDC_ fldc(datatree.num_constants.getDouble(id));
485 fldc.write(CompileCode, instruction_number);
486 }
487
488 void
collectVARLHSVariable(set<expr_t> & result) const489 NumConstNode::collectVARLHSVariable(set<expr_t> &result) const
490 {
491 cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
492 exit(EXIT_FAILURE);
493 }
494
495 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const496 NumConstNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
497 {
498 }
499
500 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const501 NumConstNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
502 {
503 /* return the numercial constant */
504 return { 0, datatree.AddNonNegativeConstant(datatree.num_constants.get(id)) };
505 }
506
507 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)508 NumConstNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
509 {
510 return datatree.Zero;
511 }
512
513 expr_t
toStatic(DataTree & static_datatree) const514 NumConstNode::toStatic(DataTree &static_datatree) const
515 {
516 return static_datatree.AddNonNegativeConstant(datatree.num_constants.get(id));
517 }
518
519 void
computeXrefs(EquationInfo & ei) const520 NumConstNode::computeXrefs(EquationInfo &ei) const
521 {
522 }
523
524 expr_t
clone(DataTree & datatree) const525 NumConstNode::clone(DataTree &datatree) const
526 {
527 return datatree.AddNonNegativeConstant(datatree.num_constants.get(id));
528 }
529
530 int
maxEndoLead() const531 NumConstNode::maxEndoLead() const
532 {
533 return 0;
534 }
535
536 int
maxExoLead() const537 NumConstNode::maxExoLead() const
538 {
539 return 0;
540 }
541
542 int
maxEndoLag() const543 NumConstNode::maxEndoLag() const
544 {
545 return 0;
546 }
547
548 int
maxExoLag() const549 NumConstNode::maxExoLag() const
550 {
551 return 0;
552 }
553
554 int
maxLead() const555 NumConstNode::maxLead() const
556 {
557 return numeric_limits<int>::min();
558 }
559
560 int
maxLag() const561 NumConstNode::maxLag() const
562 {
563 return numeric_limits<int>::min();
564 }
565
566 int
maxLagWithDiffsExpanded() const567 NumConstNode::maxLagWithDiffsExpanded() const
568 {
569 return numeric_limits<int>::min();
570 }
571
572 expr_t
undiff() const573 NumConstNode::undiff() const
574 {
575 return const_cast<NumConstNode *>(this);
576 }
577
578 int
VarMinLag() const579 NumConstNode::VarMinLag() const
580 {
581 return 1;
582 }
583
584 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const585 NumConstNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
586 {
587 return 0;
588 }
589
590 int
PacMaxLag(int lhs_symb_id) const591 NumConstNode::PacMaxLag(int lhs_symb_id) const
592 {
593 return 0;
594 }
595
596 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const597 NumConstNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
598 {
599 return -1;
600 }
601
602 expr_t
decreaseLeadsLags(int n) const603 NumConstNode::decreaseLeadsLags(int n) const
604 {
605 return const_cast<NumConstNode *>(this);
606 }
607
608 expr_t
decreaseLeadsLagsPredeterminedVariables() const609 NumConstNode::decreaseLeadsLagsPredeterminedVariables() const
610 {
611 return const_cast<NumConstNode *>(this);
612 }
613
614 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const615 NumConstNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
616 {
617 return const_cast<NumConstNode *>(this);
618 }
619
620 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const621 NumConstNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
622 {
623 return const_cast<NumConstNode *>(this);
624 }
625
626 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const627 NumConstNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
628 {
629 return const_cast<NumConstNode *>(this);
630 }
631
632 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const633 NumConstNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
634 {
635 return const_cast<NumConstNode *>(this);
636 }
637
638 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const639 NumConstNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
640 {
641 return const_cast<NumConstNode *>(this);
642 }
643
644 expr_t
substituteAdl() const645 NumConstNode::substituteAdl() const
646 {
647 return const_cast<NumConstNode *>(this);
648 }
649
650 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const651 NumConstNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
652 {
653 return const_cast<NumConstNode *>(this);
654 }
655
656 void
findDiffNodes(lag_equivalence_table_t & nodes) const657 NumConstNode::findDiffNodes(lag_equivalence_table_t &nodes) const
658 {
659 }
660
661 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const662 NumConstNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
663 {
664 }
665
666 int
findTargetVariable(int lhs_symb_id) const667 NumConstNode::findTargetVariable(int lhs_symb_id) const
668 {
669 return -1;
670 }
671
672 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const673 NumConstNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
674 {
675 return const_cast<NumConstNode *>(this);
676 }
677
678 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const679 NumConstNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
680 {
681 return const_cast<NumConstNode *>(this);
682 }
683
684 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)685 NumConstNode::substitutePacExpectation(const string &name, expr_t subexpr)
686 {
687 return const_cast<NumConstNode *>(this);
688 }
689
690 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const691 NumConstNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
692 {
693 return const_cast<NumConstNode *>(this);
694 }
695
696 bool
isNumConstNodeEqualTo(double value) const697 NumConstNode::isNumConstNodeEqualTo(double value) const
698 {
699 if (datatree.num_constants.getDouble(id) == value)
700 return true;
701 else
702 return false;
703 }
704
705 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const706 NumConstNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
707 {
708 return false;
709 }
710
711 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const712 NumConstNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
713 {
714 }
715
716 bool
containsPacExpectation(const string & pac_model_name) const717 NumConstNode::containsPacExpectation(const string &pac_model_name) const
718 {
719 return false;
720 }
721
722 bool
containsEndogenous() const723 NumConstNode::containsEndogenous() const
724 {
725 return false;
726 }
727
728 bool
containsExogenous() const729 NumConstNode::containsExogenous() const
730 {
731 return false;
732 }
733
734 expr_t
replaceTrendVar() const735 NumConstNode::replaceTrendVar() const
736 {
737 return const_cast<NumConstNode *>(this);
738 }
739
740 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const741 NumConstNode::detrend(int symb_id, bool log_trend, expr_t trend) const
742 {
743 return const_cast<NumConstNode *>(this);
744 }
745
746 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const747 NumConstNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
748 {
749 return const_cast<NumConstNode *>(this);
750 }
751
752 bool
isInStaticForm() const753 NumConstNode::isInStaticForm() const
754 {
755 return true;
756 }
757
758 bool
isParamTimesEndogExpr() const759 NumConstNode::isParamTimesEndogExpr() const
760 {
761 return false;
762 }
763
764 bool
isVarModelReferenced(const string & model_info_name) const765 NumConstNode::isVarModelReferenced(const string &model_info_name) const
766 {
767 return false;
768 }
769
770 expr_t
substituteStaticAuxiliaryVariable() const771 NumConstNode::substituteStaticAuxiliaryVariable() const
772 {
773 return const_cast<NumConstNode *>(this);
774 }
775
776 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const777 NumConstNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
778 {
779 return;
780 }
781
782 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const783 NumConstNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
784 {
785 return const_cast<NumConstNode *>(this);
786 }
787
VariableNode(DataTree & datatree_arg,int idx_arg,int symb_id_arg,int lag_arg)788 VariableNode::VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg) :
789 ExprNode{datatree_arg, idx_arg},
790 symb_id{symb_id_arg},
791 lag{lag_arg}
792 {
793 // It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
794 assert(get_type() != SymbolType::externalFunction
795 && (lag == 0 || (get_type() != SymbolType::modelLocalVariable && get_type() != SymbolType::modFileLocalVariable)));
796 }
797
798 void
prepareForDerivation()799 VariableNode::prepareForDerivation()
800 {
801 if (preparedForDerivation)
802 return;
803
804 preparedForDerivation = true;
805
806 // Fill in non_null_derivatives
807 switch (get_type())
808 {
809 case SymbolType::endogenous:
810 case SymbolType::exogenous:
811 case SymbolType::exogenousDet:
812 case SymbolType::parameter:
813 case SymbolType::trend:
814 case SymbolType::logTrend:
815 // For a variable or a parameter, the only non-null derivative is with respect to itself
816 non_null_derivatives.insert(datatree.getDerivID(symb_id, lag));
817 break;
818 case SymbolType::modelLocalVariable:
819 datatree.getLocalVariable(symb_id)->prepareForDerivation();
820 // Non null derivatives are those of the value of the local parameter
821 non_null_derivatives = datatree.getLocalVariable(symb_id)->non_null_derivatives;
822 break;
823 case SymbolType::modFileLocalVariable:
824 case SymbolType::statementDeclaredVariable:
825 case SymbolType::unusedEndogenous:
826 // Such a variable is never derived
827 break;
828 case SymbolType::externalFunction:
829 case SymbolType::endogenousVAR:
830 case SymbolType::epilogue:
831 cerr << "VariableNode::prepareForDerivation: impossible case" << endl;
832 exit(EXIT_FAILURE);
833 case SymbolType::excludedVariable:
834 cerr << "VariableNode::prepareForDerivation: impossible case: "
835 << "You are trying to derive a variable that has been excluded via include_eqs/exclude_eqs: "
836 << datatree.symbol_table.getName(symb_id) << endl;
837 exit(EXIT_FAILURE);
838 }
839 }
840
841 expr_t
computeDerivative(int deriv_id)842 VariableNode::computeDerivative(int deriv_id)
843 {
844 switch (get_type())
845 {
846 case SymbolType::endogenous:
847 case SymbolType::exogenous:
848 case SymbolType::exogenousDet:
849 case SymbolType::parameter:
850 case SymbolType::trend:
851 case SymbolType::logTrend:
852 if (deriv_id == datatree.getDerivID(symb_id, lag))
853 return datatree.One;
854 else
855 return datatree.Zero;
856 case SymbolType::modelLocalVariable:
857 return datatree.getLocalVariable(symb_id)->getDerivative(deriv_id);
858 case SymbolType::modFileLocalVariable:
859 cerr << "modFileLocalVariable is not derivable" << endl;
860 exit(EXIT_FAILURE);
861 case SymbolType::statementDeclaredVariable:
862 cerr << "statementDeclaredVariable is not derivable" << endl;
863 exit(EXIT_FAILURE);
864 case SymbolType::unusedEndogenous:
865 cerr << "unusedEndogenous is not derivable" << endl;
866 exit(EXIT_FAILURE);
867 case SymbolType::externalFunction:
868 case SymbolType::endogenousVAR:
869 case SymbolType::epilogue:
870 case SymbolType::excludedVariable:
871 cerr << "VariableNode::computeDerivative: Impossible case!" << endl;
872 exit(EXIT_FAILURE);
873 }
874 // Suppress GCC warning
875 exit(EXIT_FAILURE);
876 }
877
878 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const879 VariableNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
880 {
881 if (temporary_terms.find(const_cast<VariableNode *>(this)) != temporary_terms.end())
882 temporary_terms_inuse.insert(idx);
883 if (get_type() == SymbolType::modelLocalVariable)
884 datatree.getLocalVariable(symb_id)->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
885 }
886
887 bool
containsExternalFunction() const888 VariableNode::containsExternalFunction() const
889 {
890 if (get_type() == SymbolType::modelLocalVariable)
891 return datatree.getLocalVariable(symb_id)->containsExternalFunction();
892
893 return false;
894 }
895
896 void
writeJsonAST(ostream & output) const897 VariableNode::writeJsonAST(ostream &output) const
898 {
899 output << R"({"node_type" : "VariableNode", )"
900 << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "type" : ")";
901 switch (get_type())
902 {
903 case SymbolType::endogenous:
904 output << "endogenous";
905 break;
906 case SymbolType::exogenous:
907 output << "exogenous";
908 break;
909 case SymbolType::exogenousDet:
910 output << "exogenousDet";
911 break;
912 case SymbolType::parameter:
913 output << "parameter";
914 break;
915 case SymbolType::modelLocalVariable:
916 output << "modelLocalVariable";
917 break;
918 case SymbolType::modFileLocalVariable:
919 output << "modFileLocalVariable";
920 break;
921 case SymbolType::externalFunction:
922 output << "externalFunction";
923 break;
924 case SymbolType::trend:
925 output << "trend";
926 break;
927 case SymbolType::statementDeclaredVariable:
928 output << "statementDeclaredVariable";
929 break;
930 case SymbolType::logTrend:
931 output << "logTrend:";
932 break;
933 case SymbolType::unusedEndogenous:
934 output << "unusedEndogenous";
935 break;
936 case SymbolType::endogenousVAR:
937 output << "endogenousVAR";
938 break;
939 case SymbolType::epilogue:
940 output << "epilogue";
941 break;
942 case SymbolType::excludedVariable:
943 cerr << "VariableNode::computeDerivative: Impossible case!" << endl;
944 exit(EXIT_FAILURE);
945 }
946 output << R"(", "lag" : )" << lag << "}";
947 }
948
949 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const950 VariableNode::writeJsonOutput(ostream &output,
951 const temporary_terms_t &temporary_terms,
952 const deriv_node_temp_terms_t &tef_terms,
953 bool isdynamic) const
954 {
955 if (temporary_terms.find(const_cast<VariableNode *>(this)) != temporary_terms.end())
956 {
957 output << "T" << idx;
958 return;
959 }
960
961 output << datatree.symbol_table.getName(symb_id);
962 if (isdynamic && lag != 0)
963 output << "(" << lag << ")";
964 }
965
966 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const967 VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
968 const temporary_terms_t &temporary_terms,
969 const temporary_terms_idxs_t &temporary_terms_idxs,
970 const deriv_node_temp_terms_t &tef_terms) const
971 {
972 auto type = get_type();
973 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
974 return;
975
976 if (isLatexOutput(output_type))
977 {
978 if (output_type == ExprNodeOutputType::latexDynamicSteadyStateOperator)
979 output << R"(\bar)";
980 output << "{" << datatree.symbol_table.getTeXName(symb_id) << "}";
981 if (output_type == ExprNodeOutputType::latexDynamicModel
982 && (type == SymbolType::endogenous || type == SymbolType::exogenous || type == SymbolType::exogenousDet || type == SymbolType::trend || type == SymbolType::logTrend))
983 {
984 output << "_{t";
985 if (lag != 0)
986 {
987 if (lag > 0)
988 output << "+";
989 output << lag;
990 }
991 output << "}";
992 }
993 return;
994 }
995
996 int i;
997 switch (int tsid = datatree.symbol_table.getTypeSpecificID(symb_id); type)
998 {
999 case SymbolType::parameter:
1000 if (output_type == ExprNodeOutputType::matlabOutsideModel)
1001 output << "M_.params" << "(" << tsid + 1 << ")";
1002 else
1003 output << "params" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + ARRAY_SUBSCRIPT_OFFSET(output_type) << RIGHT_ARRAY_SUBSCRIPT(output_type);
1004 break;
1005
1006 case SymbolType::modelLocalVariable:
1007 if (output_type == ExprNodeOutputType::matlabDynamicModelSparse || output_type == ExprNodeOutputType::matlabStaticModelSparse
1008 || output_type == ExprNodeOutputType::matlabDynamicSteadyStateOperator || output_type == ExprNodeOutputType::matlabDynamicSparseSteadyStateOperator
1009 || output_type == ExprNodeOutputType::CDynamicSteadyStateOperator)
1010 {
1011 output << "(";
1012 datatree.getLocalVariable(symb_id)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
1013 output << ")";
1014 }
1015 else
1016 /* We append underscores to avoid name clashes with "g1" or "oo_".
1017 But we probably never arrive here because MLV are temporary terms… */
1018 output << datatree.symbol_table.getName(symb_id) << "__";
1019 break;
1020
1021 case SymbolType::modFileLocalVariable:
1022 output << datatree.symbol_table.getName(symb_id);
1023 break;
1024
1025 case SymbolType::endogenous:
1026 switch (output_type)
1027 {
1028 case ExprNodeOutputType::juliaDynamicModel:
1029 case ExprNodeOutputType::matlabDynamicModel:
1030 case ExprNodeOutputType::CDynamicModel:
1031 i = datatree.getDynJacobianCol(datatree.getDerivID(symb_id, lag)) + ARRAY_SUBSCRIPT_OFFSET(output_type);
1032 output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1033 break;
1034 case ExprNodeOutputType::CStaticModel:
1035 case ExprNodeOutputType::juliaStaticModel:
1036 case ExprNodeOutputType::matlabStaticModel:
1037 case ExprNodeOutputType::matlabStaticModelSparse:
1038 i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
1039 output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1040 break;
1041 case ExprNodeOutputType::matlabDynamicModelSparse:
1042 i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
1043 if (lag > 0)
1044 output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1045 else if (lag < 0)
1046 output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1047 else
1048 output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1049 break;
1050 case ExprNodeOutputType::matlabOutsideModel:
1051 output << "oo_.steady_state(" << tsid + 1 << ")";
1052 break;
1053 case ExprNodeOutputType::juliaDynamicSteadyStateOperator:
1054 case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
1055 case ExprNodeOutputType::matlabDynamicSparseSteadyStateOperator:
1056 output << "steady_state" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + 1 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1057 break;
1058 case ExprNodeOutputType::CDynamicSteadyStateOperator:
1059 output << "steady_state[" << tsid << "]";
1060 break;
1061 case ExprNodeOutputType::juliaSteadyStateFile:
1062 case ExprNodeOutputType::steadyStateFile:
1063 output << "ys_" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + 1 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1064 break;
1065 case ExprNodeOutputType::matlabDseries:
1066 output << "ds." << datatree.symbol_table.getName(symb_id);
1067 if (lag != 0)
1068 output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
1069 break;
1070 case ExprNodeOutputType::epilogueFile:
1071 output << "ds." << datatree.symbol_table.getName(symb_id);
1072 output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1073 if (lag != 0)
1074 output << lag;
1075 output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1076 break;
1077 default:
1078 cerr << "VariableNode::writeOutput: should not reach this point" << endl;
1079 exit(EXIT_FAILURE);
1080 }
1081 break;
1082
1083 case SymbolType::exogenous:
1084 i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
1085 switch (output_type)
1086 {
1087 case ExprNodeOutputType::juliaDynamicModel:
1088 case ExprNodeOutputType::matlabDynamicModel:
1089 case ExprNodeOutputType::matlabDynamicModelSparse:
1090 if (lag > 0)
1091 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i
1092 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1093 else if (lag < 0)
1094 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i
1095 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1096 else
1097 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i
1098 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1099 break;
1100 case ExprNodeOutputType::CDynamicModel:
1101 if (lag == 0)
1102 output << "x[it_+" << i << "*nb_row_x]";
1103 else if (lag > 0)
1104 output << "x[it_+" << lag << "+" << i << "*nb_row_x]";
1105 else
1106 output << "x[it_" << lag << "+" << i << "*nb_row_x]";
1107 break;
1108 case ExprNodeOutputType::CStaticModel:
1109 case ExprNodeOutputType::juliaStaticModel:
1110 case ExprNodeOutputType::matlabStaticModel:
1111 case ExprNodeOutputType::matlabStaticModelSparse:
1112 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1113 break;
1114 case ExprNodeOutputType::matlabOutsideModel:
1115 assert(lag == 0);
1116 output << "oo_.exo_steady_state(" << i << ")";
1117 break;
1118 case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
1119 output << "oo_.exo_steady_state(" << i << ")";
1120 break;
1121 case ExprNodeOutputType::juliaSteadyStateFile:
1122 case ExprNodeOutputType::steadyStateFile:
1123 output << "exo_" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1124 break;
1125 case ExprNodeOutputType::matlabDseries:
1126 output << "ds." << datatree.symbol_table.getName(symb_id);
1127 if (lag != 0)
1128 output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
1129 break;
1130 case ExprNodeOutputType::epilogueFile:
1131 output << "ds." << datatree.symbol_table.getName(symb_id);
1132 output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1133 if (lag != 0)
1134 output << lag;
1135 output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1136 break;
1137 default:
1138 cerr << "VariableNode::writeOutput: should not reach this point" << endl;
1139 exit(EXIT_FAILURE);
1140 }
1141 break;
1142
1143 case SymbolType::exogenousDet:
1144 i = tsid + datatree.symbol_table.exo_nbr() + ARRAY_SUBSCRIPT_OFFSET(output_type);
1145 switch (output_type)
1146 {
1147 case ExprNodeOutputType::juliaDynamicModel:
1148 case ExprNodeOutputType::matlabDynamicModel:
1149 case ExprNodeOutputType::matlabDynamicModelSparse:
1150 if (lag > 0)
1151 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i
1152 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1153 else if (lag < 0)
1154 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i
1155 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1156 else
1157 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i
1158 << RIGHT_ARRAY_SUBSCRIPT(output_type);
1159 break;
1160 case ExprNodeOutputType::CDynamicModel:
1161 if (lag == 0)
1162 output << "x[it_+" << i << "*nb_row_x]";
1163 else if (lag > 0)
1164 output << "x[it_+" << lag << "+" << i << "*nb_row_x]";
1165 else
1166 output << "x[it_" << lag << "+" << i << "*nb_row_x]";
1167 break;
1168 case ExprNodeOutputType::CStaticModel:
1169 case ExprNodeOutputType::juliaStaticModel:
1170 case ExprNodeOutputType::matlabStaticModel:
1171 case ExprNodeOutputType::matlabStaticModelSparse:
1172 output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1173 break;
1174 case ExprNodeOutputType::matlabOutsideModel:
1175 assert(lag == 0);
1176 output << "oo_.exo_det_steady_state(" << tsid + 1 << ")";
1177 break;
1178 case ExprNodeOutputType::matlabDynamicSteadyStateOperator:
1179 output << "oo_.exo_det_steady_state(" << tsid + 1 << ")";
1180 break;
1181 case ExprNodeOutputType::juliaSteadyStateFile:
1182 case ExprNodeOutputType::steadyStateFile:
1183 output << "exo_" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
1184 break;
1185 case ExprNodeOutputType::matlabDseries:
1186 output << "ds." << datatree.symbol_table.getName(symb_id);
1187 if (lag != 0)
1188 output << LEFT_ARRAY_SUBSCRIPT(output_type) << lag << RIGHT_ARRAY_SUBSCRIPT(output_type);
1189 break;
1190 case ExprNodeOutputType::epilogueFile:
1191 output << "ds." << datatree.symbol_table.getName(symb_id);
1192 output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1193 if (lag != 0)
1194 output << lag;
1195 output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1196 break;
1197 default:
1198 cerr << "VariableNode::writeOutput: should not reach this point" << endl;
1199 exit(EXIT_FAILURE);
1200 }
1201 break;
1202 case SymbolType::epilogue:
1203 if (output_type == ExprNodeOutputType::epilogueFile)
1204 {
1205 output << "ds." << datatree.symbol_table.getName(symb_id);
1206 output << LEFT_ARRAY_SUBSCRIPT(output_type) << "t";
1207 if (lag != 0)
1208 output << lag;
1209 output << RIGHT_ARRAY_SUBSCRIPT(output_type);
1210 }
1211 else if (output_type == ExprNodeOutputType::matlabDseries)
1212 // Only writing dseries for epilogue_static, hence no need to check lag
1213 output << "ds." << datatree.symbol_table.getName(symb_id);
1214 else
1215 {
1216 cerr << "VariableNode::writeOutput: Impossible case" << endl;
1217 exit(EXIT_FAILURE);
1218 }
1219 break;
1220 case SymbolType::unusedEndogenous:
1221 cerr << "ERROR: You cannot use an endogenous variable in an expression if that variable has not been used in the model block." << endl;
1222 exit(EXIT_FAILURE);
1223 case SymbolType::externalFunction:
1224 case SymbolType::trend:
1225 case SymbolType::logTrend:
1226 case SymbolType::statementDeclaredVariable:
1227 case SymbolType::endogenousVAR:
1228 case SymbolType::excludedVariable:
1229 cerr << "VariableNode::writeOutput: Impossible case" << endl;
1230 exit(EXIT_FAILURE);
1231 }
1232 }
1233
1234 expr_t
substituteStaticAuxiliaryVariable() const1235 VariableNode::substituteStaticAuxiliaryVariable() const
1236 {
1237 if (get_type() == SymbolType::endogenous)
1238 try
1239 {
1240 return datatree.symbol_table.getAuxiliaryVarsExprNode(symb_id)->substituteStaticAuxiliaryVariable();
1241 }
1242 catch (SymbolTable::SearchFailedException &e)
1243 {
1244 }
1245 return const_cast<VariableNode *>(this);
1246 }
1247
1248 double
eval(const eval_context_t & eval_context) const1249 VariableNode::eval(const eval_context_t &eval_context) const noexcept(false)
1250 {
1251 if (get_type() == SymbolType::modelLocalVariable)
1252 return datatree.getLocalVariable(symb_id)->eval(eval_context);
1253
1254 auto it = eval_context.find(symb_id);
1255 if (it == eval_context.end())
1256 throw EvalException();
1257
1258 return it->second;
1259 }
1260
1261 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const1262 VariableNode::compile(ostream &CompileCode, unsigned int &instruction_number,
1263 bool lhs_rhs, const temporary_terms_t &temporary_terms,
1264 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
1265 const deriv_node_temp_terms_t &tef_terms) const
1266 {
1267 auto type = get_type();
1268 if (type == SymbolType::modelLocalVariable || type == SymbolType::modFileLocalVariable)
1269 datatree.getLocalVariable(symb_id)->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
1270 else
1271 {
1272 int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
1273 if (type == SymbolType::exogenousDet)
1274 tsid += datatree.symbol_table.exo_nbr();
1275 if (!lhs_rhs)
1276 {
1277 if (dynamic)
1278 {
1279 if (steady_dynamic) // steady state values in a dynamic model
1280 {
1281 FLDVS_ fldvs{static_cast<uint8_t>(type), static_cast<unsigned int>(tsid)};
1282 fldvs.write(CompileCode, instruction_number);
1283 }
1284 else
1285 {
1286 if (type == SymbolType::parameter)
1287 {
1288 FLDV_ fldv{static_cast<int>(type), static_cast<unsigned int>(tsid)};
1289 fldv.write(CompileCode, instruction_number);
1290 }
1291 else
1292 {
1293 FLDV_ fldv{static_cast<int>(type), static_cast<unsigned int>(tsid), lag};
1294 fldv.write(CompileCode, instruction_number);
1295 }
1296 }
1297 }
1298 else
1299 {
1300 FLDSV_ fldsv{static_cast<uint8_t>(type), static_cast<unsigned int>(tsid)};
1301 fldsv.write(CompileCode, instruction_number);
1302 }
1303 }
1304 else
1305 {
1306 if (dynamic)
1307 {
1308 if (steady_dynamic) // steady state values in a dynamic model
1309 {
1310 cerr << "Impossible case: steady_state in rhs of equation" << endl;
1311 exit(EXIT_FAILURE);
1312 }
1313 else
1314 {
1315 if (type == SymbolType::parameter)
1316 {
1317 FSTPV_ fstpv{static_cast<int>(type), static_cast<unsigned int>(tsid)};
1318 fstpv.write(CompileCode, instruction_number);
1319 }
1320 else
1321 {
1322 FSTPV_ fstpv{static_cast<int>(type), static_cast<unsigned int>(tsid), lag};
1323 fstpv.write(CompileCode, instruction_number);
1324 }
1325 }
1326 }
1327 else
1328 {
1329 FSTPSV_ fstpsv{static_cast<uint8_t>(type), static_cast<unsigned int>(tsid)};
1330 fstpsv.write(CompileCode, instruction_number);
1331 }
1332 }
1333 }
1334 }
1335
1336 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const1337 VariableNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
1338 temporary_terms_t &temporary_terms,
1339 map<expr_t, pair<int, int>> &first_occurence,
1340 int Curr_block,
1341 vector<vector<temporary_terms_t>> &v_temporary_terms,
1342 int equation) const
1343 {
1344 if (get_type() == SymbolType::modelLocalVariable)
1345 datatree.getLocalVariable(symb_id)->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
1346 }
1347
1348 void
collectVARLHSVariable(set<expr_t> & result) const1349 VariableNode::collectVARLHSVariable(set<expr_t> &result) const
1350 {
1351 if (get_type() == SymbolType::endogenous && lag == 0)
1352 result.insert(const_cast<VariableNode *>(this));
1353 else
1354 {
1355 cerr << "ERROR: you can only have endogenous variables or unary ops on LHS of VAR" << endl;
1356 exit(EXIT_FAILURE);
1357 }
1358 }
1359
1360 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const1361 VariableNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
1362 {
1363 if (get_type() == type_arg)
1364 result.emplace(symb_id, lag);
1365 if (get_type() == SymbolType::modelLocalVariable)
1366 datatree.getLocalVariable(symb_id)->collectDynamicVariables(type_arg, result);
1367 }
1368
1369 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const1370 VariableNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
1371 {
1372 /* The equation has to be normalized with respect to the current endogenous variable ascribed to it.
1373 The two input arguments are :
1374 - The ID of the endogenous variable associated to the equation.
1375 - The list of operators and operands needed to normalize the equation*
1376
1377 The pair returned by NormalizeEquation is composed of
1378 - a flag indicating if the expression returned contains (flag = 1) or not (flag = 0)
1379 the endogenous variable related to the equation.
1380 If the expression contains more than one occurence of the associated endogenous variable,
1381 the flag is equal to 2.
1382 - an expression equal to the RHS if flag = 0 and equal to NULL elsewhere
1383 */
1384 if (get_type() == SymbolType::endogenous)
1385 {
1386 if (datatree.symbol_table.getTypeSpecificID(symb_id) == var_endo && lag == 0)
1387 /* the endogenous variable */
1388 return { 1, nullptr };
1389 else
1390 return { 0, datatree.AddVariable(symb_id, lag) };
1391 }
1392 else
1393 {
1394 if (get_type() == SymbolType::parameter)
1395 return { 0, datatree.AddVariable(symb_id, 0) };
1396 else
1397 return { 0, datatree.AddVariable(symb_id, lag) };
1398 }
1399 }
1400
1401 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)1402 VariableNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
1403 {
1404 switch (get_type())
1405 {
1406 case SymbolType::endogenous:
1407 case SymbolType::exogenous:
1408 case SymbolType::exogenousDet:
1409 case SymbolType::parameter:
1410 case SymbolType::trend:
1411 case SymbolType::logTrend:
1412 if (deriv_id == datatree.getDerivID(symb_id, lag))
1413 return datatree.One;
1414 // If there is in the equation a recursive variable we could use a chaine rule derivation
1415 else if (auto it = recursive_variables.find(datatree.getDerivID(symb_id, lag));
1416 it != recursive_variables.end())
1417 {
1418 map<int, expr_t> recursive_vars2(recursive_variables);
1419 recursive_vars2.erase(it->first);
1420 return datatree.AddUMinus(it->second->getChainRuleDerivative(deriv_id, recursive_vars2));
1421 }
1422 else
1423 return datatree.Zero;
1424
1425 case SymbolType::modelLocalVariable:
1426 return datatree.getLocalVariable(symb_id)->getChainRuleDerivative(deriv_id, recursive_variables);
1427 case SymbolType::modFileLocalVariable:
1428 cerr << "modFileLocalVariable is not derivable" << endl;
1429 exit(EXIT_FAILURE);
1430 case SymbolType::statementDeclaredVariable:
1431 cerr << "statementDeclaredVariable is not derivable" << endl;
1432 exit(EXIT_FAILURE);
1433 case SymbolType::unusedEndogenous:
1434 cerr << "unusedEndogenous is not derivable" << endl;
1435 exit(EXIT_FAILURE);
1436 case SymbolType::externalFunction:
1437 case SymbolType::endogenousVAR:
1438 case SymbolType::epilogue:
1439 case SymbolType::excludedVariable:
1440 cerr << "VariableNode::getChainRuleDerivative: Impossible case" << endl;
1441 exit(EXIT_FAILURE);
1442 }
1443 // Suppress GCC warning
1444 exit(EXIT_FAILURE);
1445 }
1446
1447 expr_t
toStatic(DataTree & static_datatree) const1448 VariableNode::toStatic(DataTree &static_datatree) const
1449 {
1450 return static_datatree.AddVariable(symb_id);
1451 }
1452
1453 void
computeXrefs(EquationInfo & ei) const1454 VariableNode::computeXrefs(EquationInfo &ei) const
1455 {
1456 switch (get_type())
1457 {
1458 case SymbolType::endogenous:
1459 ei.endo.emplace(symb_id, lag);
1460 break;
1461 case SymbolType::exogenous:
1462 ei.exo.emplace(symb_id, lag);
1463 break;
1464 case SymbolType::exogenousDet:
1465 ei.exo_det.emplace(symb_id, lag);
1466 break;
1467 case SymbolType::parameter:
1468 ei.param.emplace(symb_id, 0);
1469 break;
1470 case SymbolType::modFileLocalVariable:
1471 datatree.getLocalVariable(symb_id)->computeXrefs(ei);
1472 break;
1473 case SymbolType::trend:
1474 case SymbolType::logTrend:
1475 case SymbolType::modelLocalVariable:
1476 case SymbolType::statementDeclaredVariable:
1477 case SymbolType::unusedEndogenous:
1478 case SymbolType::externalFunction:
1479 case SymbolType::endogenousVAR:
1480 case SymbolType::epilogue:
1481 case SymbolType::excludedVariable:
1482 break;
1483 }
1484 }
1485
1486 SymbolType
get_type() const1487 VariableNode::get_type() const
1488 {
1489 return datatree.symbol_table.getType(symb_id);
1490 }
1491
1492 expr_t
clone(DataTree & datatree) const1493 VariableNode::clone(DataTree &datatree) const
1494 {
1495 return datatree.AddVariable(symb_id, lag);
1496 }
1497
1498 int
maxEndoLead() const1499 VariableNode::maxEndoLead() const
1500 {
1501 switch (get_type())
1502 {
1503 case SymbolType::endogenous:
1504 return max(lag, 0);
1505 case SymbolType::modelLocalVariable:
1506 return datatree.getLocalVariable(symb_id)->maxEndoLead();
1507 default:
1508 return 0;
1509 }
1510 }
1511
1512 int
maxExoLead() const1513 VariableNode::maxExoLead() const
1514 {
1515 switch (get_type())
1516 {
1517 case SymbolType::exogenous:
1518 return max(lag, 0);
1519 case SymbolType::modelLocalVariable:
1520 return datatree.getLocalVariable(symb_id)->maxExoLead();
1521 default:
1522 return 0;
1523 }
1524 }
1525
1526 int
maxEndoLag() const1527 VariableNode::maxEndoLag() const
1528 {
1529 switch (get_type())
1530 {
1531 case SymbolType::endogenous:
1532 return max(-lag, 0);
1533 case SymbolType::modelLocalVariable:
1534 return datatree.getLocalVariable(symb_id)->maxEndoLag();
1535 default:
1536 return 0;
1537 }
1538 }
1539
1540 int
maxExoLag() const1541 VariableNode::maxExoLag() const
1542 {
1543 switch (get_type())
1544 {
1545 case SymbolType::exogenous:
1546 return max(-lag, 0);
1547 case SymbolType::modelLocalVariable:
1548 return datatree.getLocalVariable(symb_id)->maxExoLag();
1549 default:
1550 return 0;
1551 }
1552 }
1553
1554 int
maxLead() const1555 VariableNode::maxLead() const
1556 {
1557 switch (get_type())
1558 {
1559 case SymbolType::endogenous:
1560 case SymbolType::exogenous:
1561 case SymbolType::exogenousDet:
1562 return lag;
1563 case SymbolType::modelLocalVariable:
1564 return datatree.getLocalVariable(symb_id)->maxLead();
1565 default:
1566 return 0;
1567 }
1568 }
1569
1570 int
VarMinLag() const1571 VariableNode::VarMinLag() const
1572 {
1573 switch (get_type())
1574 {
1575 case SymbolType::endogenous:
1576 return -lag;
1577 case SymbolType::exogenous:
1578 if (lag > 0)
1579 return -lag;
1580 else
1581 return 1; // Can have contemporaneus exog in VAR
1582 case SymbolType::modelLocalVariable:
1583 return datatree.getLocalVariable(symb_id)->VarMinLag();
1584 default:
1585 return 1;
1586 }
1587 }
1588
1589 int
maxLag() const1590 VariableNode::maxLag() const
1591 {
1592 switch (get_type())
1593 {
1594 case SymbolType::endogenous:
1595 case SymbolType::exogenous:
1596 case SymbolType::exogenousDet:
1597 return -lag;
1598 case SymbolType::modelLocalVariable:
1599 return datatree.getLocalVariable(symb_id)->maxLag();
1600 default:
1601 return 0;
1602 }
1603 }
1604
1605 int
maxLagWithDiffsExpanded() const1606 VariableNode::maxLagWithDiffsExpanded() const
1607 {
1608 switch (get_type())
1609 {
1610 case SymbolType::endogenous:
1611 case SymbolType::exogenous:
1612 case SymbolType::exogenousDet:
1613 case SymbolType::epilogue:
1614 return -lag;
1615 case SymbolType::modelLocalVariable:
1616 return datatree.getLocalVariable(symb_id)->maxLagWithDiffsExpanded();
1617 default:
1618 return 0;
1619 }
1620 }
1621
1622 expr_t
undiff() const1623 VariableNode::undiff() const
1624 {
1625 return const_cast<VariableNode *>(this);
1626 }
1627
1628 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const1629 VariableNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
1630 {
1631 auto [lag_equiv_repr, index] = getLagEquivalenceClass();
1632 if (lhs_lag_equiv.find(lag_equiv_repr) == lhs_lag_equiv.end())
1633 return 0;
1634 return maxLag();
1635 }
1636
1637 int
PacMaxLag(int lhs_symb_id) const1638 VariableNode::PacMaxLag(int lhs_symb_id) const
1639 {
1640 if (get_type() == SymbolType::modelLocalVariable)
1641 return datatree.getLocalVariable(symb_id)->PacMaxLag(lhs_symb_id);
1642
1643 if (lhs_symb_id == symb_id)
1644 return -lag;
1645 return 0;
1646 }
1647
1648 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const1649 VariableNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
1650 {
1651 return -1;
1652 }
1653
1654 expr_t
substituteAdl() const1655 VariableNode::substituteAdl() const
1656 {
1657 /* Do not recurse into model-local variables definition, rather do it at the
1658 DynamicModel method level (see the comment there) */
1659 return const_cast<VariableNode *>(this);
1660 }
1661
1662 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const1663 VariableNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
1664 {
1665 if (get_type() == SymbolType::modelLocalVariable)
1666 return datatree.getLocalVariable(symb_id)->substituteVarExpectation(subst_table);
1667
1668 return const_cast<VariableNode *>(this);
1669 }
1670
1671 void
findDiffNodes(lag_equivalence_table_t & nodes) const1672 VariableNode::findDiffNodes(lag_equivalence_table_t &nodes) const
1673 {
1674 if (get_type() == SymbolType::modelLocalVariable)
1675 datatree.getLocalVariable(symb_id)->findDiffNodes(nodes);
1676 }
1677
1678 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const1679 VariableNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
1680 {
1681 if (get_type() == SymbolType::modelLocalVariable)
1682 datatree.getLocalVariable(symb_id)->findUnaryOpNodesForAuxVarCreation(nodes);
1683 }
1684
1685 int
findTargetVariable(int lhs_symb_id) const1686 VariableNode::findTargetVariable(int lhs_symb_id) const
1687 {
1688 if (get_type() == SymbolType::modelLocalVariable)
1689 return datatree.getLocalVariable(symb_id)->findTargetVariable(lhs_symb_id);
1690
1691 return -1;
1692 }
1693
1694 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1695 VariableNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
1696 vector<BinaryOpNode *> &neweqs) const
1697 {
1698 if (get_type() == SymbolType::modelLocalVariable)
1699 return datatree.getLocalVariable(symb_id)->substituteDiff(nodes, subst_table, neweqs);
1700
1701 return const_cast<VariableNode *>(this);
1702 }
1703
1704 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1705 VariableNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1706 {
1707 if (get_type() == SymbolType::modelLocalVariable)
1708 return datatree.getLocalVariable(symb_id)->substituteUnaryOpNodes(nodes, subst_table, neweqs);
1709
1710 return const_cast<VariableNode *>(this);
1711 }
1712
1713 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)1714 VariableNode::substitutePacExpectation(const string &name, expr_t subexpr)
1715 {
1716 if (get_type() == SymbolType::modelLocalVariable)
1717 return datatree.getLocalVariable(symb_id)->substitutePacExpectation(name, subexpr);
1718
1719 return const_cast<VariableNode *>(this);
1720 }
1721
1722 expr_t
decreaseLeadsLags(int n) const1723 VariableNode::decreaseLeadsLags(int n) const
1724 {
1725 switch (get_type())
1726 {
1727 case SymbolType::endogenous:
1728 case SymbolType::exogenous:
1729 case SymbolType::exogenousDet:
1730 case SymbolType::trend:
1731 case SymbolType::logTrend:
1732 return datatree.AddVariable(symb_id, lag-n);
1733 case SymbolType::modelLocalVariable:
1734 return datatree.getLocalVariable(symb_id)->decreaseLeadsLags(n);
1735 default:
1736 return const_cast<VariableNode *>(this);
1737 }
1738 }
1739
1740 expr_t
decreaseLeadsLagsPredeterminedVariables() const1741 VariableNode::decreaseLeadsLagsPredeterminedVariables() const
1742 {
1743 /* Do not recurse into model-local variables definitions, since MLVs are
1744 already handled by DynamicModel::transformPredeterminedVariables().
1745 This is also necessary because of #65. */
1746 if (datatree.symbol_table.isPredetermined(symb_id))
1747 return decreaseLeadsLags(1);
1748 else
1749 return const_cast<VariableNode *>(this);
1750 }
1751
1752 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const1753 VariableNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
1754 {
1755 switch (get_type())
1756 {
1757 case SymbolType::endogenous:
1758 if (lag <= 1)
1759 return const_cast<VariableNode *>(this);
1760 else
1761 return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
1762 case SymbolType::modelLocalVariable:
1763 if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 1)
1764 return const_cast<VariableNode *>(this);
1765 else
1766 return value->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
1767 default:
1768 return const_cast<VariableNode *>(this);
1769 }
1770 }
1771
1772 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1773 VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1774 {
1775 VariableNode *substexpr;
1776 int cur_lag;
1777 switch (get_type())
1778 {
1779 case SymbolType::endogenous:
1780 if (lag >= -1)
1781 return const_cast<VariableNode *>(this);
1782
1783 if (auto it = subst_table.find(this); it != subst_table.end())
1784 return const_cast<VariableNode *>(it->second);
1785
1786 substexpr = datatree.AddVariable(symb_id, -1);
1787 cur_lag = -2;
1788
1789 // Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag)
1790 // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag))
1791 while (cur_lag >= lag)
1792 {
1793 VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag);
1794 if (auto it = subst_table.find(orig_expr); it == subst_table.end())
1795 {
1796 int aux_symb_id = datatree.symbol_table.addEndoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
1797 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr)));
1798 substexpr = datatree.AddVariable(aux_symb_id, -1);
1799 subst_table[orig_expr] = substexpr;
1800 }
1801 else
1802 substexpr = const_cast<VariableNode *>(it->second);
1803
1804 cur_lag--;
1805 }
1806 return substexpr;
1807
1808 case SymbolType::modelLocalVariable:
1809 if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLag() <= 1)
1810 return const_cast<VariableNode *>(this);
1811 else
1812 return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
1813 default:
1814 return const_cast<VariableNode *>(this);
1815 }
1816 }
1817
1818 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const1819 VariableNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
1820 {
1821 switch (get_type())
1822 {
1823 case SymbolType::exogenous:
1824 if (lag <= 0)
1825 return const_cast<VariableNode *>(this);
1826 else
1827 return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
1828 case SymbolType::modelLocalVariable:
1829 if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLead() == 0)
1830 return const_cast<VariableNode *>(this);
1831 else
1832 return value->substituteExoLead(subst_table, neweqs, deterministic_model);
1833 default:
1834 return const_cast<VariableNode *>(this);
1835 }
1836 }
1837
1838 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1839 VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1840 {
1841 VariableNode *substexpr;
1842 int cur_lag;
1843 switch (get_type())
1844 {
1845 case SymbolType::exogenous:
1846 if (lag >= 0)
1847 return const_cast<VariableNode *>(this);
1848
1849 if (auto it = subst_table.find(this); it != subst_table.end())
1850 return const_cast<VariableNode *>(it->second);
1851
1852 substexpr = datatree.AddVariable(symb_id, 0);
1853 cur_lag = -1;
1854
1855 // Each iteration tries to create an auxvar such that auxvar(-1)=curvar(cur_lag)
1856 // At the beginning (resp. end) of each iteration, substexpr is an expression (possibly an auxvar) equivalent to curvar(cur_lag+1) (resp. curvar(cur_lag))
1857 while (cur_lag >= lag)
1858 {
1859 VariableNode *orig_expr = datatree.AddVariable(symb_id, cur_lag);
1860 if (auto it = subst_table.find(orig_expr); it == subst_table.end())
1861 {
1862 int aux_symb_id = datatree.symbol_table.addExoLagAuxiliaryVar(symb_id, cur_lag+1, substexpr);
1863 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), substexpr)));
1864 substexpr = datatree.AddVariable(aux_symb_id, -1);
1865 subst_table[orig_expr] = substexpr;
1866 }
1867 else
1868 substexpr = const_cast<VariableNode *>(it->second);
1869
1870 cur_lag--;
1871 }
1872 return substexpr;
1873
1874 case SymbolType::modelLocalVariable:
1875 if (expr_t value = datatree.getLocalVariable(symb_id); value->maxExoLag() == 0)
1876 return const_cast<VariableNode *>(this);
1877 else
1878 return value->substituteExoLag(subst_table, neweqs);
1879 default:
1880 return const_cast<VariableNode *>(this);
1881 }
1882 }
1883
1884 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const1885 VariableNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
1886 {
1887 if (get_type() == SymbolType::modelLocalVariable)
1888 return datatree.getLocalVariable(symb_id)->substituteExpectation(subst_table, neweqs, partial_information_model);
1889
1890 return const_cast<VariableNode *>(this);
1891 }
1892
1893 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const1894 VariableNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
1895 {
1896 switch (get_type())
1897 {
1898 case SymbolType::endogenous:
1899 assert(lag <= 1);
1900 if (lag <= 0
1901 || (subset.size() > 0
1902 && find(subset.begin(), subset.end(), datatree.symbol_table.getName(symb_id)) == subset.end()))
1903 return const_cast<VariableNode *>(this);
1904 else
1905 {
1906 VariableNode *diffvar;
1907 if (auto it = subst_table.find(this); it != subst_table.end())
1908 diffvar = const_cast<VariableNode *>(it->second);
1909 else
1910 {
1911 int aux_symb_id = datatree.symbol_table.addDiffForwardAuxiliaryVar(symb_id, datatree.AddMinus(datatree.AddVariable(symb_id, 0),
1912 datatree.AddVariable(symb_id, -1)));
1913 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(datatree.AddVariable(aux_symb_id, 0), datatree.AddMinus(datatree.AddVariable(symb_id, 0),
1914 datatree.AddVariable(symb_id, -1)))));
1915 diffvar = datatree.AddVariable(aux_symb_id, 1);
1916 subst_table[this] = diffvar;
1917 }
1918 return datatree.AddPlus(datatree.AddVariable(symb_id, 0), diffvar);
1919 }
1920 case SymbolType::modelLocalVariable:
1921 if (expr_t value = datatree.getLocalVariable(symb_id); value->maxEndoLead() <= 0)
1922 return const_cast<VariableNode *>(this);
1923 else
1924 return value->differentiateForwardVars(subset, subst_table, neweqs);
1925 default:
1926 return const_cast<VariableNode *>(this);
1927 }
1928 }
1929
1930 bool
isNumConstNodeEqualTo(double value) const1931 VariableNode::isNumConstNodeEqualTo(double value) const
1932 {
1933 return false;
1934 }
1935
1936 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const1937 VariableNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
1938 {
1939 if (get_type() == type_arg && datatree.symbol_table.getTypeSpecificID(symb_id) == variable_id && lag == lag_arg)
1940 return true;
1941 else
1942 return false;
1943 }
1944
1945 bool
containsPacExpectation(const string & pac_model_name) const1946 VariableNode::containsPacExpectation(const string &pac_model_name) const
1947 {
1948 if (get_type() == SymbolType::modelLocalVariable)
1949 return datatree.getLocalVariable(symb_id)->containsPacExpectation(pac_model_name);
1950
1951 return false;
1952 }
1953
1954 bool
containsEndogenous() const1955 VariableNode::containsEndogenous() const
1956 {
1957 if (get_type() == SymbolType::modelLocalVariable)
1958 return datatree.getLocalVariable(symb_id)->containsEndogenous();
1959
1960 if (get_type() == SymbolType::endogenous)
1961 return true;
1962 else
1963 return false;
1964 }
1965
1966 bool
containsExogenous() const1967 VariableNode::containsExogenous() const
1968 {
1969 if (get_type() == SymbolType::modelLocalVariable)
1970 return datatree.getLocalVariable(symb_id)->containsExogenous();
1971
1972 return get_type() == SymbolType::exogenous || get_type() == SymbolType::exogenousDet;
1973 }
1974
1975 expr_t
replaceTrendVar() const1976 VariableNode::replaceTrendVar() const
1977 {
1978 if (get_type() == SymbolType::modelLocalVariable)
1979 return datatree.getLocalVariable(symb_id)->replaceTrendVar();
1980
1981 if (get_type() == SymbolType::trend)
1982 return datatree.One;
1983 else if (get_type() == SymbolType::logTrend)
1984 return datatree.Zero;
1985 else
1986 return const_cast<VariableNode *>(this);
1987 }
1988
1989 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const1990 VariableNode::detrend(int symb_id, bool log_trend, expr_t trend) const
1991 {
1992 if (get_type() == SymbolType::modelLocalVariable)
1993 return datatree.getLocalVariable(symb_id)->detrend(symb_id, log_trend, trend);
1994
1995 if (this->symb_id != symb_id)
1996 return const_cast<VariableNode *>(this);
1997
1998 if (log_trend)
1999 {
2000 if (lag == 0)
2001 return datatree.AddPlus(const_cast<VariableNode *>(this), trend);
2002 else
2003 return datatree.AddPlus(const_cast<VariableNode *>(this), trend->decreaseLeadsLags(-lag));
2004 }
2005 else
2006 {
2007 if (lag == 0)
2008 return datatree.AddTimes(const_cast<VariableNode *>(this), trend);
2009 else
2010 return datatree.AddTimes(const_cast<VariableNode *>(this), trend->decreaseLeadsLags(-lag));
2011 }
2012 }
2013
2014 int
countDiffs() const2015 VariableNode::countDiffs() const
2016 {
2017 if (get_type() == SymbolType::modelLocalVariable)
2018 return datatree.getLocalVariable(symb_id)->countDiffs();
2019
2020 return 0;
2021 }
2022
2023 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const2024 VariableNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
2025 {
2026 if (get_type() == SymbolType::modelLocalVariable)
2027 return datatree.getLocalVariable(symb_id)->removeTrendLeadLag(trend_symbols_map);
2028
2029 if ((get_type() != SymbolType::trend && get_type() != SymbolType::logTrend) || lag == 0)
2030 return const_cast<VariableNode *>(this);
2031
2032 auto it = trend_symbols_map.find(symb_id);
2033 expr_t noTrendLeadLagNode = datatree.AddVariable(it->first);
2034 bool log_trend = get_type() == SymbolType::logTrend;
2035 expr_t trend = it->second;
2036
2037 if (lag > 0)
2038 {
2039 expr_t growthFactorSequence = trend->decreaseLeadsLags(-1);
2040 if (log_trend)
2041 {
2042 for (int i = 1; i < lag; i++)
2043 growthFactorSequence = datatree.AddPlus(growthFactorSequence, trend->decreaseLeadsLags(-1*(i+1)));
2044 return datatree.AddPlus(noTrendLeadLagNode, growthFactorSequence);
2045 }
2046 else
2047 {
2048 for (int i = 1; i < lag; i++)
2049 growthFactorSequence = datatree.AddTimes(growthFactorSequence, trend->decreaseLeadsLags(-1*(i+1)));
2050 return datatree.AddTimes(noTrendLeadLagNode, growthFactorSequence);
2051 }
2052 }
2053 else //get_lag < 0
2054 {
2055 expr_t growthFactorSequence = trend;
2056 if (log_trend)
2057 {
2058 for (int i = 1; i < abs(lag); i++)
2059 growthFactorSequence = datatree.AddPlus(growthFactorSequence, trend->decreaseLeadsLags(i));
2060 return datatree.AddMinus(noTrendLeadLagNode, growthFactorSequence);
2061 }
2062 else
2063 {
2064 for (int i = 1; i < abs(lag); i++)
2065 growthFactorSequence = datatree.AddTimes(growthFactorSequence, trend->decreaseLeadsLags(i));
2066 return datatree.AddDivide(noTrendLeadLagNode, growthFactorSequence);
2067 }
2068 }
2069 }
2070
2071 bool
isInStaticForm() const2072 VariableNode::isInStaticForm() const
2073 {
2074 if (get_type() == SymbolType::modelLocalVariable)
2075 return datatree.getLocalVariable(symb_id)->isInStaticForm();
2076
2077 return lag == 0;
2078 }
2079
2080 bool
isParamTimesEndogExpr() const2081 VariableNode::isParamTimesEndogExpr() const
2082 {
2083 if (get_type() == SymbolType::modelLocalVariable)
2084 return datatree.getLocalVariable(symb_id)->isParamTimesEndogExpr();
2085
2086 return false;
2087 }
2088
2089 bool
isVarModelReferenced(const string & model_info_name) const2090 VariableNode::isVarModelReferenced(const string &model_info_name) const
2091 {
2092 if (get_type() == SymbolType::modelLocalVariable)
2093 return datatree.getLocalVariable(symb_id)->isVarModelReferenced(model_info_name);
2094
2095 return false;
2096 }
2097
2098 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const2099 VariableNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
2100 {
2101 if (get_type() == SymbolType::modelLocalVariable)
2102 return datatree.getLocalVariable(symb_id)->getEndosAndMaxLags(model_endos_and_lags);
2103
2104 if (get_type() == SymbolType::endogenous)
2105 if (string varname = datatree.symbol_table.getName(symb_id);
2106 model_endos_and_lags.find(varname) == model_endos_and_lags.end())
2107 model_endos_and_lags[varname] = min(model_endos_and_lags[varname], lag);
2108 else
2109 model_endos_and_lags[varname] = lag;
2110 }
2111
2112 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const2113 VariableNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
2114 {
2115 return;
2116 }
2117
2118 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const2119 VariableNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
2120 {
2121 /* Do not recurse into model-local variables definitions, since MLVs are
2122 already handled by ModelTree::simplifyEquations().
2123 This is also necessary because of #65. */
2124 for (auto &it : table)
2125 if (it.first->symb_id == symb_id)
2126 return it.second;
2127 return const_cast<VariableNode *>(this);
2128 }
2129
UnaryOpNode(DataTree & datatree_arg,int idx_arg,UnaryOpcode op_code_arg,const expr_t arg_arg,int expectation_information_set_arg,int param1_symb_id_arg,int param2_symb_id_arg,string adl_param_name_arg,vector<int> adl_lags_arg)2130 UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg) :
2131 ExprNode{datatree_arg, idx_arg},
2132 arg{arg_arg},
2133 expectation_information_set{expectation_information_set_arg},
2134 param1_symb_id{param1_symb_id_arg},
2135 param2_symb_id{param2_symb_id_arg},
2136 op_code{op_code_arg},
2137 adl_param_name{move(adl_param_name_arg)},
2138 adl_lags{move(adl_lags_arg)}
2139 {
2140 }
2141
2142 void
prepareForDerivation()2143 UnaryOpNode::prepareForDerivation()
2144 {
2145 if (preparedForDerivation)
2146 return;
2147
2148 preparedForDerivation = true;
2149
2150 arg->prepareForDerivation();
2151
2152 // Non-null derivatives are those of the argument (except for STEADY_STATE)
2153 non_null_derivatives = arg->non_null_derivatives;
2154 if (op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
2155 || op_code == UnaryOpcode::steadyStateParam2ndDeriv)
2156 datatree.addAllParamDerivId(non_null_derivatives);
2157 }
2158
2159 expr_t
composeDerivatives(expr_t darg,int deriv_id)2160 UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
2161 {
2162 expr_t t11, t12, t13, t14;
2163
2164 switch (op_code)
2165 {
2166 case UnaryOpcode::uminus:
2167 return datatree.AddUMinus(darg);
2168 case UnaryOpcode::exp:
2169 return datatree.AddTimes(darg, this);
2170 case UnaryOpcode::log:
2171 return datatree.AddDivide(darg, arg);
2172 case UnaryOpcode::log10:
2173 t11 = datatree.AddExp(datatree.One);
2174 t12 = datatree.AddLog10(t11);
2175 t13 = datatree.AddDivide(darg, arg);
2176 return datatree.AddTimes(t12, t13);
2177 case UnaryOpcode::cos:
2178 t11 = datatree.AddSin(arg);
2179 t12 = datatree.AddUMinus(t11);
2180 return datatree.AddTimes(darg, t12);
2181 case UnaryOpcode::sin:
2182 t11 = datatree.AddCos(arg);
2183 return datatree.AddTimes(darg, t11);
2184 case UnaryOpcode::tan:
2185 t11 = datatree.AddTimes(this, this);
2186 t12 = datatree.AddPlus(t11, datatree.One);
2187 return datatree.AddTimes(darg, t12);
2188 case UnaryOpcode::acos:
2189 t11 = datatree.AddSin(this);
2190 t12 = datatree.AddDivide(darg, t11);
2191 return datatree.AddUMinus(t12);
2192 case UnaryOpcode::asin:
2193 t11 = datatree.AddCos(this);
2194 return datatree.AddDivide(darg, t11);
2195 case UnaryOpcode::atan:
2196 t11 = datatree.AddTimes(arg, arg);
2197 t12 = datatree.AddPlus(datatree.One, t11);
2198 return datatree.AddDivide(darg, t12);
2199 case UnaryOpcode::cosh:
2200 t11 = datatree.AddSinh(arg);
2201 return datatree.AddTimes(darg, t11);
2202 case UnaryOpcode::sinh:
2203 t11 = datatree.AddCosh(arg);
2204 return datatree.AddTimes(darg, t11);
2205 case UnaryOpcode::tanh:
2206 t11 = datatree.AddTimes(this, this);
2207 t12 = datatree.AddMinus(datatree.One, t11);
2208 return datatree.AddTimes(darg, t12);
2209 case UnaryOpcode::acosh:
2210 t11 = datatree.AddSinh(this);
2211 return datatree.AddDivide(darg, t11);
2212 case UnaryOpcode::asinh:
2213 t11 = datatree.AddCosh(this);
2214 return datatree.AddDivide(darg, t11);
2215 case UnaryOpcode::atanh:
2216 t11 = datatree.AddTimes(arg, arg);
2217 t12 = datatree.AddMinus(datatree.One, t11);
2218 return datatree.AddTimes(darg, t12);
2219 case UnaryOpcode::sqrt:
2220 t11 = datatree.AddPlus(this, this);
2221 return datatree.AddDivide(darg, t11);
2222 case UnaryOpcode::cbrt:
2223 t11 = datatree.AddPower(arg, datatree.AddDivide(datatree.Two, datatree.Three));
2224 t12 = datatree.AddTimes(datatree.Three, t11);
2225 return datatree.AddDivide(darg, t12);
2226 case UnaryOpcode::abs:
2227 t11 = datatree.AddSign(arg);
2228 return datatree.AddTimes(t11, darg);
2229 case UnaryOpcode::sign:
2230 return datatree.Zero;
2231 case UnaryOpcode::steadyState:
2232 if (datatree.isDynamic())
2233 {
2234 if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
2235 {
2236 auto varg = dynamic_cast<VariableNode *>(arg);
2237 if (!varg)
2238 {
2239 cerr << "UnaryOpNode::composeDerivatives: STEADY_STATE() should only be used on "
2240 << "standalone variables (like STEADY_STATE(y)) to be derivable w.r.t. parameters" << endl;
2241 exit(EXIT_FAILURE);
2242 }
2243 if (datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous)
2244 return datatree.AddSteadyStateParamDeriv(arg, datatree.getSymbIDByDerivID(deriv_id));
2245 else
2246 return datatree.Zero;
2247 }
2248 else
2249 return datatree.Zero;
2250 }
2251 else
2252 return darg;
2253 case UnaryOpcode::steadyStateParamDeriv:
2254 assert(datatree.isDynamic());
2255 if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
2256 {
2257 auto varg = dynamic_cast<VariableNode *>(arg);
2258 assert(varg);
2259 assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2260 return datatree.AddSteadyStateParam2ndDeriv(arg, param1_symb_id, datatree.getSymbIDByDerivID(deriv_id));
2261 }
2262 else
2263 return datatree.Zero;
2264 case UnaryOpcode::steadyStateParam2ndDeriv:
2265 assert(datatree.isDynamic());
2266 if (datatree.getTypeByDerivID(deriv_id) == SymbolType::parameter)
2267 {
2268 cerr << "3rd derivative of STEADY_STATE node w.r.t. three parameters not implemented" << endl;
2269 exit(EXIT_FAILURE);
2270 }
2271 else
2272 return datatree.Zero;
2273 case UnaryOpcode::expectation:
2274 cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::expectation" << endl;
2275 exit(EXIT_FAILURE);
2276 case UnaryOpcode::erf:
2277 // x^2
2278 t11 = datatree.AddPower(arg, datatree.Two);
2279 // exp(x^2)
2280 t12 = datatree.AddExp(t11);
2281 // sqrt(pi)
2282 t11 = datatree.AddSqrt(datatree.Pi);
2283 // sqrt(pi)*exp(x^2)
2284 t13 = datatree.AddTimes(t11, t12);
2285 // 2/(sqrt(pi)*exp(x^2));
2286 t14 = datatree.AddDivide(datatree.Two, t13);
2287 // (2/(sqrt(pi)*exp(x^2)))*dx;
2288 return datatree.AddTimes(t14, darg);
2289 case UnaryOpcode::diff:
2290 cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::diff" << endl;
2291 exit(EXIT_FAILURE);
2292 case UnaryOpcode::adl:
2293 cerr << "UnaryOpNode::composeDerivatives: not implemented on UnaryOpcode::adl" << endl;
2294 exit(EXIT_FAILURE);
2295 }
2296 // Suppress GCC warning
2297 exit(EXIT_FAILURE);
2298 }
2299
2300 expr_t
computeDerivative(int deriv_id)2301 UnaryOpNode::computeDerivative(int deriv_id)
2302 {
2303 expr_t darg = arg->getDerivative(deriv_id);
2304 return composeDerivatives(darg, deriv_id);
2305 }
2306
2307 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const2308 UnaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
2309 {
2310 // For a temporary term, the cost is null
2311 for (const auto &it : temp_terms_map)
2312 if (it.second.find(const_cast<UnaryOpNode *>(this)) != it.second.end())
2313 return 0;
2314
2315 return cost(arg->cost(temp_terms_map, is_matlab), is_matlab);
2316 }
2317
2318 int
cost(const temporary_terms_t & temporary_terms,bool is_matlab) const2319 UnaryOpNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
2320 {
2321 // For a temporary term, the cost is null
2322 if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
2323 return 0;
2324
2325 return cost(arg->cost(temporary_terms, is_matlab), is_matlab);
2326 }
2327
2328 int
cost(int cost,bool is_matlab) const2329 UnaryOpNode::cost(int cost, bool is_matlab) const
2330 {
2331 if (is_matlab)
2332 // Cost for Matlab files
2333 switch (op_code)
2334 {
2335 case UnaryOpcode::uminus:
2336 case UnaryOpcode::sign:
2337 return cost + 70;
2338 case UnaryOpcode::exp:
2339 return cost + 160;
2340 case UnaryOpcode::log:
2341 return cost + 300;
2342 case UnaryOpcode::log10:
2343 case UnaryOpcode::erf:
2344 return cost + 16000;
2345 case UnaryOpcode::cos:
2346 case UnaryOpcode::sin:
2347 case UnaryOpcode::cosh:
2348 return cost + 210;
2349 case UnaryOpcode::tan:
2350 return cost + 230;
2351 case UnaryOpcode::acos:
2352 return cost + 300;
2353 case UnaryOpcode::asin:
2354 return cost + 310;
2355 case UnaryOpcode::atan:
2356 return cost + 140;
2357 case UnaryOpcode::sinh:
2358 return cost + 240;
2359 case UnaryOpcode::tanh:
2360 return cost + 190;
2361 case UnaryOpcode::acosh:
2362 return cost + 770;
2363 case UnaryOpcode::asinh:
2364 return cost + 460;
2365 case UnaryOpcode::atanh:
2366 return cost + 350;
2367 case UnaryOpcode::sqrt:
2368 case UnaryOpcode::cbrt:
2369 case UnaryOpcode::abs:
2370 return cost + 570;
2371 case UnaryOpcode::steadyState:
2372 case UnaryOpcode::steadyStateParamDeriv:
2373 case UnaryOpcode::steadyStateParam2ndDeriv:
2374 case UnaryOpcode::expectation:
2375 return cost;
2376 case UnaryOpcode::diff:
2377 cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::diff" << endl;
2378 exit(EXIT_FAILURE);
2379 case UnaryOpcode::adl:
2380 cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::adl" << endl;
2381 exit(EXIT_FAILURE);
2382 }
2383 else
2384 // Cost for C files
2385 switch (op_code)
2386 {
2387 case UnaryOpcode::uminus:
2388 case UnaryOpcode::sign:
2389 return cost + 3;
2390 case UnaryOpcode::exp:
2391 case UnaryOpcode::acosh:
2392 return cost + 210;
2393 case UnaryOpcode::log:
2394 return cost + 137;
2395 case UnaryOpcode::log10:
2396 return cost + 139;
2397 case UnaryOpcode::cos:
2398 case UnaryOpcode::sin:
2399 return cost + 160;
2400 case UnaryOpcode::tan:
2401 return cost + 170;
2402 case UnaryOpcode::acos:
2403 case UnaryOpcode::atan:
2404 return cost + 190;
2405 case UnaryOpcode::asin:
2406 return cost + 180;
2407 case UnaryOpcode::cosh:
2408 case UnaryOpcode::sinh:
2409 case UnaryOpcode::tanh:
2410 case UnaryOpcode::erf:
2411 return cost + 240;
2412 case UnaryOpcode::asinh:
2413 return cost + 220;
2414 case UnaryOpcode::atanh:
2415 return cost + 150;
2416 case UnaryOpcode::sqrt:
2417 case UnaryOpcode::cbrt:
2418 case UnaryOpcode::abs:
2419 return cost + 90;
2420 case UnaryOpcode::steadyState:
2421 case UnaryOpcode::steadyStateParamDeriv:
2422 case UnaryOpcode::steadyStateParam2ndDeriv:
2423 case UnaryOpcode::expectation:
2424 return cost;
2425 case UnaryOpcode::diff:
2426 cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::diff" << endl;
2427 exit(EXIT_FAILURE);
2428 case UnaryOpcode::adl:
2429 cerr << "UnaryOpNode::cost: not implemented on UnaryOpcode::adl" << endl;
2430 exit(EXIT_FAILURE);
2431 }
2432 exit(EXIT_FAILURE);
2433 }
2434
2435 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const2436 UnaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
2437 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
2438 map<expr_t, pair<int, pair<int, int>>> &reference_count,
2439 bool is_matlab) const
2440 {
2441 expr_t this2 = const_cast<UnaryOpNode *>(this);
2442 if (auto it = reference_count.find(this2);
2443 it == reference_count.end())
2444 {
2445 reference_count[this2] = { 1, derivOrder };
2446 arg->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
2447 }
2448 else
2449 {
2450 reference_count[this2] = { it->second.first + 1, it->second.second };
2451 if (reference_count[this2].first * cost(temp_terms_map, is_matlab) > min_cost(is_matlab))
2452 temp_terms_map[reference_count[this2].second].insert(this2);
2453 }
2454 }
2455
2456 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const2457 UnaryOpNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
2458 temporary_terms_t &temporary_terms,
2459 map<expr_t, pair<int, int>> &first_occurence,
2460 int Curr_block,
2461 vector< vector<temporary_terms_t>> &v_temporary_terms,
2462 int equation) const
2463 {
2464 expr_t this2 = const_cast<UnaryOpNode *>(this);
2465 if (auto it = reference_count.find(this2);
2466 it == reference_count.end())
2467 {
2468 reference_count[this2] = 1;
2469 first_occurence[this2] = { Curr_block, equation };
2470 arg->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
2471 }
2472 else
2473 {
2474 reference_count[this2]++;
2475 if (reference_count[this2] * cost(temporary_terms, false) > min_cost_c)
2476 {
2477 temporary_terms.insert(this2);
2478 v_temporary_terms[first_occurence[this2].first][first_occurence[this2].second].insert(this2);
2479 }
2480 }
2481 }
2482
2483 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const2484 UnaryOpNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
2485 {
2486 if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
2487 temporary_terms_inuse.insert(idx);
2488 else
2489 arg->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
2490 }
2491
2492 bool
containsExternalFunction() const2493 UnaryOpNode::containsExternalFunction() const
2494 {
2495 return arg->containsExternalFunction();
2496 }
2497
2498 void
writeJsonAST(ostream & output) const2499 UnaryOpNode::writeJsonAST(ostream &output) const
2500 {
2501 output << R"({"node_type" : "UnaryOpNode", "op" : ")";
2502 switch (op_code)
2503 {
2504 case UnaryOpcode::uminus:
2505 output << "uminus";
2506 break;
2507 case UnaryOpcode::exp:
2508 output << "exp";
2509 break;
2510 case UnaryOpcode::log:
2511 output << "log";
2512 break;
2513 case UnaryOpcode::log10:
2514 output << "log10";
2515 break;
2516 case UnaryOpcode::cos:
2517 output << "cos";
2518 break;
2519 case UnaryOpcode::sin:
2520 output << "sin";
2521 break;
2522 case UnaryOpcode::tan:
2523 output << "tan";
2524 break;
2525 case UnaryOpcode::acos:
2526 output << "acos";
2527 break;
2528 case UnaryOpcode::asin:
2529 output << "asin";
2530 break;
2531 case UnaryOpcode::atan:
2532 output << "atan";
2533 break;
2534 case UnaryOpcode::cosh:
2535 output << "cosh";
2536 break;
2537 case UnaryOpcode::sinh:
2538 output << "sinh";
2539 break;
2540 case UnaryOpcode::tanh:
2541 output << "tanh";
2542 break;
2543 case UnaryOpcode::acosh:
2544 output << "acosh";
2545 break;
2546 case UnaryOpcode::asinh:
2547 output << "asinh";
2548 break;
2549 case UnaryOpcode::atanh:
2550 output << "atanh";
2551 break;
2552 case UnaryOpcode::sqrt:
2553 output << "sqrt";
2554 break;
2555 case UnaryOpcode::cbrt:
2556 output << "cbrt";
2557 break;
2558 case UnaryOpcode::abs:
2559 output << "abs";
2560 break;
2561 case UnaryOpcode::sign:
2562 output << "sign";
2563 break;
2564 case UnaryOpcode::diff:
2565 output << "diff";
2566 break;
2567 case UnaryOpcode::adl:
2568 output << "adl";
2569 break;
2570 case UnaryOpcode::steadyState:
2571 output << "steady_state";
2572 case UnaryOpcode::steadyStateParamDeriv:
2573 output << "steady_state_param_deriv";
2574 break;
2575 case UnaryOpcode::steadyStateParam2ndDeriv:
2576 output << "steady_state_param_second_deriv";
2577 break;
2578 case UnaryOpcode::expectation:
2579 output << "expectation";
2580 break;
2581 case UnaryOpcode::erf:
2582 output << "erf";
2583 break;
2584 }
2585 output << R"(", "arg" : )";
2586 arg->writeJsonAST(output);
2587 switch (op_code)
2588 {
2589 case UnaryOpcode::adl:
2590 output << R"(, "adl_param_name" : ")" << adl_param_name << R"(")"
2591 << R"(, "lags" : [)";
2592 for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
2593 {
2594 if (it != adl_lags.begin())
2595 output << ", ";
2596 output << *it;
2597 }
2598 output << "]";
2599 break;
2600 default:
2601 break;
2602 }
2603 output << "}";
2604 }
2605
2606 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const2607 UnaryOpNode::writeJsonOutput(ostream &output,
2608 const temporary_terms_t &temporary_terms,
2609 const deriv_node_temp_terms_t &tef_terms,
2610 bool isdynamic) const
2611 {
2612 if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
2613 {
2614 output << "T" << idx;
2615 return;
2616 }
2617
2618 // Always put parenthesis around uminus nodes
2619 if (op_code == UnaryOpcode::uminus)
2620 output << "(";
2621
2622 switch (op_code)
2623 {
2624 case UnaryOpcode::uminus:
2625 output << "-";
2626 break;
2627 case UnaryOpcode::exp:
2628 output << "exp";
2629 break;
2630 case UnaryOpcode::log:
2631 output << "log";
2632 break;
2633 case UnaryOpcode::log10:
2634 output << "log10";
2635 break;
2636 case UnaryOpcode::cos:
2637 output << "cos";
2638 break;
2639 case UnaryOpcode::sin:
2640 output << "sin";
2641 break;
2642 case UnaryOpcode::tan:
2643 output << "tan";
2644 break;
2645 case UnaryOpcode::acos:
2646 output << "acos";
2647 break;
2648 case UnaryOpcode::asin:
2649 output << "asin";
2650 break;
2651 case UnaryOpcode::atan:
2652 output << "atan";
2653 break;
2654 case UnaryOpcode::cosh:
2655 output << "cosh";
2656 break;
2657 case UnaryOpcode::sinh:
2658 output << "sinh";
2659 break;
2660 case UnaryOpcode::tanh:
2661 output << "tanh";
2662 break;
2663 case UnaryOpcode::acosh:
2664 output << "acosh";
2665 break;
2666 case UnaryOpcode::asinh:
2667 output << "asinh";
2668 break;
2669 case UnaryOpcode::atanh:
2670 output << "atanh";
2671 break;
2672 case UnaryOpcode::sqrt:
2673 output << "sqrt";
2674 break;
2675 case UnaryOpcode::cbrt:
2676 output << "cbrt";
2677 break;
2678 case UnaryOpcode::abs:
2679 output << "abs";
2680 break;
2681 case UnaryOpcode::sign:
2682 output << "sign";
2683 break;
2684 case UnaryOpcode::diff:
2685 output << "diff";
2686 break;
2687 case UnaryOpcode::adl:
2688 output << "adl(";
2689 arg->writeJsonOutput(output, temporary_terms, tef_terms);
2690 output << ", '" << adl_param_name << "', [";
2691 for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
2692 {
2693 if (it != adl_lags.begin())
2694 output << ", ";
2695 output << *it;
2696 }
2697 output << "])";
2698 return;
2699 case UnaryOpcode::steadyState:
2700 output << "(";
2701 arg->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
2702 output << ")";
2703 return;
2704 case UnaryOpcode::steadyStateParamDeriv:
2705 {
2706 auto varg = dynamic_cast<VariableNode *>(arg);
2707 assert(varg);
2708 assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2709 assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2710 int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2711 int tsid_param = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2712 output << "ss_param_deriv(" << tsid_endo+1 << "," << tsid_param+1 << ")";
2713 }
2714 return;
2715 case UnaryOpcode::steadyStateParam2ndDeriv:
2716 {
2717 auto varg = dynamic_cast<VariableNode *>(arg);
2718 assert(varg);
2719 assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2720 assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2721 assert(datatree.symbol_table.getType(param2_symb_id) == SymbolType::parameter);
2722 int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2723 int tsid_param1 = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2724 int tsid_param2 = datatree.symbol_table.getTypeSpecificID(param2_symb_id);
2725 output << "ss_param_2nd_deriv(" << tsid_endo+1 << "," << tsid_param1+1
2726 << "," << tsid_param2+1 << ")";
2727 }
2728 return;
2729 case UnaryOpcode::expectation:
2730 output << "EXPECTATION(" << expectation_information_set << ")";
2731 break;
2732 case UnaryOpcode::erf:
2733 output << "erf";
2734 break;
2735 }
2736
2737 bool close_parenthesis = false;
2738
2739 /* Enclose argument with parentheses if:
2740 - current opcode is not uminus, or
2741 - current opcode is uminus and argument has lowest precedence
2742 */
2743 if (op_code != UnaryOpcode::uminus
2744 || (op_code == UnaryOpcode::uminus
2745 && arg->precedenceJson(temporary_terms) < precedenceJson(temporary_terms)))
2746 {
2747 output << "(";
2748 close_parenthesis = true;
2749 }
2750
2751 // Write argument
2752 arg->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
2753
2754 if (close_parenthesis)
2755 output << ")";
2756
2757 // Close parenthesis for uminus
2758 if (op_code == UnaryOpcode::uminus)
2759 output << ")";
2760 }
2761
2762 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const2763 UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
2764 const temporary_terms_t &temporary_terms,
2765 const temporary_terms_idxs_t &temporary_terms_idxs,
2766 const deriv_node_temp_terms_t &tef_terms) const
2767 {
2768 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
2769 return;
2770
2771 // Always put parenthesis around uminus nodes
2772 if (op_code == UnaryOpcode::uminus)
2773 output << LEFT_PAR(output_type);
2774
2775 switch (op_code)
2776 {
2777 case UnaryOpcode::uminus:
2778 output << "-";
2779 break;
2780 case UnaryOpcode::exp:
2781 if (isLatexOutput(output_type))
2782 output << R"(\exp)";
2783 else
2784 output << "exp";
2785 break;
2786 case UnaryOpcode::log:
2787 if (isLatexOutput(output_type))
2788 output << R"(\log)";
2789 else
2790 output << "log";
2791 break;
2792 case UnaryOpcode::log10:
2793 if (isLatexOutput(output_type))
2794 output << R"(\log_{10})";
2795 else
2796 output << "log10";
2797 break;
2798 case UnaryOpcode::cos:
2799 if (isLatexOutput(output_type))
2800 output << R"(\cos)";
2801 else
2802 output << "cos";
2803 break;
2804 case UnaryOpcode::sin:
2805 if (isLatexOutput(output_type))
2806 output << R"(\sin)";
2807 else
2808 output << "sin";
2809 break;
2810 case UnaryOpcode::tan:
2811 if (isLatexOutput(output_type))
2812 output << R"(\tan)";
2813 else
2814 output << "tan";
2815 break;
2816 case UnaryOpcode::acos:
2817 output << "acos";
2818 break;
2819 case UnaryOpcode::asin:
2820 output << "asin";
2821 break;
2822 case UnaryOpcode::atan:
2823 output << "atan";
2824 break;
2825 case UnaryOpcode::cosh:
2826 output << "cosh";
2827 break;
2828 case UnaryOpcode::sinh:
2829 output << "sinh";
2830 break;
2831 case UnaryOpcode::tanh:
2832 output << "tanh";
2833 break;
2834 case UnaryOpcode::acosh:
2835 output << "acosh";
2836 break;
2837 case UnaryOpcode::asinh:
2838 output << "asinh";
2839 break;
2840 case UnaryOpcode::atanh:
2841 output << "atanh";
2842 break;
2843 case UnaryOpcode::sqrt:
2844 if (isLatexOutput(output_type))
2845 {
2846 output << R"(\sqrt{)";
2847 arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2848 output << "}";
2849 return;
2850 }
2851 output << "sqrt";
2852 break;
2853 case UnaryOpcode::cbrt:
2854 if (isMatlabOutput(output_type))
2855 {
2856 output << "nthroot(";
2857 arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2858 output << ", 3)";
2859 return;
2860 }
2861 else if (isLatexOutput(output_type))
2862 {
2863 output << R"(\sqrt[3]{)";
2864 arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2865 output << "}";
2866 return;
2867 }
2868 else
2869 output << "cbrt";
2870 break;
2871 case UnaryOpcode::abs:
2872 output << "abs";
2873 break;
2874 case UnaryOpcode::sign:
2875 if (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel)
2876 output << "copysign";
2877 else
2878 output << "sign";
2879 break;
2880 case UnaryOpcode::steadyState:
2881 ExprNodeOutputType new_output_type;
2882 switch (output_type)
2883 {
2884 case ExprNodeOutputType::matlabDynamicModel:
2885 new_output_type = ExprNodeOutputType::matlabDynamicSteadyStateOperator;
2886 break;
2887 case ExprNodeOutputType::latexDynamicModel:
2888 new_output_type = ExprNodeOutputType::latexDynamicSteadyStateOperator;
2889 break;
2890 case ExprNodeOutputType::CDynamicModel:
2891 new_output_type = ExprNodeOutputType::CDynamicSteadyStateOperator;
2892 break;
2893 case ExprNodeOutputType::juliaDynamicModel:
2894 new_output_type = ExprNodeOutputType::juliaDynamicSteadyStateOperator;
2895 break;
2896 case ExprNodeOutputType::matlabDynamicModelSparse:
2897 new_output_type = ExprNodeOutputType::matlabDynamicSparseSteadyStateOperator;
2898 break;
2899 default:
2900 new_output_type = output_type;
2901 break;
2902 }
2903 output << "(";
2904 arg->writeOutput(output, new_output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2905 output << ")";
2906 return;
2907 case UnaryOpcode::steadyStateParamDeriv:
2908 {
2909 auto varg = dynamic_cast<VariableNode *>(arg);
2910 assert(varg);
2911 assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2912 assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2913 int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2914 int tsid_param = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2915 assert(isMatlabOutput(output_type));
2916 output << "ss_param_deriv(" << tsid_endo+1 << "," << tsid_param+1 << ")";
2917 }
2918 return;
2919 case UnaryOpcode::steadyStateParam2ndDeriv:
2920 {
2921 auto varg = dynamic_cast<VariableNode *>(arg);
2922 assert(varg);
2923 assert(datatree.symbol_table.getType(varg->symb_id) == SymbolType::endogenous);
2924 assert(datatree.symbol_table.getType(param1_symb_id) == SymbolType::parameter);
2925 assert(datatree.symbol_table.getType(param2_symb_id) == SymbolType::parameter);
2926 int tsid_endo = datatree.symbol_table.getTypeSpecificID(varg->symb_id);
2927 int tsid_param1 = datatree.symbol_table.getTypeSpecificID(param1_symb_id);
2928 int tsid_param2 = datatree.symbol_table.getTypeSpecificID(param2_symb_id);
2929 assert(isMatlabOutput(output_type));
2930 output << "ss_param_2nd_deriv(" << tsid_endo+1 << "," << tsid_param1+1
2931 << "," << tsid_param2+1 << ")";
2932 }
2933 return;
2934 case UnaryOpcode::expectation:
2935 if (!isLatexOutput(output_type))
2936 {
2937 cerr << "UnaryOpNode::writeOutput: not implemented on UnaryOpcode::expectation" << endl;
2938 exit(EXIT_FAILURE);
2939 }
2940 output << R"(\mathbb{E}_{t)";
2941 if (expectation_information_set != 0)
2942 {
2943 if (expectation_information_set > 0)
2944 output << "+";
2945 output << expectation_information_set;
2946 }
2947 output << "}";
2948 break;
2949 case UnaryOpcode::erf:
2950 output << "erf";
2951 break;
2952 case UnaryOpcode::diff:
2953 output << "diff";
2954 break;
2955 case UnaryOpcode::adl:
2956 output << "adl";
2957 break;
2958 }
2959
2960 bool close_parenthesis = false;
2961
2962 /* Enclose argument with parentheses if:
2963 - current opcode is not uminus, or
2964 - current opcode is uminus and argument has lowest precedence
2965 */
2966 if (op_code != UnaryOpcode::uminus
2967 || (op_code == UnaryOpcode::uminus
2968 && arg->precedence(output_type, temporary_terms) < precedence(output_type, temporary_terms)))
2969 {
2970 output << LEFT_PAR(output_type);
2971 if (op_code == UnaryOpcode::sign && (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel))
2972 output << "1.0,";
2973 close_parenthesis = true;
2974 }
2975
2976 // Write argument
2977 arg->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2978
2979 if (close_parenthesis)
2980 output << RIGHT_PAR(output_type);
2981
2982 // Close parenthesis for uminus
2983 if (op_code == UnaryOpcode::uminus)
2984 output << RIGHT_PAR(output_type);
2985 }
2986
2987 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const2988 UnaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
2989 const temporary_terms_t &temporary_terms,
2990 const temporary_terms_idxs_t &temporary_terms_idxs,
2991 deriv_node_temp_terms_t &tef_terms) const
2992 {
2993 arg->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
2994 }
2995
2996 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const2997 UnaryOpNode::writeJsonExternalFunctionOutput(vector<string> &efout,
2998 const temporary_terms_t &temporary_terms,
2999 deriv_node_temp_terms_t &tef_terms,
3000 bool isdynamic) const
3001 {
3002 arg->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
3003 }
3004
3005 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const3006 UnaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
3007 bool lhs_rhs, const temporary_terms_t &temporary_terms,
3008 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
3009 deriv_node_temp_terms_t &tef_terms) const
3010 {
3011 arg->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
3012 dynamic, steady_dynamic, tef_terms);
3013 }
3014
3015 double
eval_opcode(UnaryOpcode op_code,double v)3016 UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) noexcept(false)
3017 {
3018 switch (op_code)
3019 {
3020 case UnaryOpcode::uminus:
3021 return -v;
3022 case UnaryOpcode::exp:
3023 return exp(v);
3024 case UnaryOpcode::log:
3025 return log(v);
3026 case UnaryOpcode::log10:
3027 return log10(v);
3028 case UnaryOpcode::cos:
3029 return cos(v);
3030 case UnaryOpcode::sin:
3031 return sin(v);
3032 case UnaryOpcode::tan:
3033 return tan(v);
3034 case UnaryOpcode::acos:
3035 return acos(v);
3036 case UnaryOpcode::asin:
3037 return asin(v);
3038 case UnaryOpcode::atan:
3039 return atan(v);
3040 case UnaryOpcode::cosh:
3041 return cosh(v);
3042 case UnaryOpcode::sinh:
3043 return sinh(v);
3044 case UnaryOpcode::tanh:
3045 return tanh(v);
3046 case UnaryOpcode::acosh:
3047 return acosh(v);
3048 case UnaryOpcode::asinh:
3049 return asinh(v);
3050 case UnaryOpcode::atanh:
3051 return atanh(v);
3052 case UnaryOpcode::sqrt:
3053 return sqrt(v);
3054 case UnaryOpcode::cbrt:
3055 return cbrt(v);
3056 case UnaryOpcode::abs:
3057 return abs(v);
3058 case UnaryOpcode::sign:
3059 return (v > 0) ? 1 : ((v < 0) ? -1 : 0);
3060 case UnaryOpcode::steadyState:
3061 return v;
3062 case UnaryOpcode::steadyStateParamDeriv:
3063 cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::steadyStateParamDeriv" << endl;
3064 exit(EXIT_FAILURE);
3065 case UnaryOpcode::steadyStateParam2ndDeriv:
3066 cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::steadyStateParam2ndDeriv" << endl;
3067 exit(EXIT_FAILURE);
3068 case UnaryOpcode::expectation:
3069 cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::expectation" << endl;
3070 exit(EXIT_FAILURE);
3071 case UnaryOpcode::erf:
3072 return erf(v);
3073 case UnaryOpcode::diff:
3074 cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::diff" << endl;
3075 exit(EXIT_FAILURE);
3076 case UnaryOpcode::adl:
3077 cerr << "UnaryOpNode::eval_opcode: not implemented on UnaryOpcode::adl" << endl;
3078 exit(EXIT_FAILURE);
3079 }
3080 // Suppress GCC warning
3081 exit(EXIT_FAILURE);
3082 }
3083
3084 double
eval(const eval_context_t & eval_context) const3085 UnaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
3086 {
3087 double v = arg->eval(eval_context);
3088 return eval_opcode(op_code, v);
3089 }
3090
3091 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const3092 UnaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
3093 bool lhs_rhs, const temporary_terms_t &temporary_terms,
3094 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
3095 const deriv_node_temp_terms_t &tef_terms) const
3096 {
3097 if (temporary_terms.find(const_cast<UnaryOpNode *>(this)) != temporary_terms.end())
3098 {
3099 if (dynamic)
3100 {
3101 auto ii = map_idx.find(idx);
3102 FLDT_ fldt(ii->second);
3103 fldt.write(CompileCode, instruction_number);
3104 }
3105 else
3106 {
3107 auto ii = map_idx.find(idx);
3108 FLDST_ fldst(ii->second);
3109 fldst.write(CompileCode, instruction_number);
3110 }
3111 return;
3112 }
3113 if (op_code == UnaryOpcode::steadyState)
3114 arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, true, tef_terms);
3115 else
3116 {
3117 arg->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
3118 FUNARY_ funary{static_cast<uint8_t>(op_code)};
3119 funary.write(CompileCode, instruction_number);
3120 }
3121 }
3122
3123 void
collectVARLHSVariable(set<expr_t> & result) const3124 UnaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
3125 {
3126 if (op_code == UnaryOpcode::diff)
3127 result.insert(const_cast<UnaryOpNode *>(this));
3128 else
3129 arg->collectVARLHSVariable(result);
3130 }
3131
3132 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const3133 UnaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
3134 {
3135 arg->collectDynamicVariables(type_arg, result);
3136 }
3137
3138 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const3139 UnaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
3140 {
3141 pair<int, expr_t> res = arg->normalizeEquation(var_endo, List_of_Op_RHS);
3142 int is_endogenous_present = res.first;
3143 expr_t New_expr_t = res.second;
3144
3145 if (is_endogenous_present == 2) /* The equation could not be normalized and the process is given-up*/
3146 return { 2, nullptr };
3147 else if (is_endogenous_present) /* The argument of the function contains the current values of
3148 the endogenous variable associated to the equation.
3149 In order to normalized, we have to apply the invert function to the RHS.*/
3150 {
3151 switch (op_code)
3152 {
3153 case UnaryOpcode::uminus:
3154 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::uminus), nullptr, nullptr);
3155 return { 1, nullptr };
3156 case UnaryOpcode::exp:
3157 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::log), nullptr, nullptr);
3158 return { 1, nullptr };
3159 case UnaryOpcode::log:
3160 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::exp), nullptr, nullptr);
3161 return { 1, nullptr };
3162 case UnaryOpcode::log10:
3163 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.AddNonNegativeConstant("10"));
3164 return { 1, nullptr };
3165 case UnaryOpcode::cos:
3166 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::acos), nullptr, nullptr);
3167 return { 1, nullptr };
3168 case UnaryOpcode::sin:
3169 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::asin), nullptr, nullptr);
3170 return { 1, nullptr };
3171 case UnaryOpcode::tan:
3172 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::atan), nullptr, nullptr);
3173 return { 1, nullptr };
3174 case UnaryOpcode::acos:
3175 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::cos), nullptr, nullptr);
3176 return { 1, nullptr };
3177 case UnaryOpcode::asin:
3178 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::sin), nullptr, nullptr);
3179 return { 1, nullptr };
3180 case UnaryOpcode::atan:
3181 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::tan), nullptr, nullptr);
3182 return { 1, nullptr };
3183 case UnaryOpcode::cosh:
3184 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::acosh), nullptr, nullptr);
3185 return { 1, nullptr };
3186 case UnaryOpcode::sinh:
3187 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::asinh), nullptr, nullptr);
3188 return { 1, nullptr };
3189 case UnaryOpcode::tanh:
3190 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::atanh), nullptr, nullptr);
3191 return { 1, nullptr };
3192 case UnaryOpcode::acosh:
3193 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::cosh), nullptr, nullptr);
3194 return { 1, nullptr };
3195 case UnaryOpcode::asinh:
3196 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::sinh), nullptr, nullptr);
3197 return { 1, nullptr };
3198 case UnaryOpcode::atanh:
3199 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::tanh), nullptr, nullptr);
3200 return { 1, nullptr };
3201 case UnaryOpcode::sqrt:
3202 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Two);
3203 return { 1, nullptr };
3204 case UnaryOpcode::cbrt:
3205 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Three);
3206 return { 1, nullptr };
3207 case UnaryOpcode::abs:
3208 return { 2, nullptr };
3209 case UnaryOpcode::sign:
3210 return { 2, nullptr };
3211 case UnaryOpcode::steadyState:
3212 return { 2, nullptr };
3213 case UnaryOpcode::erf:
3214 return { 2, nullptr };
3215 default:
3216 cerr << "Unary operator not handled during the normalization process" << endl;
3217 return { 2, nullptr }; // Could not be normalized
3218 }
3219 }
3220 else
3221 { /* If the argument of the function do not contain the current values of the endogenous variable
3222 related to the equation, the function with its argument is stored in the RHS*/
3223 switch (op_code)
3224 {
3225 case UnaryOpcode::uminus:
3226 return { 0, datatree.AddUMinus(New_expr_t) };
3227 case UnaryOpcode::exp:
3228 return { 0, datatree.AddExp(New_expr_t) };
3229 case UnaryOpcode::log:
3230 return { 0, datatree.AddLog(New_expr_t) };
3231 case UnaryOpcode::log10:
3232 return { 0, datatree.AddLog10(New_expr_t) };
3233 case UnaryOpcode::cos:
3234 return { 0, datatree.AddCos(New_expr_t) };
3235 case UnaryOpcode::sin:
3236 return { 0, datatree.AddSin(New_expr_t) };
3237 case UnaryOpcode::tan:
3238 return { 0, datatree.AddTan(New_expr_t) };
3239 case UnaryOpcode::acos:
3240 return { 0, datatree.AddAcos(New_expr_t) };
3241 case UnaryOpcode::asin:
3242 return { 0, datatree.AddAsin(New_expr_t) };
3243 case UnaryOpcode::atan:
3244 return { 0, datatree.AddAtan(New_expr_t) };
3245 case UnaryOpcode::cosh:
3246 return { 0, datatree.AddCosh(New_expr_t) };
3247 case UnaryOpcode::sinh:
3248 return { 0, datatree.AddSinh(New_expr_t) };
3249 case UnaryOpcode::tanh:
3250 return { 0, datatree.AddTanh(New_expr_t) };
3251 case UnaryOpcode::acosh:
3252 return { 0, datatree.AddAcosh(New_expr_t) };
3253 case UnaryOpcode::asinh:
3254 return { 0, datatree.AddAsinh(New_expr_t) };
3255 case UnaryOpcode::atanh:
3256 return { 0, datatree.AddAtanh(New_expr_t) };
3257 case UnaryOpcode::sqrt:
3258 return { 0, datatree.AddSqrt(New_expr_t) };
3259 case UnaryOpcode::cbrt:
3260 return { 0, datatree.AddCbrt(New_expr_t) };
3261 case UnaryOpcode::abs:
3262 return { 0, datatree.AddAbs(New_expr_t) };
3263 case UnaryOpcode::sign:
3264 return { 0, datatree.AddSign(New_expr_t) };
3265 case UnaryOpcode::steadyState:
3266 return { 0, datatree.AddSteadyState(New_expr_t) };
3267 case UnaryOpcode::erf:
3268 return { 0, datatree.AddErf(New_expr_t) };
3269 default:
3270 cerr << "Unary operator not handled during the normalization process" << endl;
3271 return { 2, nullptr }; // Could not be normalized
3272 }
3273 }
3274 cerr << "UnaryOpNode::normalizeEquation: impossible case" << endl;
3275 exit(EXIT_FAILURE);
3276 }
3277
3278 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)3279 UnaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
3280 {
3281 expr_t darg = arg->getChainRuleDerivative(deriv_id, recursive_variables);
3282 return composeDerivatives(darg, deriv_id);
3283 }
3284
3285 expr_t
buildSimilarUnaryOpNode(expr_t alt_arg,DataTree & alt_datatree) const3286 UnaryOpNode::buildSimilarUnaryOpNode(expr_t alt_arg, DataTree &alt_datatree) const
3287 {
3288 switch (op_code)
3289 {
3290 case UnaryOpcode::uminus:
3291 return alt_datatree.AddUMinus(alt_arg);
3292 case UnaryOpcode::exp:
3293 return alt_datatree.AddExp(alt_arg);
3294 case UnaryOpcode::log:
3295 return alt_datatree.AddLog(alt_arg);
3296 case UnaryOpcode::log10:
3297 return alt_datatree.AddLog10(alt_arg);
3298 case UnaryOpcode::cos:
3299 return alt_datatree.AddCos(alt_arg);
3300 case UnaryOpcode::sin:
3301 return alt_datatree.AddSin(alt_arg);
3302 case UnaryOpcode::tan:
3303 return alt_datatree.AddTan(alt_arg);
3304 case UnaryOpcode::acos:
3305 return alt_datatree.AddAcos(alt_arg);
3306 case UnaryOpcode::asin:
3307 return alt_datatree.AddAsin(alt_arg);
3308 case UnaryOpcode::atan:
3309 return alt_datatree.AddAtan(alt_arg);
3310 case UnaryOpcode::cosh:
3311 return alt_datatree.AddCosh(alt_arg);
3312 case UnaryOpcode::sinh:
3313 return alt_datatree.AddSinh(alt_arg);
3314 case UnaryOpcode::tanh:
3315 return alt_datatree.AddTanh(alt_arg);
3316 case UnaryOpcode::acosh:
3317 return alt_datatree.AddAcosh(alt_arg);
3318 case UnaryOpcode::asinh:
3319 return alt_datatree.AddAsinh(alt_arg);
3320 case UnaryOpcode::atanh:
3321 return alt_datatree.AddAtanh(alt_arg);
3322 case UnaryOpcode::sqrt:
3323 return alt_datatree.AddSqrt(alt_arg);
3324 case UnaryOpcode::cbrt:
3325 return alt_datatree.AddCbrt(alt_arg);
3326 case UnaryOpcode::abs:
3327 return alt_datatree.AddAbs(alt_arg);
3328 case UnaryOpcode::sign:
3329 return alt_datatree.AddSign(alt_arg);
3330 case UnaryOpcode::steadyState:
3331 return alt_datatree.AddSteadyState(alt_arg);
3332 case UnaryOpcode::steadyStateParamDeriv:
3333 cerr << "UnaryOpNode::buildSimilarUnaryOpNode: UnaryOpcode::steadyStateParamDeriv can't be translated" << endl;
3334 exit(EXIT_FAILURE);
3335 case UnaryOpcode::steadyStateParam2ndDeriv:
3336 cerr << "UnaryOpNode::buildSimilarUnaryOpNode: UnaryOpcode::steadyStateParam2ndDeriv can't be translated" << endl;
3337 exit(EXIT_FAILURE);
3338 case UnaryOpcode::expectation:
3339 return alt_datatree.AddExpectation(expectation_information_set, alt_arg);
3340 case UnaryOpcode::erf:
3341 return alt_datatree.AddErf(alt_arg);
3342 case UnaryOpcode::diff:
3343 return alt_datatree.AddDiff(alt_arg);
3344 case UnaryOpcode::adl:
3345 return alt_datatree.AddAdl(alt_arg, adl_param_name, adl_lags);
3346 }
3347 // Suppress GCC warning
3348 exit(EXIT_FAILURE);
3349 }
3350
3351 expr_t
toStatic(DataTree & static_datatree) const3352 UnaryOpNode::toStatic(DataTree &static_datatree) const
3353 {
3354 expr_t sarg = arg->toStatic(static_datatree);
3355 return buildSimilarUnaryOpNode(sarg, static_datatree);
3356 }
3357
3358 void
computeXrefs(EquationInfo & ei) const3359 UnaryOpNode::computeXrefs(EquationInfo &ei) const
3360 {
3361 arg->computeXrefs(ei);
3362 }
3363
3364 expr_t
clone(DataTree & datatree) const3365 UnaryOpNode::clone(DataTree &datatree) const
3366 {
3367 expr_t substarg = arg->clone(datatree);
3368 return buildSimilarUnaryOpNode(substarg, datatree);
3369 }
3370
3371 int
maxEndoLead() const3372 UnaryOpNode::maxEndoLead() const
3373 {
3374 return arg->maxEndoLead();
3375 }
3376
3377 int
maxExoLead() const3378 UnaryOpNode::maxExoLead() const
3379 {
3380 return arg->maxExoLead();
3381 }
3382
3383 int
maxEndoLag() const3384 UnaryOpNode::maxEndoLag() const
3385 {
3386 return arg->maxEndoLag();
3387 }
3388
3389 int
maxExoLag() const3390 UnaryOpNode::maxExoLag() const
3391 {
3392 return arg->maxExoLag();
3393 }
3394
3395 int
maxLead() const3396 UnaryOpNode::maxLead() const
3397 {
3398 return arg->maxLead();
3399 }
3400
3401 int
maxLag() const3402 UnaryOpNode::maxLag() const
3403 {
3404 return arg->maxLag();
3405 }
3406
3407 int
maxLagWithDiffsExpanded() const3408 UnaryOpNode::maxLagWithDiffsExpanded() const
3409 {
3410 if (op_code == UnaryOpcode::diff)
3411 return arg->maxLagWithDiffsExpanded() + 1;
3412 return arg->maxLagWithDiffsExpanded();
3413 }
3414
3415 expr_t
undiff() const3416 UnaryOpNode::undiff() const
3417 {
3418 if (op_code == UnaryOpcode::diff)
3419 return arg;
3420 return arg->undiff();
3421 }
3422
3423 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const3424 UnaryOpNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
3425 {
3426 auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3427 if (lhs_lag_equiv.find(lag_equiv_repr) == lhs_lag_equiv.end())
3428 return 0;
3429 return arg->maxLag();
3430 }
3431
3432 int
VarMinLag() const3433 UnaryOpNode::VarMinLag() const
3434 {
3435 return arg->VarMinLag();
3436 }
3437
3438 int
PacMaxLag(int lhs_symb_id) const3439 UnaryOpNode::PacMaxLag(int lhs_symb_id) const
3440 {
3441 //This will never be an UnaryOpcode::diff node
3442 return arg->PacMaxLag(lhs_symb_id);
3443 }
3444
3445 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const3446 UnaryOpNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
3447 {
3448 return arg->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
3449 }
3450
3451 expr_t
substituteAdl() const3452 UnaryOpNode::substituteAdl() const
3453 {
3454 if (op_code != UnaryOpcode::adl)
3455 {
3456 expr_t argsubst = arg->substituteAdl();
3457 return buildSimilarUnaryOpNode(argsubst, datatree);
3458 }
3459
3460 expr_t arg1subst = arg->substituteAdl();
3461 expr_t retval = nullptr;
3462 ostringstream inttostr;
3463
3464 for (auto it = adl_lags.begin(); it != adl_lags.end(); ++it)
3465 if (it == adl_lags.begin())
3466 {
3467 inttostr << *it;
3468 retval = datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_" + inttostr.str()), 0),
3469 arg1subst->decreaseLeadsLags(*it));
3470 }
3471 else
3472 {
3473 inttostr.clear();
3474 inttostr.str("");
3475 inttostr << *it;
3476 retval = datatree.AddPlus(retval,
3477 datatree.AddTimes(datatree.AddVariable(datatree.symbol_table.getID(adl_param_name + "_lag_"
3478 + inttostr.str()), 0),
3479 arg1subst->decreaseLeadsLags(*it)));
3480 }
3481 return retval;
3482 }
3483
3484 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const3485 UnaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
3486 {
3487 expr_t argsubst = arg->substituteVarExpectation(subst_table);
3488 return buildSimilarUnaryOpNode(argsubst, datatree);
3489 }
3490
3491 int
countDiffs() const3492 UnaryOpNode::countDiffs() const
3493 {
3494 if (op_code == UnaryOpcode::diff)
3495 return arg->countDiffs() + 1;
3496 return arg->countDiffs();
3497 }
3498
3499 bool
createAuxVarForUnaryOpNode() const3500 UnaryOpNode::createAuxVarForUnaryOpNode() const
3501 {
3502 switch (op_code)
3503 {
3504 case UnaryOpcode::exp:
3505 case UnaryOpcode::log:
3506 case UnaryOpcode::log10:
3507 case UnaryOpcode::cos:
3508 case UnaryOpcode::sin:
3509 case UnaryOpcode::tan:
3510 case UnaryOpcode::acos:
3511 case UnaryOpcode::asin:
3512 case UnaryOpcode::atan:
3513 case UnaryOpcode::cosh:
3514 case UnaryOpcode::sinh:
3515 case UnaryOpcode::tanh:
3516 case UnaryOpcode::acosh:
3517 case UnaryOpcode::asinh:
3518 case UnaryOpcode::atanh:
3519 case UnaryOpcode::sqrt:
3520 case UnaryOpcode::cbrt:
3521 case UnaryOpcode::abs:
3522 case UnaryOpcode::sign:
3523 case UnaryOpcode::erf:
3524 return true;
3525 default:
3526 return false;
3527 }
3528 }
3529
3530 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const3531 UnaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
3532 {
3533 arg->findUnaryOpNodesForAuxVarCreation(nodes);
3534
3535 if (!this->createAuxVarForUnaryOpNode())
3536 return;
3537
3538 auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3539 nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
3540 }
3541
3542 void
findDiffNodes(lag_equivalence_table_t & nodes) const3543 UnaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
3544 {
3545 arg->findDiffNodes(nodes);
3546
3547 if (op_code != UnaryOpcode::diff)
3548 return;
3549
3550 auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3551 nodes[lag_equiv_repr][index] = const_cast<UnaryOpNode *>(this);
3552 }
3553
3554 int
findTargetVariable(int lhs_symb_id) const3555 UnaryOpNode::findTargetVariable(int lhs_symb_id) const
3556 {
3557 return arg->findTargetVariable(lhs_symb_id);
3558 }
3559
3560 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3561 UnaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
3562 vector<BinaryOpNode *> &neweqs) const
3563 {
3564 // If this is not a diff node, then substitute recursively and return
3565 expr_t argsubst = arg->substituteDiff(nodes, subst_table, neweqs);
3566 if (op_code != UnaryOpcode::diff)
3567 return buildSimilarUnaryOpNode(argsubst, datatree);
3568
3569 if (auto sit = subst_table.find(this);
3570 sit != subst_table.end())
3571 return const_cast<VariableNode *>(sit->second);
3572
3573 auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3574 auto it = nodes.find(lag_equiv_repr);
3575 if (it == nodes.end() || it->second.find(index) == it->second.end()
3576 || it->second.at(index) != this)
3577 {
3578 /* diff does not appear in VAR equations, so simply create aux var and return.
3579 Once the comparison of expression nodes works, come back and remove
3580 this part, folding into the next loop. */
3581 int symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, const_cast<UnaryOpNode *>(this));
3582 VariableNode *aux_var = datatree.AddVariable(symb_id, 0);
3583 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
3584 datatree.AddMinus(argsubst,
3585 argsubst->decreaseLeadsLags(1)))));
3586 subst_table[this] = dynamic_cast<VariableNode *>(aux_var);
3587 return const_cast<VariableNode *>(subst_table[this]);
3588 }
3589
3590 /* At this point, we know that this node (and its lagged/leaded brothers)
3591 must be substituted. We create the auxiliary variable and fill the
3592 substitution table for all those similar nodes, in an iteration going from
3593 leads to lags. */
3594 int last_index = 0;
3595 VariableNode *last_aux_var = nullptr;
3596 for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
3597 {
3598 expr_t argsubst = dynamic_cast<UnaryOpNode *>(rit->second)->
3599 arg->substituteDiff(nodes, subst_table, neweqs);
3600 auto vn = dynamic_cast<VariableNode *>(argsubst);
3601 int symb_id;
3602 if (rit == it->second.rbegin())
3603 {
3604 if (vn)
3605 symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, rit->second, vn->symb_id, vn->lag);
3606 else
3607 symb_id = datatree.symbol_table.addDiffAuxiliaryVar(argsubst->idx, rit->second);
3608
3609 // make originating aux var & equation
3610 last_index = rit->first;
3611 last_aux_var = datatree.AddVariable(symb_id, 0);
3612 //ORIG_AUX_DIFF = argsubst - argsubst(-1)
3613 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(last_aux_var,
3614 datatree.AddMinus(argsubst,
3615 argsubst->decreaseLeadsLags(1)))));
3616 subst_table[rit->second] = dynamic_cast<VariableNode *>(last_aux_var);
3617 }
3618 else
3619 {
3620 // just add equation of form: AUX_DIFF = LAST_AUX_VAR(-1)
3621 VariableNode *new_aux_var = nullptr;
3622 for (int i = last_index; i > rit->first; i--)
3623 {
3624 if (i == last_index)
3625 symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(argsubst->idx, rit->second,
3626 last_aux_var->symb_id, last_aux_var->lag);
3627 else
3628 symb_id = datatree.symbol_table.addDiffLagAuxiliaryVar(new_aux_var->idx, rit->second,
3629 last_aux_var->symb_id, last_aux_var->lag);
3630
3631 new_aux_var = datatree.AddVariable(symb_id, 0);
3632 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(new_aux_var,
3633 last_aux_var->decreaseLeadsLags(1))));
3634 last_aux_var = new_aux_var;
3635 }
3636 subst_table[rit->second] = dynamic_cast<VariableNode *>(new_aux_var);
3637 last_index = rit->first;
3638 }
3639 }
3640 return const_cast<VariableNode *>(subst_table[this]);
3641 }
3642
3643 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3644 UnaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3645 {
3646 if (auto sit = subst_table.find(this);
3647 sit != subst_table.end())
3648 return const_cast<VariableNode *>(sit->second);
3649
3650 /* If the equivalence class of this node is not marked for substitution,
3651 then substitute recursively and return. */
3652 auto [lag_equiv_repr, index] = getLagEquivalenceClass();
3653 auto it = nodes.find(lag_equiv_repr);
3654 expr_t argsubst = arg->substituteUnaryOpNodes(nodes, subst_table, neweqs);
3655 if (it == nodes.end())
3656 return buildSimilarUnaryOpNode(argsubst, datatree);
3657
3658 string unary_op;
3659 switch (op_code)
3660 {
3661 case UnaryOpcode::exp:
3662 unary_op = "exp";
3663 break;
3664 case UnaryOpcode::log:
3665 unary_op = "log";
3666 break;
3667 case UnaryOpcode::log10:
3668 unary_op = "log10";
3669 break;
3670 case UnaryOpcode::cos:
3671 unary_op = "cos";
3672 break;
3673 case UnaryOpcode::sin:
3674 unary_op = "sin";
3675 break;
3676 case UnaryOpcode::tan:
3677 unary_op = "tan";
3678 break;
3679 case UnaryOpcode::acos:
3680 unary_op = "acos";
3681 break;
3682 case UnaryOpcode::asin:
3683 unary_op = "asin";
3684 break;
3685 case UnaryOpcode::atan:
3686 unary_op = "atan";
3687 break;
3688 case UnaryOpcode::cosh:
3689 unary_op = "cosh";
3690 break;
3691 case UnaryOpcode::sinh:
3692 unary_op = "sinh";
3693 break;
3694 case UnaryOpcode::tanh:
3695 unary_op = "tanh";
3696 break;
3697 case UnaryOpcode::acosh:
3698 unary_op = "acosh";
3699 break;
3700 case UnaryOpcode::asinh:
3701 unary_op = "asinh";
3702 break;
3703 case UnaryOpcode::atanh:
3704 unary_op = "atanh";
3705 break;
3706 case UnaryOpcode::sqrt:
3707 unary_op = "sqrt";
3708 break;
3709 case UnaryOpcode::cbrt:
3710 unary_op = "cbrt";
3711 break;
3712 case UnaryOpcode::abs:
3713 unary_op = "abs";
3714 break;
3715 case UnaryOpcode::sign:
3716 unary_op = "sign";
3717 break;
3718 case UnaryOpcode::erf:
3719 unary_op = "erf";
3720 break;
3721 default:
3722 cerr << "UnaryOpNode::substituteUnaryOpNodes: Shouldn't arrive here" << endl;
3723 exit(EXIT_FAILURE);
3724 }
3725
3726 /* At this point, we know that this node (and its lagged/leaded brothers)
3727 must be substituted. We create the auxiliary variable and fill the
3728 substitution table for all those similar nodes, in an iteration going from
3729 leads to lags. */
3730 int base_index = 0;
3731 VariableNode *aux_var = nullptr;
3732 for (auto rit = it->second.rbegin(); rit != it->second.rend(); ++rit)
3733 if (rit == it->second.rbegin())
3734 {
3735 /* Verify that we’re not operating on a node with leads, since the
3736 transformation does take into account the expectation operator. We only
3737 need to do this for the first iteration of the loop, because we’re
3738 going from leads to lags. */
3739 if (rit->second->maxLead() > 0)
3740 {
3741 cerr << "Cannot substitute unary operations that contain leads" << endl;
3742 exit(EXIT_FAILURE);
3743 }
3744
3745 int symb_id;
3746 auto vn = dynamic_cast<VariableNode *>(argsubst);
3747 if (!vn)
3748 symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op);
3749 else
3750 symb_id = datatree.symbol_table.addUnaryOpAuxiliaryVar(this->idx, dynamic_cast<UnaryOpNode *>(rit->second), unary_op,
3751 vn->symb_id, vn->lag);
3752 aux_var = datatree.AddVariable(symb_id, 0);
3753 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(aux_var,
3754 dynamic_cast<UnaryOpNode *>(rit->second))));
3755 subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var);
3756 base_index = rit->first;
3757 }
3758 else
3759 subst_table[rit->second] = dynamic_cast<VariableNode *>(aux_var->decreaseLeadsLags(base_index - rit->first));
3760
3761 return const_cast<VariableNode *>(subst_table.find(this)->second);
3762 }
3763
3764 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)3765 UnaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
3766 {
3767 expr_t argsubst = arg->substitutePacExpectation(name, subexpr);
3768 return buildSimilarUnaryOpNode(argsubst, datatree);
3769 }
3770
3771 expr_t
decreaseLeadsLags(int n) const3772 UnaryOpNode::decreaseLeadsLags(int n) const
3773 {
3774 expr_t argsubst = arg->decreaseLeadsLags(n);
3775 return buildSimilarUnaryOpNode(argsubst, datatree);
3776 }
3777
3778 expr_t
decreaseLeadsLagsPredeterminedVariables() const3779 UnaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
3780 {
3781 expr_t argsubst = arg->decreaseLeadsLagsPredeterminedVariables();
3782 return buildSimilarUnaryOpNode(argsubst, datatree);
3783 }
3784
3785 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const3786 UnaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
3787 {
3788 if (op_code == UnaryOpcode::uminus || deterministic_model)
3789 {
3790 expr_t argsubst = arg->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
3791 return buildSimilarUnaryOpNode(argsubst, datatree);
3792 }
3793 else
3794 {
3795 if (maxEndoLead() >= 2)
3796 return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
3797 else
3798 return const_cast<UnaryOpNode *>(this);
3799 }
3800 }
3801
3802 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3803 UnaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3804 {
3805 expr_t argsubst = arg->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
3806 return buildSimilarUnaryOpNode(argsubst, datatree);
3807 }
3808
3809 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const3810 UnaryOpNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
3811 {
3812 if (op_code == UnaryOpcode::uminus || deterministic_model)
3813 {
3814 expr_t argsubst = arg->substituteExoLead(subst_table, neweqs, deterministic_model);
3815 return buildSimilarUnaryOpNode(argsubst, datatree);
3816 }
3817 else
3818 {
3819 if (maxExoLead() >= 1)
3820 return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
3821 else
3822 return const_cast<UnaryOpNode *>(this);
3823 }
3824 }
3825
3826 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3827 UnaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3828 {
3829 expr_t argsubst = arg->substituteExoLag(subst_table, neweqs);
3830 return buildSimilarUnaryOpNode(argsubst, datatree);
3831 }
3832
3833 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const3834 UnaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
3835 {
3836 if (op_code == UnaryOpcode::expectation)
3837 {
3838 if (auto it = subst_table.find(const_cast<UnaryOpNode *>(this)); it != subst_table.end())
3839 return const_cast<VariableNode *>(it->second);
3840
3841 //Arriving here, we need to create an auxiliary variable for this Expectation Operator:
3842 //AUX_EXPECT_(LEAD/LAG)_(period)_(arg.idx) OR
3843 //AUX_EXPECT_(info_set_name)_(arg.idx)
3844 int symb_id = datatree.symbol_table.addExpectationAuxiliaryVar(expectation_information_set, arg->idx, const_cast<UnaryOpNode *>(this));
3845 expr_t newAuxE = datatree.AddVariable(symb_id, 0);
3846
3847 if (partial_information_model && expectation_information_set == 0)
3848 if (!dynamic_cast<VariableNode *>(arg))
3849 {
3850 cerr << "ERROR: In Partial Information models, EXPECTATION(0)(X) "
3851 << "can only be used when X is a single variable." << endl;
3852 exit(EXIT_FAILURE);
3853 }
3854
3855 //take care of any nested expectation operators by calling arg->substituteExpectation(.), then decreaseLeadsLags for this UnaryOpcode::expectation operator
3856 //arg(lag-period) (holds entire subtree of arg(lag-period)
3857 expr_t substexpr = (arg->substituteExpectation(subst_table, neweqs, partial_information_model))->decreaseLeadsLags(expectation_information_set);
3858 assert(substexpr);
3859 neweqs.push_back(dynamic_cast<BinaryOpNode *>(datatree.AddEqual(newAuxE, substexpr))); //AUXE_period_arg.idx = arg(lag-period)
3860 newAuxE = datatree.AddVariable(symb_id, expectation_information_set);
3861
3862 assert(dynamic_cast<VariableNode *>(newAuxE));
3863 subst_table[this] = dynamic_cast<VariableNode *>(newAuxE);
3864 return newAuxE;
3865 }
3866 else
3867 {
3868 expr_t argsubst = arg->substituteExpectation(subst_table, neweqs, partial_information_model);
3869 return buildSimilarUnaryOpNode(argsubst, datatree);
3870 }
3871 }
3872
3873 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const3874 UnaryOpNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
3875 {
3876 expr_t argsubst = arg->differentiateForwardVars(subset, subst_table, neweqs);
3877 return buildSimilarUnaryOpNode(argsubst, datatree);
3878 }
3879
3880 bool
isNumConstNodeEqualTo(double value) const3881 UnaryOpNode::isNumConstNodeEqualTo(double value) const
3882 {
3883 return false;
3884 }
3885
3886 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const3887 UnaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
3888 {
3889 return false;
3890 }
3891
3892 bool
containsPacExpectation(const string & pac_model_name) const3893 UnaryOpNode::containsPacExpectation(const string &pac_model_name) const
3894 {
3895 return arg->containsPacExpectation(pac_model_name);
3896 }
3897
3898 bool
containsEndogenous() const3899 UnaryOpNode::containsEndogenous() const
3900 {
3901 return arg->containsEndogenous();
3902 }
3903
3904 bool
containsExogenous() const3905 UnaryOpNode::containsExogenous() const
3906 {
3907 return arg->containsExogenous();
3908 }
3909
3910 expr_t
replaceTrendVar() const3911 UnaryOpNode::replaceTrendVar() const
3912 {
3913 expr_t argsubst = arg->replaceTrendVar();
3914 return buildSimilarUnaryOpNode(argsubst, datatree);
3915 }
3916
3917 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const3918 UnaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
3919 {
3920 expr_t argsubst = arg->detrend(symb_id, log_trend, trend);
3921 return buildSimilarUnaryOpNode(argsubst, datatree);
3922 }
3923
3924 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const3925 UnaryOpNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
3926 {
3927 expr_t argsubst = arg->removeTrendLeadLag(trend_symbols_map);
3928 return buildSimilarUnaryOpNode(argsubst, datatree);
3929 }
3930
3931 bool
isInStaticForm() const3932 UnaryOpNode::isInStaticForm() const
3933 {
3934 if (op_code == UnaryOpcode::steadyState || op_code == UnaryOpcode::steadyStateParamDeriv
3935 || op_code == UnaryOpcode::steadyStateParam2ndDeriv
3936 || op_code == UnaryOpcode::expectation)
3937 return false;
3938 else
3939 return arg->isInStaticForm();
3940 }
3941
3942 bool
isParamTimesEndogExpr() const3943 UnaryOpNode::isParamTimesEndogExpr() const
3944 {
3945 return arg->isParamTimesEndogExpr();
3946 }
3947
3948 bool
isVarModelReferenced(const string & model_info_name) const3949 UnaryOpNode::isVarModelReferenced(const string &model_info_name) const
3950 {
3951 return arg->isVarModelReferenced(model_info_name);
3952 }
3953
3954 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const3955 UnaryOpNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
3956 {
3957 arg->getEndosAndMaxLags(model_endos_and_lags);
3958 }
3959
3960 expr_t
substituteStaticAuxiliaryVariable() const3961 UnaryOpNode::substituteStaticAuxiliaryVariable() const
3962 {
3963 if (op_code == UnaryOpcode::diff)
3964 return datatree.Zero;
3965
3966 expr_t argsubst = arg->substituteStaticAuxiliaryVariable();
3967 if (op_code == UnaryOpcode::expectation)
3968 return argsubst;
3969 else
3970 return buildSimilarUnaryOpNode(argsubst, datatree);
3971 }
3972
3973 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const3974 UnaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
3975 {
3976 arg->findConstantEquations(table);
3977 }
3978
3979 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const3980 UnaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
3981 {
3982 expr_t argsubst = arg->replaceVarsInEquation(table);
3983 return buildSimilarUnaryOpNode(argsubst, datatree);
3984 }
3985
BinaryOpNode(DataTree & datatree_arg,int idx_arg,const expr_t arg1_arg,BinaryOpcode op_code_arg,const expr_t arg2_arg,int powerDerivOrder_arg)3986 BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
3987 BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder_arg) :
3988 ExprNode{datatree_arg, idx_arg},
3989 arg1{arg1_arg},
3990 arg2{arg2_arg},
3991 op_code{op_code_arg},
3992 powerDerivOrder{powerDerivOrder_arg}
3993 {
3994 assert(powerDerivOrder >= 0);
3995 }
3996
3997 void
prepareForDerivation()3998 BinaryOpNode::prepareForDerivation()
3999 {
4000 if (preparedForDerivation)
4001 return;
4002
4003 preparedForDerivation = true;
4004
4005 arg1->prepareForDerivation();
4006 arg2->prepareForDerivation();
4007
4008 // Non-null derivatives are the union of those of the arguments
4009 // Compute set union of arg1->non_null_derivatives and arg2->non_null_derivatives
4010 set_union(arg1->non_null_derivatives.begin(),
4011 arg1->non_null_derivatives.end(),
4012 arg2->non_null_derivatives.begin(),
4013 arg2->non_null_derivatives.end(),
4014 inserter(non_null_derivatives, non_null_derivatives.begin()));
4015 }
4016
4017 expr_t
getNonZeroPartofEquation() const4018 BinaryOpNode::getNonZeroPartofEquation() const
4019 {
4020 assert(arg1 == datatree.Zero || arg2 == datatree.Zero);
4021 if (arg1 == datatree.Zero)
4022 return arg2;
4023 return arg1;
4024 }
4025
4026 expr_t
composeDerivatives(expr_t darg1,expr_t darg2)4027 BinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2)
4028 {
4029 expr_t t11, t12, t13, t14, t15;
4030
4031 switch (op_code)
4032 {
4033 case BinaryOpcode::plus:
4034 return datatree.AddPlus(darg1, darg2);
4035 case BinaryOpcode::minus:
4036 return datatree.AddMinus(darg1, darg2);
4037 case BinaryOpcode::times:
4038 t11 = datatree.AddTimes(darg1, arg2);
4039 t12 = datatree.AddTimes(darg2, arg1);
4040 return datatree.AddPlus(t11, t12);
4041 case BinaryOpcode::divide:
4042 if (darg2 != datatree.Zero)
4043 {
4044 t11 = datatree.AddTimes(darg1, arg2);
4045 t12 = datatree.AddTimes(darg2, arg1);
4046 t13 = datatree.AddMinus(t11, t12);
4047 t14 = datatree.AddTimes(arg2, arg2);
4048 return datatree.AddDivide(t13, t14);
4049 }
4050 else
4051 return datatree.AddDivide(darg1, arg2);
4052 case BinaryOpcode::less:
4053 case BinaryOpcode::greater:
4054 case BinaryOpcode::lessEqual:
4055 case BinaryOpcode::greaterEqual:
4056 case BinaryOpcode::equalEqual:
4057 case BinaryOpcode::different:
4058 return datatree.Zero;
4059 case BinaryOpcode::power:
4060 if (darg2 == datatree.Zero)
4061 if (darg1 == datatree.Zero)
4062 return datatree.Zero;
4063 else
4064 if (dynamic_cast<NumConstNode *>(arg2))
4065 {
4066 t11 = datatree.AddMinus(arg2, datatree.One);
4067 t12 = datatree.AddPower(arg1, t11);
4068 t13 = datatree.AddTimes(arg2, t12);
4069 return datatree.AddTimes(darg1, t13);
4070 }
4071 else
4072 return datatree.AddTimes(darg1, datatree.AddPowerDeriv(arg1, arg2, powerDerivOrder + 1));
4073 else
4074 {
4075 t11 = datatree.AddLog(arg1);
4076 t12 = datatree.AddTimes(darg2, t11);
4077 t13 = datatree.AddTimes(darg1, arg2);
4078 t14 = datatree.AddDivide(t13, arg1);
4079 t15 = datatree.AddPlus(t12, t14);
4080 return datatree.AddTimes(t15, this);
4081 }
4082 case BinaryOpcode::powerDeriv:
4083 if (darg2 == datatree.Zero)
4084 return datatree.AddTimes(darg1, datatree.AddPowerDeriv(arg1, arg2, powerDerivOrder + 1));
4085 else
4086 {
4087 t11 = datatree.AddTimes(darg2, datatree.AddLog(arg1));
4088 t12 = datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(powerDerivOrder));
4089 t13 = datatree.AddTimes(darg1, t12);
4090 t14 = datatree.AddDivide(t13, arg1);
4091 t15 = datatree.AddPlus(t11, t14);
4092 expr_t f = datatree.AddPower(arg1, t12);
4093 expr_t first_part = datatree.AddTimes(f, t15);
4094
4095 for (int i = 0; i < powerDerivOrder; i++)
4096 first_part = datatree.AddTimes(first_part, datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(i)));
4097
4098 t13 = datatree.Zero;
4099 for (int i = 0; i < powerDerivOrder; i++)
4100 {
4101 t11 = datatree.One;
4102 for (int j = 0; j < powerDerivOrder; j++)
4103 if (i != j)
4104 {
4105 t12 = datatree.AddMinus(arg2, datatree.AddPossiblyNegativeConstant(j));
4106 t11 = datatree.AddTimes(t11, t12);
4107 }
4108 t13 = datatree.AddPlus(t13, t11);
4109 }
4110 t13 = datatree.AddTimes(darg2, t13);
4111 t14 = datatree.AddTimes(f, t13);
4112 return datatree.AddPlus(first_part, t14);
4113 }
4114 case BinaryOpcode::max:
4115 t11 = datatree.AddGreater(arg1, arg2);
4116 t12 = datatree.AddTimes(t11, darg1);
4117 t13 = datatree.AddMinus(datatree.One, t11);
4118 t14 = datatree.AddTimes(t13, darg2);
4119 return datatree.AddPlus(t14, t12);
4120 case BinaryOpcode::min:
4121 t11 = datatree.AddGreater(arg2, arg1);
4122 t12 = datatree.AddTimes(t11, darg1);
4123 t13 = datatree.AddMinus(datatree.One, t11);
4124 t14 = datatree.AddTimes(t13, darg2);
4125 return datatree.AddPlus(t14, t12);
4126 case BinaryOpcode::equal:
4127 return datatree.AddMinus(darg1, darg2);
4128 }
4129 // Suppress GCC warning
4130 exit(EXIT_FAILURE);
4131 }
4132
4133 expr_t
unpackPowerDeriv() const4134 BinaryOpNode::unpackPowerDeriv() const
4135 {
4136 if (op_code != BinaryOpcode::powerDeriv)
4137 return const_cast<BinaryOpNode *>(this);
4138
4139 expr_t front = datatree.One;
4140 for (int i = 0; i < powerDerivOrder; i++)
4141 front = datatree.AddTimes(front,
4142 datatree.AddMinus(arg2,
4143 datatree.AddPossiblyNegativeConstant(i)));
4144 expr_t tmp = datatree.AddPower(arg1,
4145 datatree.AddMinus(arg2,
4146 datatree.AddPossiblyNegativeConstant(powerDerivOrder)));
4147 return datatree.AddTimes(front, tmp);
4148 }
4149
4150 expr_t
computeDerivative(int deriv_id)4151 BinaryOpNode::computeDerivative(int deriv_id)
4152 {
4153 expr_t darg1 = arg1->getDerivative(deriv_id);
4154 expr_t darg2 = arg2->getDerivative(deriv_id);
4155 return composeDerivatives(darg1, darg2);
4156 }
4157
4158 int
precedence(ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms) const4159 BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
4160 {
4161 // A temporary term behaves as a variable
4162 if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4163 return 100;
4164
4165 switch (op_code)
4166 {
4167 case BinaryOpcode::equal:
4168 return 0;
4169 case BinaryOpcode::equalEqual:
4170 case BinaryOpcode::different:
4171 return 1;
4172 case BinaryOpcode::lessEqual:
4173 case BinaryOpcode::greaterEqual:
4174 case BinaryOpcode::less:
4175 case BinaryOpcode::greater:
4176 return 2;
4177 case BinaryOpcode::plus:
4178 case BinaryOpcode::minus:
4179 return 3;
4180 case BinaryOpcode::times:
4181 case BinaryOpcode::divide:
4182 return 4;
4183 case BinaryOpcode::power:
4184 case BinaryOpcode::powerDeriv:
4185 if (isCOutput(output_type))
4186 // In C, power operator is of the form pow(a, b)
4187 return 100;
4188 else
4189 return 5;
4190 case BinaryOpcode::min:
4191 case BinaryOpcode::max:
4192 return 100;
4193 }
4194 // Suppress GCC warning
4195 exit(EXIT_FAILURE);
4196 }
4197
4198 int
precedenceJson(const temporary_terms_t & temporary_terms) const4199 BinaryOpNode::precedenceJson(const temporary_terms_t &temporary_terms) const
4200 {
4201 // A temporary term behaves as a variable
4202 if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4203 return 100;
4204
4205 switch (op_code)
4206 {
4207 case BinaryOpcode::equal:
4208 return 0;
4209 case BinaryOpcode::equalEqual:
4210 case BinaryOpcode::different:
4211 return 1;
4212 case BinaryOpcode::lessEqual:
4213 case BinaryOpcode::greaterEqual:
4214 case BinaryOpcode::less:
4215 case BinaryOpcode::greater:
4216 return 2;
4217 case BinaryOpcode::plus:
4218 case BinaryOpcode::minus:
4219 return 3;
4220 case BinaryOpcode::times:
4221 case BinaryOpcode::divide:
4222 return 4;
4223 case BinaryOpcode::power:
4224 case BinaryOpcode::powerDeriv:
4225 return 5;
4226 case BinaryOpcode::min:
4227 case BinaryOpcode::max:
4228 return 100;
4229 }
4230 // Suppress GCC warning
4231 exit(EXIT_FAILURE);
4232 }
4233
4234 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const4235 BinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
4236 {
4237 // For a temporary term, the cost is null
4238 for (const auto &it : temp_terms_map)
4239 if (it.second.find(const_cast<BinaryOpNode *>(this)) != it.second.end())
4240 return 0;
4241
4242 int arg_cost = arg1->cost(temp_terms_map, is_matlab) + arg2->cost(temp_terms_map, is_matlab);
4243
4244 return cost(arg_cost, is_matlab);
4245 }
4246
4247 int
cost(const temporary_terms_t & temporary_terms,bool is_matlab) const4248 BinaryOpNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
4249 {
4250 // For a temporary term, the cost is null
4251 if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4252 return 0;
4253
4254 int arg_cost = arg1->cost(temporary_terms, is_matlab) + arg2->cost(temporary_terms, is_matlab);
4255
4256 return cost(arg_cost, is_matlab);
4257 }
4258
4259 int
cost(int cost,bool is_matlab) const4260 BinaryOpNode::cost(int cost, bool is_matlab) const
4261 {
4262 if (is_matlab)
4263 // Cost for Matlab files
4264 switch (op_code)
4265 {
4266 case BinaryOpcode::less:
4267 case BinaryOpcode::greater:
4268 case BinaryOpcode::lessEqual:
4269 case BinaryOpcode::greaterEqual:
4270 case BinaryOpcode::equalEqual:
4271 case BinaryOpcode::different:
4272 return cost + 60;
4273 case BinaryOpcode::plus:
4274 case BinaryOpcode::minus:
4275 case BinaryOpcode::times:
4276 return cost + 90;
4277 case BinaryOpcode::max:
4278 case BinaryOpcode::min:
4279 return cost + 110;
4280 case BinaryOpcode::divide:
4281 return cost + 990;
4282 case BinaryOpcode::power:
4283 case BinaryOpcode::powerDeriv:
4284 return cost + (min_cost_matlab/2+1);
4285 case BinaryOpcode::equal:
4286 return cost;
4287 }
4288 else
4289 // Cost for C files
4290 switch (op_code)
4291 {
4292 case BinaryOpcode::less:
4293 case BinaryOpcode::greater:
4294 case BinaryOpcode::lessEqual:
4295 case BinaryOpcode::greaterEqual:
4296 case BinaryOpcode::equalEqual:
4297 case BinaryOpcode::different:
4298 return cost + 2;
4299 case BinaryOpcode::plus:
4300 case BinaryOpcode::minus:
4301 case BinaryOpcode::times:
4302 return cost + 4;
4303 case BinaryOpcode::max:
4304 case BinaryOpcode::min:
4305 return cost + 5;
4306 case BinaryOpcode::divide:
4307 return cost + 15;
4308 case BinaryOpcode::power:
4309 return cost + 520;
4310 case BinaryOpcode::powerDeriv:
4311 return cost + (min_cost_c/2+1);;
4312 case BinaryOpcode::equal:
4313 return cost;
4314 }
4315 // Suppress GCC warning
4316 exit(EXIT_FAILURE);
4317 }
4318
4319 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const4320 BinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
4321 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
4322 map<expr_t, pair<int, pair<int, int>>> &reference_count,
4323 bool is_matlab) const
4324 {
4325 expr_t this2 = const_cast<BinaryOpNode *>(this);
4326 if (auto it = reference_count.find(this2);
4327 it == reference_count.end())
4328 {
4329 // If this node has never been encountered, set its ref count to one,
4330 // and travel through its children
4331 reference_count[this2] = { 1, derivOrder };
4332 arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
4333 arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
4334 }
4335 else
4336 {
4337 /* If the node has already been encountered, increment its ref count
4338 and declare it as a temporary term if it is too costly (except if it is
4339 an equal node: we don't want them as temporary terms) */
4340 reference_count[this2] = { it->second.first + 1, it->second.second };;
4341 if (reference_count[this2].first * cost(temp_terms_map, is_matlab) > min_cost(is_matlab)
4342 && op_code != BinaryOpcode::equal)
4343 temp_terms_map[reference_count[this2].second].insert(this2);
4344 }
4345 }
4346
4347 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const4348 BinaryOpNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
4349 temporary_terms_t &temporary_terms,
4350 map<expr_t, pair<int, int>> &first_occurence,
4351 int Curr_block,
4352 vector<vector<temporary_terms_t>> &v_temporary_terms,
4353 int equation) const
4354 {
4355 expr_t this2 = const_cast<BinaryOpNode *>(this);
4356 if (auto it = reference_count.find(this2);
4357 it == reference_count.end())
4358 {
4359 reference_count[this2] = 1;
4360 first_occurence[this2] = { Curr_block, equation };
4361 arg1->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
4362 arg2->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
4363 }
4364 else
4365 {
4366 reference_count[this2]++;
4367 if (reference_count[this2] * cost(temporary_terms, false) > min_cost_c
4368 && op_code != BinaryOpcode::equal)
4369 {
4370 temporary_terms.insert(this2);
4371 v_temporary_terms[first_occurence[this2].first][first_occurence[this2].second].insert(this2);
4372 }
4373 }
4374 }
4375
4376 double
eval_opcode(double v1,BinaryOpcode op_code,double v2,int derivOrder)4377 BinaryOpNode::eval_opcode(double v1, BinaryOpcode op_code, double v2, int derivOrder) noexcept(false)
4378 {
4379 switch (op_code)
4380 {
4381 case BinaryOpcode::plus:
4382 return v1 + v2;
4383 case BinaryOpcode::minus:
4384 return v1 - v2;
4385 case BinaryOpcode::times:
4386 return v1 * v2;
4387 case BinaryOpcode::divide:
4388 return v1 / v2;
4389 case BinaryOpcode::power:
4390 return pow(v1, v2);
4391 case BinaryOpcode::powerDeriv:
4392 if (fabs(v1) < near_zero && v2 > 0
4393 && derivOrder > v2
4394 && fabs(v2-nearbyint(v2)) < near_zero)
4395 return 0.0;
4396 else
4397 {
4398 double dxp = pow(v1, v2-derivOrder);
4399 for (int i = 0; i < derivOrder; i++)
4400 dxp *= v2--;
4401 return dxp;
4402 }
4403 case BinaryOpcode::max:
4404 if (v1 < v2)
4405 return v2;
4406 else
4407 return v1;
4408 case BinaryOpcode::min:
4409 if (v1 > v2)
4410 return v2;
4411 else
4412 return v1;
4413 case BinaryOpcode::less:
4414 return v1 < v2;
4415 case BinaryOpcode::greater:
4416 return v1 > v2;
4417 case BinaryOpcode::lessEqual:
4418 return v1 <= v2;
4419 case BinaryOpcode::greaterEqual:
4420 return v1 >= v2;
4421 case BinaryOpcode::equalEqual:
4422 return v1 == v2;
4423 case BinaryOpcode::different:
4424 return v1 != v2;
4425 case BinaryOpcode::equal:
4426 throw EvalException();
4427 }
4428 // Suppress GCC warning
4429 exit(EXIT_FAILURE);
4430 }
4431
4432 double
eval(const eval_context_t & eval_context) const4433 BinaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
4434 {
4435 double v1 = arg1->eval(eval_context);
4436 double v2 = arg2->eval(eval_context);
4437 return eval_opcode(v1, op_code, v2, powerDerivOrder);
4438 }
4439
4440 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const4441 BinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
4442 bool lhs_rhs, const temporary_terms_t &temporary_terms,
4443 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
4444 const deriv_node_temp_terms_t &tef_terms) const
4445 {
4446 // If current node is a temporary term
4447 if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4448 {
4449 if (dynamic)
4450 {
4451 auto ii = map_idx.find(idx);
4452 FLDT_ fldt(ii->second);
4453 fldt.write(CompileCode, instruction_number);
4454 }
4455 else
4456 {
4457 auto ii = map_idx.find(idx);
4458 FLDST_ fldst(ii->second);
4459 fldst.write(CompileCode, instruction_number);
4460 }
4461 return;
4462 }
4463 if (op_code == BinaryOpcode::powerDeriv)
4464 {
4465 FLDC_ fldc(powerDerivOrder);
4466 fldc.write(CompileCode, instruction_number);
4467 }
4468 arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
4469 arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
4470 FBINARY_ fbinary{static_cast<int>(op_code)};
4471 fbinary.write(CompileCode, instruction_number);
4472 }
4473
4474 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const4475 BinaryOpNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
4476 {
4477 if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4478 temporary_terms_inuse.insert(idx);
4479 else
4480 {
4481 arg1->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
4482 arg2->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
4483 }
4484 }
4485
4486 bool
containsExternalFunction() const4487 BinaryOpNode::containsExternalFunction() const
4488 {
4489 return arg1->containsExternalFunction()
4490 || arg2->containsExternalFunction();
4491 }
4492
4493 void
writeJsonAST(ostream & output) const4494 BinaryOpNode::writeJsonAST(ostream &output) const
4495 {
4496 output << R"({"node_type" : "BinaryOpNode",)"
4497 << R"( "op" : ")";
4498 switch (op_code)
4499 {
4500 case BinaryOpcode::plus:
4501 output << "+";
4502 break;
4503 case BinaryOpcode::minus:
4504 output << "-";
4505 break;
4506 case BinaryOpcode::times:
4507 output << "*";
4508 break;
4509 case BinaryOpcode::divide:
4510 output << "/";
4511 break;
4512 case BinaryOpcode::power:
4513 output << "^";
4514 break;
4515 case BinaryOpcode::less:
4516 output << "<";
4517 break;
4518 case BinaryOpcode::greater:
4519 output << ">";
4520 break;
4521 case BinaryOpcode::lessEqual:
4522 output << "<=";
4523 break;
4524 case BinaryOpcode::greaterEqual:
4525 output << ">=";
4526 break;
4527 case BinaryOpcode::equalEqual:
4528 output << "==";
4529 break;
4530 case BinaryOpcode::different:
4531 output << "!=";
4532 break;
4533 case BinaryOpcode::equal:
4534 output << "=";
4535 break;
4536 case BinaryOpcode::max:
4537 output << "max";
4538 break;
4539 case BinaryOpcode::min:
4540 output << "min";
4541 break;
4542 case BinaryOpcode::powerDeriv:
4543 output << "power_deriv";
4544 break;
4545 }
4546 output << R"(", "arg1" : )";
4547 arg1->writeJsonAST(output);
4548 output << R"(, "arg2" : )";
4549 arg2->writeJsonAST(output);
4550 output << "}";
4551 }
4552
4553 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const4554 BinaryOpNode::writeJsonOutput(ostream &output,
4555 const temporary_terms_t &temporary_terms,
4556 const deriv_node_temp_terms_t &tef_terms,
4557 bool isdynamic) const
4558 {
4559 // If current node is a temporary term
4560 if (temporary_terms.find(const_cast<BinaryOpNode *>(this)) != temporary_terms.end())
4561 {
4562 output << "T" << idx;
4563 return;
4564 }
4565
4566 if (op_code == BinaryOpcode::max || op_code == BinaryOpcode::min)
4567 {
4568 switch (op_code)
4569 {
4570 case BinaryOpcode::max:
4571 output << "max(";
4572 break;
4573 case BinaryOpcode::min:
4574 output << "min(";
4575 break;
4576 default:
4577 ;
4578 }
4579 arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4580 output << ",";
4581 arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4582 output << ")";
4583 return;
4584 }
4585
4586 if (op_code == BinaryOpcode::powerDeriv)
4587 {
4588 output << "get_power_deriv(";
4589 arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4590 output << ",";
4591 arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4592 output << "," << powerDerivOrder << ")";
4593 return;
4594 }
4595
4596 int prec = precedenceJson(temporary_terms);
4597
4598 bool close_parenthesis = false;
4599
4600 // If left argument has a lower precedence, or if current and left argument are both power operators,
4601 // add parenthesis around left argument
4602 if (auto barg1 = dynamic_cast<BinaryOpNode *>(arg1);
4603 arg1->precedenceJson(temporary_terms) < prec
4604 || (op_code == BinaryOpcode::power && barg1 && barg1->op_code == BinaryOpcode::power))
4605 {
4606 output << "(";
4607 close_parenthesis = true;
4608 }
4609
4610 // Write left argument
4611 arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4612
4613 if (close_parenthesis)
4614 output << ")";
4615
4616 // Write current operator symbol
4617 switch (op_code)
4618 {
4619 case BinaryOpcode::plus:
4620 output << "+";
4621 break;
4622 case BinaryOpcode::minus:
4623 output << "-";
4624 break;
4625 case BinaryOpcode::times:
4626 output << "*";
4627 break;
4628 case BinaryOpcode::divide:
4629 output << "/";
4630 break;
4631 case BinaryOpcode::power:
4632 output << "^";
4633 break;
4634 case BinaryOpcode::less:
4635 output << "<";
4636 break;
4637 case BinaryOpcode::greater:
4638 output << ">";
4639 break;
4640 case BinaryOpcode::lessEqual:
4641 output << "<=";
4642 break;
4643 case BinaryOpcode::greaterEqual:
4644 output << ">=";
4645 break;
4646 case BinaryOpcode::equalEqual:
4647 output << "==";
4648 break;
4649 case BinaryOpcode::different:
4650 output << "!=";
4651 break;
4652 case BinaryOpcode::equal:
4653 output << "=";
4654 break;
4655 default:
4656 ;
4657 }
4658
4659 close_parenthesis = false;
4660
4661 /* Add parenthesis around right argument if:
4662 - its precedence is lower than those of the current node
4663 - it is a power operator and current operator is also a power operator
4664 - it is a minus operator with same precedence than current operator
4665 - it is a divide operator with same precedence than current operator */
4666 auto barg2 = dynamic_cast<BinaryOpNode *>(arg2);
4667 if (int arg2_prec = arg2->precedenceJson(temporary_terms); arg2_prec < prec
4668 || (op_code == BinaryOpcode::power && barg2 && barg2->op_code == BinaryOpcode::power)
4669 || (op_code == BinaryOpcode::minus && arg2_prec == prec)
4670 || (op_code == BinaryOpcode::divide && arg2_prec == prec))
4671 {
4672 output << "(";
4673 close_parenthesis = true;
4674 }
4675
4676 // Write right argument
4677 arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
4678
4679 if (close_parenthesis)
4680 output << ")";
4681 }
4682
4683 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const4684 BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
4685 const temporary_terms_t &temporary_terms,
4686 const temporary_terms_idxs_t &temporary_terms_idxs,
4687 const deriv_node_temp_terms_t &tef_terms) const
4688 {
4689 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
4690 return;
4691
4692 // Treat derivative of Power
4693 if (op_code == BinaryOpcode::powerDeriv)
4694 {
4695 if (isLatexOutput(output_type))
4696 unpackPowerDeriv()->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4697 else
4698 {
4699 if (output_type == ExprNodeOutputType::juliaStaticModel || output_type == ExprNodeOutputType::juliaDynamicModel)
4700 output << "get_power_deriv(";
4701 else
4702 output << "getPowerDeriv(";
4703 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4704 output << ",";
4705 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4706 output << "," << powerDerivOrder << ")";
4707 }
4708 return;
4709 }
4710
4711 // Treat special case of power operator in C, and case of max and min operators
4712 if ((op_code == BinaryOpcode::power && isCOutput(output_type)) || op_code == BinaryOpcode::max || op_code == BinaryOpcode::min)
4713 {
4714 switch (op_code)
4715 {
4716 case BinaryOpcode::power:
4717 output << "pow(";
4718 break;
4719 case BinaryOpcode::max:
4720 output << "max(";
4721 break;
4722 case BinaryOpcode::min:
4723 output << "min(";
4724 break;
4725 default:
4726 ;
4727 }
4728 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4729 output << ",";
4730 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4731 output << ")";
4732 return;
4733 }
4734
4735 int prec = precedence(output_type, temporary_terms);
4736
4737 bool close_parenthesis = false;
4738
4739 if (isLatexOutput(output_type) && op_code == BinaryOpcode::divide)
4740 output << R"(\frac{)";
4741 else
4742 {
4743 // If left argument has a lower precedence, or if current and left argument are both power operators, add parenthesis around left argument
4744 auto barg1 = dynamic_cast<BinaryOpNode *>(arg1);
4745 if (arg1->precedence(output_type, temporary_terms) < prec
4746 || (op_code == BinaryOpcode::power && barg1 != nullptr && barg1->op_code == BinaryOpcode::power))
4747 {
4748 output << LEFT_PAR(output_type);
4749 close_parenthesis = true;
4750 }
4751 }
4752
4753 // Write left argument
4754 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4755
4756 if (close_parenthesis)
4757 output << RIGHT_PAR(output_type);
4758
4759 if (isLatexOutput(output_type) && op_code == BinaryOpcode::divide)
4760 output << "}";
4761
4762 // Write current operator symbol
4763 switch (op_code)
4764 {
4765 case BinaryOpcode::plus:
4766 output << "+";
4767 break;
4768 case BinaryOpcode::minus:
4769 output << "-";
4770 break;
4771 case BinaryOpcode::times:
4772 if (isLatexOutput(output_type))
4773 output << R"(\, )";
4774 else
4775 output << "*";
4776 break;
4777 case BinaryOpcode::divide:
4778 if (!isLatexOutput(output_type))
4779 output << "/";
4780 break;
4781 case BinaryOpcode::power:
4782 output << "^";
4783 break;
4784 case BinaryOpcode::less:
4785 output << "<";
4786 break;
4787 case BinaryOpcode::greater:
4788 output << ">";
4789 break;
4790 case BinaryOpcode::lessEqual:
4791 if (isLatexOutput(output_type))
4792 output << R"(\leq )";
4793 else
4794 output << "<=";
4795 break;
4796 case BinaryOpcode::greaterEqual:
4797 if (isLatexOutput(output_type))
4798 output << R"(\geq )";
4799 else
4800 output << ">=";
4801 break;
4802 case BinaryOpcode::equalEqual:
4803 output << "==";
4804 break;
4805 case BinaryOpcode::different:
4806 if (isMatlabOutput(output_type))
4807 output << "~=";
4808 else
4809 {
4810 if (isCOutput(output_type) || isJuliaOutput(output_type))
4811 output << "!=";
4812 else
4813 output << R"(\neq )";
4814 }
4815 break;
4816 case BinaryOpcode::equal:
4817 output << "=";
4818 break;
4819 default:
4820 ;
4821 }
4822
4823 close_parenthesis = false;
4824
4825 if (isLatexOutput(output_type) && (op_code == BinaryOpcode::power || op_code == BinaryOpcode::divide))
4826 output << "{";
4827 else
4828 {
4829 /* Add parenthesis around right argument if:
4830 - its precedence is lower than those of the current node
4831 - it is a power operator and current operator is also a power operator
4832 - it is a minus operator with same precedence than current operator
4833 - it is a divide operator with same precedence than current operator */
4834 auto barg2 = dynamic_cast<BinaryOpNode *>(arg2);
4835 if (int arg2_prec = arg2->precedence(output_type, temporary_terms); arg2_prec < prec
4836 || (op_code == BinaryOpcode::power && barg2 && barg2->op_code == BinaryOpcode::power && !isLatexOutput(output_type))
4837 || (op_code == BinaryOpcode::minus && arg2_prec == prec)
4838 || (op_code == BinaryOpcode::divide && arg2_prec == prec && !isLatexOutput(output_type)))
4839 {
4840 output << LEFT_PAR(output_type);
4841 close_parenthesis = true;
4842 }
4843 }
4844
4845 // Write right argument
4846 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4847
4848 if (isLatexOutput(output_type) && (op_code == BinaryOpcode::power || op_code == BinaryOpcode::divide))
4849 output << "}";
4850
4851 if (close_parenthesis)
4852 output << RIGHT_PAR(output_type);
4853 }
4854
4855 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const4856 BinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
4857 const temporary_terms_t &temporary_terms,
4858 const temporary_terms_idxs_t &temporary_terms_idxs,
4859 deriv_node_temp_terms_t &tef_terms) const
4860 {
4861 arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4862 arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
4863 }
4864
4865 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const4866 BinaryOpNode::writeJsonExternalFunctionOutput(vector<string> &efout,
4867 const temporary_terms_t &temporary_terms,
4868 deriv_node_temp_terms_t &tef_terms,
4869 bool isdynamic) const
4870 {
4871 arg1->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
4872 arg2->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
4873 }
4874
4875 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const4876 BinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
4877 bool lhs_rhs, const temporary_terms_t &temporary_terms,
4878 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
4879 deriv_node_temp_terms_t &tef_terms) const
4880 {
4881 arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
4882 dynamic, steady_dynamic, tef_terms);
4883 arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
4884 dynamic, steady_dynamic, tef_terms);
4885 }
4886
4887 int
VarMinLag() const4888 BinaryOpNode::VarMinLag() const
4889 {
4890 return min(arg1->VarMinLag(), arg2->VarMinLag());
4891 }
4892
4893 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const4894 BinaryOpNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
4895 {
4896 return max(arg1->VarMaxLag(lhs_lag_equiv),
4897 arg2->VarMaxLag(lhs_lag_equiv));
4898 }
4899
4900 void
collectVARLHSVariable(set<expr_t> & result) const4901 BinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
4902 {
4903 cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
4904 exit(EXIT_FAILURE);
4905 }
4906
4907 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const4908 BinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
4909 {
4910 arg1->collectDynamicVariables(type_arg, result);
4911 arg2->collectDynamicVariables(type_arg, result);
4912 }
4913
4914 expr_t
Compute_RHS(expr_t arg1,expr_t arg2,int op,int op_type) const4915 BinaryOpNode::Compute_RHS(expr_t arg1, expr_t arg2, int op, int op_type) const
4916 {
4917 temporary_terms_t temp;
4918 switch (op_type)
4919 {
4920 case 0: /*Unary Operator*/
4921 switch (static_cast<UnaryOpcode>(op))
4922 {
4923 case UnaryOpcode::uminus:
4924 return (datatree.AddUMinus(arg1));
4925 break;
4926 case UnaryOpcode::exp:
4927 return (datatree.AddExp(arg1));
4928 break;
4929 case UnaryOpcode::log:
4930 return (datatree.AddLog(arg1));
4931 break;
4932 case UnaryOpcode::log10:
4933 return (datatree.AddLog10(arg1));
4934 break;
4935 default:
4936 cerr << "BinaryOpNode::Compute_RHS: case not handled";
4937 exit(EXIT_FAILURE);
4938 }
4939 break;
4940 case 1: /*Binary Operator*/
4941 switch (static_cast<BinaryOpcode>(op))
4942 {
4943 case BinaryOpcode::plus:
4944 return (datatree.AddPlus(arg1, arg2));
4945 break;
4946 case BinaryOpcode::minus:
4947 return (datatree.AddMinus(arg1, arg2));
4948 break;
4949 case BinaryOpcode::times:
4950 return (datatree.AddTimes(arg1, arg2));
4951 break;
4952 case BinaryOpcode::divide:
4953 return (datatree.AddDivide(arg1, arg2));
4954 break;
4955 case BinaryOpcode::power:
4956 return (datatree.AddPower(arg1, arg2));
4957 break;
4958 default:
4959 cerr << "BinaryOpNode::Compute_RHS: case not handled";
4960 exit(EXIT_FAILURE);
4961 }
4962 break;
4963 }
4964 return nullptr;
4965 }
4966
4967 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const4968 BinaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
4969 {
4970 /* Checks if the current value of the endogenous variable related to the equation
4971 is present in the arguments of the binary operator. */
4972 vector<tuple<int, expr_t, expr_t>> List_of_Op_RHS1, List_of_Op_RHS2;
4973 pair<int, expr_t> res = arg1->normalizeEquation(var_endo, List_of_Op_RHS1);
4974 int is_endogenous_present_1 = res.first;
4975 expr_t expr_t_1 = res.second;
4976
4977 res = arg2->normalizeEquation(var_endo, List_of_Op_RHS2);
4978 int is_endogenous_present_2 = res.first;
4979 expr_t expr_t_2 = res.second;
4980
4981 /* If the two expressions contains the current value of the endogenous variable associated to the equation
4982 the equation could not be normalized and the process is given-up.*/
4983 if (is_endogenous_present_1 == 2 || is_endogenous_present_2 == 2)
4984 return { 2, nullptr };
4985 else if (is_endogenous_present_1 && is_endogenous_present_2)
4986 return { 2, nullptr };
4987 else if (is_endogenous_present_1) /*If the current values of the endogenous variable associated to the equation
4988 is present only in the first operand of the expression, we try to normalize the equation*/
4989 {
4990 if (op_code == BinaryOpcode::equal) /* The end of the normalization process :
4991 All the operations needed to normalize the equation are applied. */
4992 while (!List_of_Op_RHS1.empty())
4993 {
4994 tuple<int, expr_t, expr_t> it = List_of_Op_RHS1.back();
4995 List_of_Op_RHS1.pop_back();
4996 if (get<1>(it) && !get<2>(it)) /*Binary operator*/
4997 expr_t_2 = Compute_RHS(expr_t_2, static_cast<BinaryOpNode *>(get<1>(it)), get<0>(it), 1);
4998 else if (get<2>(it) && !get<1>(it)) /*Binary operator*/
4999 expr_t_2 = Compute_RHS(get<2>(it), expr_t_2, get<0>(it), 1);
5000 else if (get<2>(it) && get<1>(it)) /*Binary operator*/
5001 expr_t_2 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1);
5002 else /*Unary operator*/
5003 expr_t_2 = Compute_RHS(static_cast<UnaryOpNode *>(expr_t_2), static_cast<UnaryOpNode *>(get<1>(it)), get<0>(it), 0);
5004 }
5005 else
5006 List_of_Op_RHS = List_of_Op_RHS1;
5007 }
5008 else if (is_endogenous_present_2)
5009 {
5010 if (op_code == BinaryOpcode::equal)
5011 while (!List_of_Op_RHS2.empty())
5012 {
5013 tuple<int, expr_t, expr_t> it = List_of_Op_RHS2.back();
5014 List_of_Op_RHS2.pop_back();
5015 if (get<1>(it) && !get<2>(it)) /*Binary operator*/
5016 expr_t_1 = Compute_RHS(static_cast<BinaryOpNode *>(expr_t_1), static_cast<BinaryOpNode *>(get<1>(it)), get<0>(it), 1);
5017 else if (get<2>(it) && !get<1>(it)) /*Binary operator*/
5018 expr_t_1 = Compute_RHS(static_cast<BinaryOpNode *>(get<2>(it)), static_cast<BinaryOpNode *>(expr_t_1), get<0>(it), 1);
5019 else if (get<2>(it) && get<1>(it)) /*Binary operator*/
5020 expr_t_1 = Compute_RHS(get<1>(it), get<2>(it), get<0>(it), 1);
5021 else
5022 expr_t_1 = Compute_RHS(static_cast<UnaryOpNode *>(expr_t_1), static_cast<UnaryOpNode *>(get<1>(it)), get<0>(it), 0);
5023 }
5024 else
5025 List_of_Op_RHS = List_of_Op_RHS2;
5026 }
5027 switch (op_code)
5028 {
5029 case BinaryOpcode::plus:
5030 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5031 return { 0, datatree.AddPlus(expr_t_1, expr_t_2) };
5032 else if (is_endogenous_present_1 && is_endogenous_present_2)
5033 return { 2, nullptr };
5034 else if (!is_endogenous_present_1 && is_endogenous_present_2)
5035 {
5036 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_1, nullptr);
5037 return { 1, expr_t_1 };
5038 }
5039 else if (is_endogenous_present_1 && !is_endogenous_present_2)
5040 {
5041 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_2, nullptr);
5042 return { 1, expr_t_2 };
5043 }
5044 break;
5045 case BinaryOpcode::minus:
5046 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5047 return { 0, datatree.AddMinus(expr_t_1, expr_t_2) };
5048 else if (is_endogenous_present_1 && is_endogenous_present_2)
5049 return { 2, nullptr };
5050 else if (!is_endogenous_present_1 && is_endogenous_present_2)
5051 {
5052 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::uminus), nullptr, nullptr);
5053 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::minus), expr_t_1, nullptr);
5054 return { 1, expr_t_1 };
5055 }
5056 else if (is_endogenous_present_1 && !is_endogenous_present_2)
5057 {
5058 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::plus), expr_t_2, nullptr);
5059 return { 1, datatree.AddUMinus(expr_t_2) };
5060 }
5061 break;
5062 case BinaryOpcode::times:
5063 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5064 return { 0, datatree.AddTimes(expr_t_1, expr_t_2) };
5065 else if (!is_endogenous_present_1 && is_endogenous_present_2)
5066 {
5067 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), expr_t_1, nullptr);
5068 return { 1, expr_t_1 };
5069 }
5070 else if (is_endogenous_present_1 && !is_endogenous_present_2)
5071 {
5072 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), expr_t_2, nullptr);
5073 return { 1, expr_t_2 };
5074 }
5075 else
5076 return { 2, nullptr };
5077 break;
5078 case BinaryOpcode::divide:
5079 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5080 return { 0, datatree.AddDivide(expr_t_1, expr_t_2) };
5081 else if (!is_endogenous_present_1 && is_endogenous_present_2)
5082 {
5083 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), nullptr, expr_t_1);
5084 return { 1, expr_t_1 };
5085 }
5086 else if (is_endogenous_present_1 && !is_endogenous_present_2)
5087 {
5088 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::times), expr_t_2, nullptr);
5089 return { 1, expr_t_2 };
5090 }
5091 else
5092 return { 2, nullptr };
5093 break;
5094 case BinaryOpcode::power:
5095 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5096 return { 0, datatree.AddPower(expr_t_1, expr_t_2) };
5097 else if (is_endogenous_present_1 && !is_endogenous_present_2)
5098 {
5099 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), datatree.AddDivide(datatree.One, expr_t_2), nullptr);
5100 return { 1, nullptr };
5101 }
5102 else if (!is_endogenous_present_1 && is_endogenous_present_2)
5103 {
5104 /* we have to nomalize a^f(X) = RHS */
5105 /* First computes the ln(RHS)*/
5106 List_of_Op_RHS.emplace_back(static_cast<int>(UnaryOpcode::log), nullptr, nullptr);
5107 /* Second computes f(X) = ln(RHS) / ln(a)*/
5108 List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::divide), nullptr, datatree.AddLog(expr_t_1));
5109 return { 1, nullptr };
5110 }
5111 break;
5112 case BinaryOpcode::equal:
5113 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5114 {
5115 return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.AddMinus(expr_t_2, expr_t_1)) };
5116 }
5117 else if (is_endogenous_present_1 && is_endogenous_present_2)
5118 {
5119 return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), datatree.Zero) };
5120 }
5121 else if (!is_endogenous_present_1 && is_endogenous_present_2)
5122 {
5123 return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), /*datatree.AddUMinus(expr_t_1)*/ expr_t_1) };
5124 }
5125 else if (is_endogenous_present_1 && !is_endogenous_present_2)
5126 {
5127 return { 0, datatree.AddEqual(datatree.AddVariable(datatree.symbol_table.getID(SymbolType::endogenous, var_endo), 0), expr_t_2) };
5128 }
5129 break;
5130 case BinaryOpcode::max:
5131 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5132 return { 0, datatree.AddMax(expr_t_1, expr_t_2) };
5133 else
5134 return { 2, nullptr };
5135 break;
5136 case BinaryOpcode::min:
5137 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5138 return { 0, datatree.AddMin(expr_t_1, expr_t_2) };
5139 else
5140 return { 2, nullptr };
5141 break;
5142 case BinaryOpcode::less:
5143 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5144 return { 0, datatree.AddLess(expr_t_1, expr_t_2) };
5145 else
5146 return { 2, nullptr };
5147 break;
5148 case BinaryOpcode::greater:
5149 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5150 return { 0, datatree.AddGreater(expr_t_1, expr_t_2) };
5151 else
5152 return { 2, nullptr };
5153 break;
5154 case BinaryOpcode::lessEqual:
5155 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5156 return { 0, datatree.AddLessEqual(expr_t_1, expr_t_2) };
5157 else
5158 return { 2, nullptr };
5159 break;
5160 case BinaryOpcode::greaterEqual:
5161 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5162 return { 0, datatree.AddGreaterEqual(expr_t_1, expr_t_2) };
5163 else
5164 return { 2, nullptr };
5165 break;
5166 case BinaryOpcode::equalEqual:
5167 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5168 return { 0, datatree.AddEqualEqual(expr_t_1, expr_t_2) };
5169 else
5170 return { 2, nullptr };
5171 break;
5172 case BinaryOpcode::different:
5173 if (!is_endogenous_present_1 && !is_endogenous_present_2)
5174 return { 0, datatree.AddDifferent(expr_t_1, expr_t_2) };
5175 else
5176 return { 2, nullptr };
5177 break;
5178 default:
5179 cerr << "Binary operator not handled during the normalization process" << endl;
5180 return { 2, nullptr }; // Could not be normalized
5181 }
5182 // Suppress GCC warning
5183 cerr << "BinaryOpNode::normalizeEquation: impossible case" << endl;
5184 exit(EXIT_FAILURE);
5185 }
5186
5187 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)5188 BinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
5189 {
5190 expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
5191 expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
5192 return composeDerivatives(darg1, darg2);
5193 }
5194
5195 expr_t
buildSimilarBinaryOpNode(expr_t alt_arg1,expr_t alt_arg2,DataTree & alt_datatree) const5196 BinaryOpNode::buildSimilarBinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, DataTree &alt_datatree) const
5197 {
5198 switch (op_code)
5199 {
5200 case BinaryOpcode::plus:
5201 return alt_datatree.AddPlus(alt_arg1, alt_arg2);
5202 case BinaryOpcode::minus:
5203 return alt_datatree.AddMinus(alt_arg1, alt_arg2);
5204 case BinaryOpcode::times:
5205 return alt_datatree.AddTimes(alt_arg1, alt_arg2);
5206 case BinaryOpcode::divide:
5207 return alt_datatree.AddDivide(alt_arg1, alt_arg2);
5208 case BinaryOpcode::power:
5209 return alt_datatree.AddPower(alt_arg1, alt_arg2);
5210 case BinaryOpcode::equal:
5211 return alt_datatree.AddEqual(alt_arg1, alt_arg2);
5212 case BinaryOpcode::max:
5213 return alt_datatree.AddMax(alt_arg1, alt_arg2);
5214 case BinaryOpcode::min:
5215 return alt_datatree.AddMin(alt_arg1, alt_arg2);
5216 case BinaryOpcode::less:
5217 return alt_datatree.AddLess(alt_arg1, alt_arg2);
5218 case BinaryOpcode::greater:
5219 return alt_datatree.AddGreater(alt_arg1, alt_arg2);
5220 case BinaryOpcode::lessEqual:
5221 return alt_datatree.AddLessEqual(alt_arg1, alt_arg2);
5222 case BinaryOpcode::greaterEqual:
5223 return alt_datatree.AddGreaterEqual(alt_arg1, alt_arg2);
5224 case BinaryOpcode::equalEqual:
5225 return alt_datatree.AddEqualEqual(alt_arg1, alt_arg2);
5226 case BinaryOpcode::different:
5227 return alt_datatree.AddDifferent(alt_arg1, alt_arg2);
5228 case BinaryOpcode::powerDeriv:
5229 return alt_datatree.AddPowerDeriv(alt_arg1, alt_arg2, powerDerivOrder);
5230 }
5231 // Suppress GCC warning
5232 exit(EXIT_FAILURE);
5233 }
5234
5235 expr_t
toStatic(DataTree & static_datatree) const5236 BinaryOpNode::toStatic(DataTree &static_datatree) const
5237 {
5238 expr_t sarg1 = arg1->toStatic(static_datatree);
5239 expr_t sarg2 = arg2->toStatic(static_datatree);
5240 return buildSimilarBinaryOpNode(sarg1, sarg2, static_datatree);
5241 }
5242
5243 void
computeXrefs(EquationInfo & ei) const5244 BinaryOpNode::computeXrefs(EquationInfo &ei) const
5245 {
5246 arg1->computeXrefs(ei);
5247 arg2->computeXrefs(ei);
5248 }
5249
5250 expr_t
clone(DataTree & datatree) const5251 BinaryOpNode::clone(DataTree &datatree) const
5252 {
5253 expr_t substarg1 = arg1->clone(datatree);
5254 expr_t substarg2 = arg2->clone(datatree);
5255 return buildSimilarBinaryOpNode(substarg1, substarg2, datatree);
5256 }
5257
5258 int
maxEndoLead() const5259 BinaryOpNode::maxEndoLead() const
5260 {
5261 return max(arg1->maxEndoLead(), arg2->maxEndoLead());
5262 }
5263
5264 int
maxExoLead() const5265 BinaryOpNode::maxExoLead() const
5266 {
5267 return max(arg1->maxExoLead(), arg2->maxExoLead());
5268 }
5269
5270 int
maxEndoLag() const5271 BinaryOpNode::maxEndoLag() const
5272 {
5273 return max(arg1->maxEndoLag(), arg2->maxEndoLag());
5274 }
5275
5276 int
maxExoLag() const5277 BinaryOpNode::maxExoLag() const
5278 {
5279 return max(arg1->maxExoLag(), arg2->maxExoLag());
5280 }
5281
5282 int
maxLead() const5283 BinaryOpNode::maxLead() const
5284 {
5285 return max(arg1->maxLead(), arg2->maxLead());
5286 }
5287
5288 int
maxLag() const5289 BinaryOpNode::maxLag() const
5290 {
5291 return max(arg1->maxLag(), arg2->maxLag());
5292 }
5293
5294 int
maxLagWithDiffsExpanded() const5295 BinaryOpNode::maxLagWithDiffsExpanded() const
5296 {
5297 return max(arg1->maxLagWithDiffsExpanded(), arg2->maxLagWithDiffsExpanded());
5298 }
5299
5300 expr_t
undiff() const5301 BinaryOpNode::undiff() const
5302 {
5303 expr_t arg1subst = arg1->undiff();
5304 expr_t arg2subst = arg2->undiff();
5305 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5306 }
5307
5308 int
PacMaxLag(int lhs_symb_id) const5309 BinaryOpNode::PacMaxLag(int lhs_symb_id) const
5310 {
5311 return max(arg1->PacMaxLag(lhs_symb_id), arg2->PacMaxLag(lhs_symb_id));
5312 }
5313
5314 int
getPacTargetSymbIdHelper(int lhs_symb_id,int undiff_lhs_symb_id,const set<pair<int,int>> & endogs) const5315 BinaryOpNode::getPacTargetSymbIdHelper(int lhs_symb_id, int undiff_lhs_symb_id, const set<pair<int, int>> &endogs) const
5316 {
5317 int target_symb_id = -1;
5318 bool found_lagged_lhs = false;
5319 for (auto &it : endogs)
5320 {
5321 int id = datatree.symbol_table.getUltimateOrigSymbID(it.first);
5322 if (id == lhs_symb_id || id == undiff_lhs_symb_id)
5323 found_lagged_lhs = true;
5324 if (id != lhs_symb_id && id != undiff_lhs_symb_id)
5325 if (target_symb_id < 0)
5326 target_symb_id = it.first;
5327 }
5328 if (!found_lagged_lhs)
5329 target_symb_id = -1;
5330 return target_symb_id;
5331 }
5332
5333 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const5334 BinaryOpNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
5335 {
5336 set<pair<int, int>> endogs;
5337 arg1->collectDynamicVariables(SymbolType::endogenous, endogs);
5338 int target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, undiff_lhs_symb_id, endogs);
5339 if (target_symb_id >= 0)
5340 return target_symb_id;
5341
5342 endogs.clear();
5343 arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
5344 target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, undiff_lhs_symb_id, endogs);
5345
5346 if (target_symb_id < 0)
5347 {
5348 cerr << "Error finding target variable in PAC equation" << endl;
5349 exit(EXIT_FAILURE);
5350 }
5351
5352 return target_symb_id;
5353 }
5354
5355 expr_t
decreaseLeadsLags(int n) const5356 BinaryOpNode::decreaseLeadsLags(int n) const
5357 {
5358 expr_t arg1subst = arg1->decreaseLeadsLags(n);
5359 expr_t arg2subst = arg2->decreaseLeadsLags(n);
5360 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5361 }
5362
5363 expr_t
decreaseLeadsLagsPredeterminedVariables() const5364 BinaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
5365 {
5366 expr_t arg1subst = arg1->decreaseLeadsLagsPredeterminedVariables();
5367 expr_t arg2subst = arg2->decreaseLeadsLagsPredeterminedVariables();
5368 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5369 }
5370
5371 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const5372 BinaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
5373 {
5374 expr_t arg1subst, arg2subst;
5375 int maxendolead1 = arg1->maxEndoLead(), maxendolead2 = arg2->maxEndoLead();
5376
5377 if (maxendolead1 < 2 && maxendolead2 < 2)
5378 return const_cast<BinaryOpNode *>(this);
5379 if (deterministic_model)
5380 {
5381 arg1subst = maxendolead1 >= 2 ? arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg1;
5382 arg2subst = maxendolead2 >= 2 ? arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg2;
5383 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5384 }
5385 else
5386 switch (op_code)
5387 {
5388 case BinaryOpcode::plus:
5389 case BinaryOpcode::minus:
5390 case BinaryOpcode::equal:
5391 arg1subst = maxendolead1 >= 2 ? arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg1;
5392 arg2subst = maxendolead2 >= 2 ? arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model) : arg2;
5393 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5394 case BinaryOpcode::times:
5395 case BinaryOpcode::divide:
5396 if (maxendolead1 >= 2 && maxendolead2 == 0 && arg2->maxExoLead() == 0)
5397 {
5398 arg1subst = arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
5399 return buildSimilarBinaryOpNode(arg1subst, arg2, datatree);
5400 }
5401 if (maxendolead1 == 0 && arg1->maxExoLead() == 0
5402 && maxendolead2 >= 2 && op_code == BinaryOpcode::times)
5403 {
5404 arg2subst = arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
5405 return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
5406 }
5407 return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5408 default:
5409 return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5410 }
5411 }
5412
5413 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5414 BinaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5415 {
5416 expr_t arg1subst = arg1->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
5417 expr_t arg2subst = arg2->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
5418 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5419 }
5420
5421 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const5422 BinaryOpNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
5423 {
5424 expr_t arg1subst, arg2subst;
5425 int maxexolead1 = arg1->maxExoLead(), maxexolead2 = arg2->maxExoLead();
5426
5427 if (maxexolead1 < 1 && maxexolead2 < 1)
5428 return const_cast<BinaryOpNode *>(this);
5429 if (deterministic_model)
5430 {
5431 arg1subst = maxexolead1 >= 1 ? arg1->substituteExoLead(subst_table, neweqs, deterministic_model) : arg1;
5432 arg2subst = maxexolead2 >= 1 ? arg2->substituteExoLead(subst_table, neweqs, deterministic_model) : arg2;
5433 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5434 }
5435 else
5436 switch (op_code)
5437 {
5438 case BinaryOpcode::plus:
5439 case BinaryOpcode::minus:
5440 case BinaryOpcode::equal:
5441 arg1subst = maxexolead1 >= 1 ? arg1->substituteExoLead(subst_table, neweqs, deterministic_model) : arg1;
5442 arg2subst = maxexolead2 >= 1 ? arg2->substituteExoLead(subst_table, neweqs, deterministic_model) : arg2;
5443 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5444 case BinaryOpcode::times:
5445 case BinaryOpcode::divide:
5446 if (maxexolead1 >= 1 && maxexolead2 == 0 && arg2->maxEndoLead() == 0)
5447 {
5448 arg1subst = arg1->substituteExoLead(subst_table, neweqs, deterministic_model);
5449 return buildSimilarBinaryOpNode(arg1subst, arg2, datatree);
5450 }
5451 if (maxexolead1 == 0 && arg1->maxEndoLead() == 0
5452 && maxexolead2 >= 1 && op_code == BinaryOpcode::times)
5453 {
5454 arg2subst = arg2->substituteExoLead(subst_table, neweqs, deterministic_model);
5455 return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
5456 }
5457 return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5458 default:
5459 return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
5460 }
5461 }
5462
5463 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5464 BinaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5465 {
5466 expr_t arg1subst = arg1->substituteExoLag(subst_table, neweqs);
5467 expr_t arg2subst = arg2->substituteExoLag(subst_table, neweqs);
5468 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5469 }
5470
5471 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const5472 BinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
5473 {
5474 expr_t arg1subst = arg1->substituteExpectation(subst_table, neweqs, partial_information_model);
5475 expr_t arg2subst = arg2->substituteExpectation(subst_table, neweqs, partial_information_model);
5476 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5477 }
5478
5479 expr_t
substituteAdl() const5480 BinaryOpNode::substituteAdl() const
5481 {
5482 expr_t arg1subst = arg1->substituteAdl();
5483 expr_t arg2subst = arg2->substituteAdl();
5484 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5485 }
5486
5487 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const5488 BinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
5489 {
5490 expr_t arg1subst = arg1->substituteVarExpectation(subst_table);
5491 expr_t arg2subst = arg2->substituteVarExpectation(subst_table);
5492 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5493 }
5494
5495 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const5496 BinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
5497 {
5498 arg1->findUnaryOpNodesForAuxVarCreation(nodes);
5499 arg2->findUnaryOpNodesForAuxVarCreation(nodes);
5500 }
5501
5502 void
findDiffNodes(lag_equivalence_table_t & nodes) const5503 BinaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
5504 {
5505 arg1->findDiffNodes(nodes);
5506 arg2->findDiffNodes(nodes);
5507 }
5508
5509 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5510 BinaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
5511 vector<BinaryOpNode *> &neweqs) const
5512 {
5513 expr_t arg1subst = arg1->substituteDiff(nodes, subst_table, neweqs);
5514 expr_t arg2subst = arg2->substituteDiff(nodes, subst_table, neweqs);
5515 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5516 }
5517
5518 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5519 BinaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5520 {
5521 expr_t arg1subst = arg1->substituteUnaryOpNodes(nodes, subst_table, neweqs);
5522 expr_t arg2subst = arg2->substituteUnaryOpNodes(nodes, subst_table, neweqs);
5523 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5524 }
5525
5526 int
countDiffs() const5527 BinaryOpNode::countDiffs() const
5528 {
5529 return max(arg1->countDiffs(), arg2->countDiffs());
5530 }
5531
5532 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)5533 BinaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
5534 {
5535 expr_t arg1subst = arg1->substitutePacExpectation(name, subexpr);
5536 expr_t arg2subst = arg2->substitutePacExpectation(name, subexpr);
5537 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5538 }
5539
5540 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const5541 BinaryOpNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
5542 {
5543 expr_t arg1subst = arg1->differentiateForwardVars(subset, subst_table, neweqs);
5544 expr_t arg2subst = arg2->differentiateForwardVars(subset, subst_table, neweqs);
5545 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5546 }
5547
5548 expr_t
addMultipliersToConstraints(int i)5549 BinaryOpNode::addMultipliersToConstraints(int i)
5550 {
5551 int symb_id = datatree.symbol_table.addMultiplierAuxiliaryVar(i);
5552 expr_t newAuxLM = datatree.AddVariable(symb_id, 0);
5553 return datatree.AddEqual(datatree.AddTimes(newAuxLM, datatree.AddMinus(arg1, arg2)), datatree.Zero);
5554 }
5555
5556 bool
isNumConstNodeEqualTo(double value) const5557 BinaryOpNode::isNumConstNodeEqualTo(double value) const
5558 {
5559 return false;
5560 }
5561
5562 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const5563 BinaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
5564 {
5565 return false;
5566 }
5567
5568 bool
containsPacExpectation(const string & pac_model_name) const5569 BinaryOpNode::containsPacExpectation(const string &pac_model_name) const
5570 {
5571 return arg1->containsPacExpectation(pac_model_name) || arg2->containsPacExpectation(pac_model_name);
5572 }
5573
5574 bool
containsEndogenous() const5575 BinaryOpNode::containsEndogenous() const
5576 {
5577 return arg1->containsEndogenous() || arg2->containsEndogenous();
5578 }
5579
5580 bool
containsExogenous() const5581 BinaryOpNode::containsExogenous() const
5582 {
5583 return arg1->containsExogenous() || arg2->containsExogenous();
5584 }
5585
5586 expr_t
replaceTrendVar() const5587 BinaryOpNode::replaceTrendVar() const
5588 {
5589 expr_t arg1subst = arg1->replaceTrendVar();
5590 expr_t arg2subst = arg2->replaceTrendVar();
5591 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5592 }
5593
5594 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const5595 BinaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
5596 {
5597 expr_t arg1subst = arg1->detrend(symb_id, log_trend, trend);
5598 expr_t arg2subst = arg2->detrend(symb_id, log_trend, trend);
5599 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5600 }
5601
5602 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const5603 BinaryOpNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
5604 {
5605 expr_t arg1subst = arg1->removeTrendLeadLag(trend_symbols_map);
5606 expr_t arg2subst = arg2->removeTrendLeadLag(trend_symbols_map);
5607 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
5608 }
5609
5610 bool
isInStaticForm() const5611 BinaryOpNode::isInStaticForm() const
5612 {
5613 return arg1->isInStaticForm() && arg2->isInStaticForm();
5614 }
5615
5616 bool
findTargetVariableHelper1(int lhs_symb_id,int rhs_symb_id) const5617 BinaryOpNode::findTargetVariableHelper1(int lhs_symb_id, int rhs_symb_id) const
5618 {
5619 if (lhs_symb_id == rhs_symb_id)
5620 return true;
5621
5622 try
5623 {
5624 if (datatree.symbol_table.isAuxiliaryVariable(rhs_symb_id)
5625 && lhs_symb_id == datatree.symbol_table.getOrigSymbIdForAuxVar(rhs_symb_id))
5626 return true;
5627 }
5628 catch (...)
5629 {
5630 }
5631 return false;
5632 }
5633
5634 int
findTargetVariableHelper(const expr_t arg1,const expr_t arg2,int lhs_symb_id) const5635 BinaryOpNode::findTargetVariableHelper(const expr_t arg1, const expr_t arg2,
5636 int lhs_symb_id) const
5637 {
5638 set<int> params;
5639 arg1->collectVariables(SymbolType::parameter, params);
5640 if (params.size() != 1)
5641 return -1;
5642
5643 set<pair<int, int>> endogs;
5644 arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
5645 if (auto testarg2 = dynamic_cast<BinaryOpNode *>(arg2);
5646 endogs.size() == 2 && testarg2 && testarg2->op_code == BinaryOpcode::minus
5647 && dynamic_cast<VariableNode *>(testarg2->arg1)
5648 && dynamic_cast<VariableNode *>(testarg2->arg2))
5649 {
5650 if (findTargetVariableHelper1(lhs_symb_id, endogs.begin()->first))
5651 return endogs.rbegin()->first;
5652 else if (findTargetVariableHelper1(lhs_symb_id, endogs.rbegin()->first))
5653 return endogs.begin()->first;
5654 }
5655 return -1;
5656 }
5657
5658 int
findTargetVariable(int lhs_symb_id) const5659 BinaryOpNode::findTargetVariable(int lhs_symb_id) const
5660 {
5661 int retval = findTargetVariableHelper(arg1, arg2, lhs_symb_id);
5662 if (retval < 0)
5663 retval = findTargetVariableHelper(arg2, arg1, lhs_symb_id);
5664 if (retval < 0)
5665 retval = arg1->findTargetVariable(lhs_symb_id);
5666 if (retval < 0)
5667 retval = arg2->findTargetVariable(lhs_symb_id);
5668 return retval;
5669 }
5670
5671 pair<int, vector<tuple<int, bool, int>>>
getPacEC(BinaryOpNode * bopn,int lhs_symb_id,int lhs_orig_symb_id) const5672 BinaryOpNode::getPacEC(BinaryOpNode *bopn, int lhs_symb_id, int lhs_orig_symb_id) const
5673 {
5674 pair<int, vector<tuple<int, bool, int>>> ec_params_and_vars = {-1, vector<tuple<int, bool, int>>()};
5675 int optim_param_symb_id = -1;
5676 expr_t optim_part = nullptr;
5677 set<pair<int, int>> endogs;
5678 bopn->collectDynamicVariables(SymbolType::endogenous, endogs);
5679 int target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, lhs_orig_symb_id, endogs);
5680 if (target_symb_id >= 0 && bopn->isParamTimesEndogExpr())
5681 {
5682 optim_part = bopn->arg2;
5683 auto vn = dynamic_cast<VariableNode *>(bopn->arg1);
5684 if (!vn || datatree.symbol_table.getType(vn->symb_id) != SymbolType::parameter)
5685 {
5686 optim_part = bopn->arg1;
5687 vn = dynamic_cast<VariableNode *>(bopn->arg2);
5688 }
5689 if (!vn || datatree.symbol_table.getType(vn->symb_id) != SymbolType::parameter)
5690 return ec_params_and_vars;
5691 optim_param_symb_id = vn->symb_id;
5692 }
5693 if (optim_param_symb_id >= 0)
5694 {
5695 endogs.clear();
5696 optim_part->collectDynamicVariables(SymbolType::endogenous, endogs);
5697 optim_part->collectDynamicVariables(SymbolType::exogenous, endogs);
5698 if (endogs.size() != 2)
5699 {
5700 cerr << "Error getting EC part of PAC equation" << endl;
5701 exit(EXIT_FAILURE);
5702 }
5703 vector<pair<expr_t, int>> terms;
5704 vector<tuple<int, bool, int>> ordered_symb_ids;
5705 optim_part->decomposeAdditiveTerms(terms, 1);
5706 for (const auto &it : terms)
5707 {
5708 int scale = it.second;
5709 auto vn = dynamic_cast<VariableNode *>(it.first);
5710 if (!vn
5711 || !(datatree.symbol_table.getType(vn->symb_id) == SymbolType::endogenous
5712 || datatree.symbol_table.getType(vn->symb_id) == SymbolType::exogenous))
5713 {
5714 cerr << "Problem with error component portion of PAC equation" << endl;
5715 exit(EXIT_FAILURE);
5716 }
5717 int id = vn->symb_id;
5718 int orig_id = datatree.symbol_table.getUltimateOrigSymbID(id);
5719 bool istarget = true;
5720 if (orig_id == lhs_symb_id || orig_id == lhs_orig_symb_id)
5721 istarget = false;
5722 ordered_symb_ids.emplace_back(id, istarget, scale);
5723 }
5724 ec_params_and_vars = { optim_param_symb_id, ordered_symb_ids };
5725 }
5726 return ec_params_and_vars;
5727 }
5728
5729 void
getPacAREC(int lhs_symb_id,int lhs_orig_symb_id,pair<int,vector<tuple<int,bool,int>>> & ec_params_and_vars,set<pair<int,pair<int,int>>> & ar_params_and_vars,vector<tuple<int,int,int,double>> & additive_vars_params_and_constants) const5730 BinaryOpNode::getPacAREC(int lhs_symb_id, int lhs_orig_symb_id,
5731 pair<int, vector<tuple<int, bool, int>>> &ec_params_and_vars,
5732 set<pair<int, pair<int, int>>> &ar_params_and_vars,
5733 vector<tuple<int, int, int, double>> &additive_vars_params_and_constants) const
5734 {
5735 vector<pair<expr_t, int>> terms;
5736 decomposeAdditiveTerms(terms, 1);
5737 for (auto it = terms.begin(); it != terms.end(); ++it)
5738 if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
5739 {
5740 ec_params_and_vars = getPacEC(bopn, lhs_symb_id, lhs_orig_symb_id);
5741 if (ec_params_and_vars.first >= 0)
5742 {
5743 terms.erase(it);
5744 break;
5745 }
5746 }
5747
5748 if (ec_params_and_vars.first < 0)
5749 {
5750 cerr << "Error finding EC part of PAC equation" << endl;
5751 exit(EXIT_FAILURE);
5752 }
5753
5754 for (const auto &it : terms)
5755 {
5756 if (dynamic_cast<PacExpectationNode *>(it.first))
5757 continue;
5758
5759 pair<int, vector<tuple<int, int, int, double>>> m;
5760 try
5761 {
5762 m = {-1, {it.first->matchVariableTimesConstantTimesParam()}};
5763 }
5764 catch (MatchFailureException &e)
5765 {
5766 try
5767 {
5768 m = it.first->matchParamTimesLinearCombinationOfVariables();
5769 }
5770 catch (MatchFailureException &e)
5771 {
5772 cerr << "Unsupported expression in PAC equation" << endl;
5773 exit(EXIT_FAILURE);
5774 }
5775 }
5776
5777 for (auto &t : m.second)
5778 get<3>(t) *= it.second; // Update sign of constants
5779
5780 int pid = get<0>(m);
5781 for (auto [vid, lag, pidtmp, constant] : m.second)
5782 {
5783 if (pid == -1)
5784 pid = pidtmp;
5785 else if (pidtmp >= 0)
5786 {
5787 cerr << "unexpected parameter found in PAC equation" << endl;
5788 exit(EXIT_FAILURE);
5789 }
5790
5791 if (int vidorig = datatree.symbol_table.getUltimateOrigSymbID(vid);
5792 vidorig == lhs_symb_id || vidorig == lhs_orig_symb_id)
5793 {
5794 // This is an autoregressive term
5795 if (constant != 1 || pid == -1)
5796 {
5797 cerr << "BinaryOpNode::getPacAREC: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
5798 exit(EXIT_FAILURE);
5799 }
5800 ar_params_and_vars.insert({pid, { vid, lag }});
5801 }
5802 else
5803 // This is a residual additive term
5804 additive_vars_params_and_constants.emplace_back(vid, lag, pid, constant);
5805 }
5806 }
5807 }
5808
5809 bool
isParamTimesEndogExpr() const5810 BinaryOpNode::isParamTimesEndogExpr() const
5811 {
5812 if (op_code == BinaryOpcode::times)
5813 {
5814 set<int> params;
5815 auto test_arg1 = dynamic_cast<VariableNode *>(arg1);
5816 auto test_arg2 = dynamic_cast<VariableNode *>(arg2);
5817 if (test_arg1)
5818 arg1->collectVariables(SymbolType::parameter, params);
5819 else if (test_arg2)
5820 arg2->collectVariables(SymbolType::parameter, params);
5821 else
5822 return false;
5823
5824 if (params.size() != 1)
5825 return false;
5826
5827 params.clear();
5828 set<pair<int, int>> endogs, exogs;
5829 if (test_arg1)
5830 {
5831 arg2->collectDynamicVariables(SymbolType::endogenous, endogs);
5832 arg2->collectDynamicVariables(SymbolType::exogenous, exogs);
5833 arg2->collectVariables(SymbolType::parameter, params);
5834 if (params.size() == 0 && exogs.size() == 0 && endogs.size() >= 1)
5835 return true;
5836 }
5837 else
5838 {
5839 arg1->collectDynamicVariables(SymbolType::endogenous, endogs);
5840 arg1->collectDynamicVariables(SymbolType::exogenous, exogs);
5841 arg1->collectVariables(SymbolType::parameter, params);
5842 if (params.size() == 0 && exogs.size() == 0 && endogs.size() >= 1)
5843 return true;
5844 }
5845 }
5846 else if (op_code == BinaryOpcode::plus)
5847 return arg1->isParamTimesEndogExpr() || arg2->isParamTimesEndogExpr();
5848 return false;
5849 }
5850
5851 bool
getPacNonOptimizingPartHelper(BinaryOpNode * bopn,int optim_share) const5852 BinaryOpNode::getPacNonOptimizingPartHelper(BinaryOpNode *bopn, int optim_share) const
5853 {
5854 set<int> params;
5855 bopn->collectVariables(SymbolType::parameter, params);
5856 if (params.size() == 1 && *params.begin() == optim_share)
5857 return true;
5858 return false;
5859 }
5860
5861 expr_t
getPacNonOptimizingPart(BinaryOpNode * bopn,int optim_share) const5862 BinaryOpNode::getPacNonOptimizingPart(BinaryOpNode *bopn, int optim_share) const
5863 {
5864 auto a1 = dynamic_cast<BinaryOpNode *>(bopn->arg1);
5865 auto a2 = dynamic_cast<BinaryOpNode *>(bopn->arg2);
5866 if (!a1 && !a2)
5867 return nullptr;
5868
5869 if (a1 && getPacNonOptimizingPartHelper(a1, optim_share))
5870 return bopn->arg2;
5871
5872 if (a2 && getPacNonOptimizingPartHelper(a2, optim_share))
5873 return bopn->arg1;
5874
5875 return nullptr;
5876 }
5877
5878 pair<int, expr_t>
getPacOptimizingShareAndExprNodesHelper(BinaryOpNode * bopn,int lhs_symb_id,int lhs_orig_symb_id) const5879 BinaryOpNode::getPacOptimizingShareAndExprNodesHelper(BinaryOpNode *bopn, int lhs_symb_id, int lhs_orig_symb_id) const
5880 {
5881 int optim_param_symb_id = -1;
5882 expr_t optim_part = nullptr;
5883 set<pair<int, int>> endogs;
5884 bopn->collectDynamicVariables(SymbolType::endogenous, endogs);
5885 int target_symb_id = getPacTargetSymbIdHelper(lhs_symb_id, lhs_orig_symb_id, endogs);
5886 if (target_symb_id >= 0)
5887 {
5888 set<int> params;
5889 if (bopn->arg1->isParamTimesEndogExpr() && !bopn->arg2->isParamTimesEndogExpr())
5890 {
5891 optim_part = bopn->arg1;
5892 bopn->arg2->collectVariables(SymbolType::parameter, params);
5893 optim_param_symb_id = *(params.begin());
5894 }
5895 else if (bopn->arg2->isParamTimesEndogExpr() && !bopn->arg1->isParamTimesEndogExpr())
5896 {
5897 optim_part = bopn->arg2;
5898 bopn->arg1->collectVariables(SymbolType::parameter, params);
5899 optim_param_symb_id = *(params.begin());
5900 }
5901 }
5902 return {optim_param_symb_id, optim_part};
5903 }
5904
5905 tuple<int, expr_t, expr_t, expr_t>
getPacOptimizingShareAndExprNodes(int lhs_symb_id,int lhs_orig_symb_id) const5906 BinaryOpNode::getPacOptimizingShareAndExprNodes(int lhs_symb_id, int lhs_orig_symb_id) const
5907 {
5908 vector<pair<expr_t, int>> terms;
5909 decomposeAdditiveTerms(terms, 1);
5910 for (auto &it : terms)
5911 if (dynamic_cast<PacExpectationNode *>(it.first))
5912 // if the pac_expectation operator is additive in the expression
5913 // there are no optimizing shares
5914 return {-1, nullptr, nullptr, nullptr};
5915
5916 int optim_share;
5917 expr_t optim_part, non_optim_part, additive_part;
5918 optim_part = non_optim_part = additive_part = nullptr;
5919
5920 for (auto it = terms.begin(); it != terms.end(); ++it)
5921 if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
5922 {
5923 tie(optim_share, optim_part)
5924 = getPacOptimizingShareAndExprNodesHelper(bopn, lhs_symb_id, lhs_orig_symb_id);
5925 if (optim_share >= 0 && optim_part)
5926 {
5927 terms.erase(it);
5928 break;
5929 }
5930 }
5931
5932 if (!optim_part)
5933 return {-1, nullptr, nullptr, nullptr};
5934
5935 for (auto it = terms.begin(); it != terms.end(); ++it)
5936 if (auto bopn = dynamic_cast<BinaryOpNode *>(it->first); bopn)
5937 {
5938 non_optim_part = getPacNonOptimizingPart(bopn, optim_share);
5939 if (non_optim_part)
5940 {
5941 terms.erase(it);
5942 break;
5943 }
5944 }
5945
5946 if (!non_optim_part)
5947 return {-1, nullptr, nullptr, nullptr};
5948 else
5949 {
5950 additive_part = datatree.Zero;
5951 for (auto it : terms)
5952 additive_part = datatree.AddPlus(additive_part, it.first);
5953 if (additive_part == datatree.Zero)
5954 additive_part = nullptr;
5955 }
5956
5957 return {optim_share, optim_part, non_optim_part, additive_part};
5958 }
5959
5960 void
fillAutoregressiveRow(int eqn,const vector<int> & lhs,map<tuple<int,int,int>,expr_t> & AR) const5961 BinaryOpNode::fillAutoregressiveRow(int eqn, const vector<int> &lhs, map<tuple<int, int, int>, expr_t> &AR) const
5962 {
5963 vector<pair<expr_t, int>> terms;
5964 decomposeAdditiveTerms(terms, 1);
5965 for (const auto &it : terms)
5966 {
5967 int vid, lag, param_id;
5968 double constant;
5969 try
5970 {
5971 tie(vid, lag, param_id, constant) = it.first->matchVariableTimesConstantTimesParam();
5972 constant *= it.second;
5973 }
5974 catch (MatchFailureException &e)
5975 {
5976 continue;
5977 }
5978
5979 if (datatree.symbol_table.isDiffAuxiliaryVariable(vid))
5980 {
5981 lag = -datatree.symbol_table.getOrigLeadLagForDiffAuxVar(vid);
5982 vid = datatree.symbol_table.getOrigSymbIdForDiffAuxVar(vid);
5983 }
5984
5985 if (find(lhs.begin(), lhs.end(), vid) == lhs.end())
5986 continue;
5987
5988 if (AR.find({eqn, -lag, vid}) != AR.end())
5989 {
5990 cerr << "BinaryOpNode::fillAutoregressiveRow: Error filling AR matrix: "
5991 << "lag/symb_id encountered more than once in equation" << endl;
5992 exit(EXIT_FAILURE);
5993 }
5994 if (constant != 1 || param_id == -1)
5995 {
5996 cerr << "BinaryOpNode::fillAutoregressiveRow: autoregressive terms must be of the form 'parameter*lagged_variable" << endl;
5997 exit(EXIT_FAILURE);
5998 }
5999 AR[{eqn, -lag, vid}] = datatree.AddVariable(param_id);
6000 }
6001 }
6002
6003 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const6004 BinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
6005 {
6006 if (op_code == BinaryOpcode::equal)
6007 {
6008 if (dynamic_cast<VariableNode *>(arg1) && dynamic_cast<NumConstNode *>(arg2))
6009 table[dynamic_cast<VariableNode *>(arg1)] = dynamic_cast<NumConstNode *>(arg2);
6010 else if (dynamic_cast<VariableNode *>(arg2) && dynamic_cast<NumConstNode *>(arg1))
6011 table[dynamic_cast<VariableNode *>(arg2)] = dynamic_cast<NumConstNode *>(arg1);
6012 }
6013 else
6014 {
6015 arg1->findConstantEquations(table);
6016 arg2->findConstantEquations(table);
6017 }
6018 }
6019
6020 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const6021 BinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
6022 {
6023 if (op_code == BinaryOpcode::equal)
6024 for (auto &it : table)
6025 if ((it.first == arg1 && it.second == arg2) || (it.first == arg2 && it.second == arg1))
6026 return const_cast<BinaryOpNode *>(this);
6027 expr_t arg1subst = arg1->replaceVarsInEquation(table);
6028 expr_t arg2subst = arg2->replaceVarsInEquation(table);
6029 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
6030 }
6031
6032 bool
isVarModelReferenced(const string & model_info_name) const6033 BinaryOpNode::isVarModelReferenced(const string &model_info_name) const
6034 {
6035 return arg1->isVarModelReferenced(model_info_name)
6036 || arg2->isVarModelReferenced(model_info_name);
6037 }
6038
6039 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const6040 BinaryOpNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
6041 {
6042 arg1->getEndosAndMaxLags(model_endos_and_lags);
6043 arg2->getEndosAndMaxLags(model_endos_and_lags);
6044 }
6045
6046 expr_t
substituteStaticAuxiliaryVariable() const6047 BinaryOpNode::substituteStaticAuxiliaryVariable() const
6048 {
6049 expr_t arg1subst = arg1->substituteStaticAuxiliaryVariable();
6050 expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
6051 return buildSimilarBinaryOpNode(arg1subst, arg2subst, datatree);
6052 }
6053
6054 expr_t
substituteStaticAuxiliaryDefinition() const6055 BinaryOpNode::substituteStaticAuxiliaryDefinition() const
6056 {
6057 expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
6058 return buildSimilarBinaryOpNode(arg1, arg2subst, datatree);
6059 }
6060
TrinaryOpNode(DataTree & datatree_arg,int idx_arg,const expr_t arg1_arg,TrinaryOpcode op_code_arg,const expr_t arg2_arg,const expr_t arg3_arg)6061 TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg,
6062 TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg) :
6063 ExprNode{datatree_arg, idx_arg},
6064 arg1{arg1_arg},
6065 arg2{arg2_arg},
6066 arg3{arg3_arg},
6067 op_code{op_code_arg}
6068 {
6069 }
6070
6071 void
prepareForDerivation()6072 TrinaryOpNode::prepareForDerivation()
6073 {
6074 if (preparedForDerivation)
6075 return;
6076
6077 preparedForDerivation = true;
6078
6079 arg1->prepareForDerivation();
6080 arg2->prepareForDerivation();
6081 arg3->prepareForDerivation();
6082
6083 // Non-null derivatives are the union of those of the arguments
6084 // Compute set union of arg{1,2,3}->non_null_derivatives
6085 set<int> non_null_derivatives_tmp;
6086 set_union(arg1->non_null_derivatives.begin(),
6087 arg1->non_null_derivatives.end(),
6088 arg2->non_null_derivatives.begin(),
6089 arg2->non_null_derivatives.end(),
6090 inserter(non_null_derivatives_tmp, non_null_derivatives_tmp.begin()));
6091 set_union(non_null_derivatives_tmp.begin(),
6092 non_null_derivatives_tmp.end(),
6093 arg3->non_null_derivatives.begin(),
6094 arg3->non_null_derivatives.end(),
6095 inserter(non_null_derivatives, non_null_derivatives.begin()));
6096 }
6097
6098 expr_t
composeDerivatives(expr_t darg1,expr_t darg2,expr_t darg3)6099 TrinaryOpNode::composeDerivatives(expr_t darg1, expr_t darg2, expr_t darg3)
6100 {
6101
6102 expr_t t11, t12, t13, t14, t15;
6103
6104 switch (op_code)
6105 {
6106 case TrinaryOpcode::normcdf:
6107 // normal pdf is inlined in the tree
6108 expr_t y;
6109 // sqrt(2*pi)
6110 t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
6111 // x - mu
6112 t12 = datatree.AddMinus(arg1, arg2);
6113 // y = (x-mu)/sigma
6114 y = datatree.AddDivide(t12, arg3);
6115 // (x-mu)^2/sigma^2
6116 t12 = datatree.AddTimes(y, y);
6117 // -(x-mu)^2/sigma^2
6118 t13 = datatree.AddUMinus(t12);
6119 // -((x-mu)^2/sigma^2)/2
6120 t12 = datatree.AddDivide(t13, datatree.Two);
6121 // exp(-((x-mu)^2/sigma^2)/2)
6122 t13 = datatree.AddExp(t12);
6123 // derivative of a standardized normal
6124 // t15 = (1/sqrt(2*pi))*exp(-y^2/2)
6125 t15 = datatree.AddDivide(t13, t14);
6126 // derivatives thru x
6127 t11 = datatree.AddDivide(darg1, arg3);
6128 // derivatives thru mu
6129 t12 = datatree.AddDivide(darg2, arg3);
6130 // intermediary sum
6131 t14 = datatree.AddMinus(t11, t12);
6132 // derivatives thru sigma
6133 t11 = datatree.AddDivide(y, arg3);
6134 t12 = datatree.AddTimes(t11, darg3);
6135 //intermediary sum
6136 t11 = datatree.AddMinus(t14, t12);
6137 // total derivative:
6138 // (darg1/sigma - darg2/sigma - darg3*(x-mu)/sigma^2) * t15
6139 // where t15 is the derivative of a standardized normal
6140 return datatree.AddTimes(t11, t15);
6141 case TrinaryOpcode::normpdf:
6142 // (x - mu)
6143 t11 = datatree.AddMinus(arg1, arg2);
6144 // (x - mu)/sigma
6145 t12 = datatree.AddDivide(t11, arg3);
6146 // darg3 * (x - mu)/sigma
6147 t11 = datatree.AddTimes(darg3, t12);
6148 // darg2 - darg1
6149 t13 = datatree.AddMinus(darg2, darg1);
6150 // darg2 - darg1 + darg3 * (x - mu)/sigma
6151 t14 = datatree.AddPlus(t13, t11);
6152 // ((x - mu)/sigma) * (darg2 - darg1 + darg3 * (x - mu)/sigma)
6153 t11 = datatree.AddTimes(t12, t14);
6154 // ((x - mu)/sigma) * (darg2 - darg1 + darg3 * (x - mu)/sigma) - darg3
6155 t12 = datatree.AddMinus(t11, darg3);
6156 // this / sigma
6157 t11 = datatree.AddDivide(this, arg3);
6158 // total derivative:
6159 // (this / sigma) * (((x - mu)/sigma) * (darg2 - darg1 + darg3 * (x - mu)/sigma) - darg3)
6160 return datatree.AddTimes(t11, t12);
6161 }
6162 // Suppress GCC warning
6163 exit(EXIT_FAILURE);
6164 }
6165
6166 expr_t
computeDerivative(int deriv_id)6167 TrinaryOpNode::computeDerivative(int deriv_id)
6168 {
6169 expr_t darg1 = arg1->getDerivative(deriv_id);
6170 expr_t darg2 = arg2->getDerivative(deriv_id);
6171 expr_t darg3 = arg3->getDerivative(deriv_id);
6172 return composeDerivatives(darg1, darg2, darg3);
6173 }
6174
6175 int
precedence(ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms) const6176 TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const
6177 {
6178 // A temporary term behaves as a variable
6179 if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6180 return 100;
6181
6182 switch (op_code)
6183 {
6184 case TrinaryOpcode::normcdf:
6185 case TrinaryOpcode::normpdf:
6186 return 100;
6187 }
6188 // Suppress GCC warning
6189 exit(EXIT_FAILURE);
6190 }
6191
6192 int
cost(const map<pair<int,int>,temporary_terms_t> & temp_terms_map,bool is_matlab) const6193 TrinaryOpNode::cost(const map<pair<int, int>, temporary_terms_t> &temp_terms_map, bool is_matlab) const
6194 {
6195 // For a temporary term, the cost is null
6196 for (const auto &it : temp_terms_map)
6197 if (it.second.find(const_cast<TrinaryOpNode *>(this)) != it.second.end())
6198 return 0;
6199
6200 int arg_cost = arg1->cost(temp_terms_map, is_matlab)
6201 + arg2->cost(temp_terms_map, is_matlab)
6202 + arg3->cost(temp_terms_map, is_matlab);
6203
6204 return cost(arg_cost, is_matlab);
6205 }
6206
6207 int
cost(const temporary_terms_t & temporary_terms,bool is_matlab) const6208 TrinaryOpNode::cost(const temporary_terms_t &temporary_terms, bool is_matlab) const
6209 {
6210 // For a temporary term, the cost is null
6211 if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6212 return 0;
6213
6214 int arg_cost = arg1->cost(temporary_terms, is_matlab)
6215 + arg2->cost(temporary_terms, is_matlab)
6216 + arg3->cost(temporary_terms, is_matlab);
6217
6218 return cost(arg_cost, is_matlab);
6219 }
6220
6221 int
cost(int cost,bool is_matlab) const6222 TrinaryOpNode::cost(int cost, bool is_matlab) const
6223 {
6224 if (is_matlab)
6225 // Cost for Matlab files
6226 switch (op_code)
6227 {
6228 case TrinaryOpcode::normcdf:
6229 case TrinaryOpcode::normpdf:
6230 return cost+1000;
6231 }
6232 else
6233 // Cost for C files
6234 switch (op_code)
6235 {
6236 case TrinaryOpcode::normcdf:
6237 case TrinaryOpcode::normpdf:
6238 return cost+1000;
6239 }
6240 // Suppress GCC warning
6241 exit(EXIT_FAILURE);
6242 }
6243
6244 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const6245 TrinaryOpNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
6246 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
6247 map<expr_t, pair<int, pair<int, int>>> &reference_count,
6248 bool is_matlab) const
6249 {
6250 expr_t this2 = const_cast<TrinaryOpNode *>(this);
6251 if (auto it = reference_count.find(this2);
6252 it == reference_count.end())
6253 {
6254 // If this node has never been encountered, set its ref count to one,
6255 // and travel through its children
6256 reference_count[this2] = { 1, derivOrder };
6257 arg1->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
6258 arg2->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
6259 arg3->computeTemporaryTerms(derivOrder, temp_terms_map, reference_count, is_matlab);
6260 }
6261 else
6262 {
6263 // If the node has already been encountered, increment its ref count
6264 // and declare it as a temporary term if it is too costly
6265 reference_count[this2] = { it->second.first + 1, it->second.second };
6266 if (reference_count[this2].first * cost(temp_terms_map, is_matlab) > min_cost(is_matlab))
6267 temp_terms_map[reference_count[this2].second].insert(this2);
6268 }
6269 }
6270
6271 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const6272 TrinaryOpNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
6273 temporary_terms_t &temporary_terms,
6274 map<expr_t, pair<int, int>> &first_occurence,
6275 int Curr_block,
6276 vector<vector<temporary_terms_t>> &v_temporary_terms,
6277 int equation) const
6278 {
6279 expr_t this2 = const_cast<TrinaryOpNode *>(this);
6280 if (auto it = reference_count.find(this2);
6281 it == reference_count.end())
6282 {
6283 reference_count[this2] = 1;
6284 first_occurence[this2] = { Curr_block, equation };
6285 arg1->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
6286 arg2->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
6287 arg3->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, v_temporary_terms, equation);
6288 }
6289 else
6290 {
6291 reference_count[this2]++;
6292 if (reference_count[this2] * cost(temporary_terms, false) > min_cost_c)
6293 {
6294 temporary_terms.insert(this2);
6295 v_temporary_terms[first_occurence[this2].first][first_occurence[this2].second].insert(this2);
6296 }
6297 }
6298 }
6299
6300 double
eval_opcode(double v1,TrinaryOpcode op_code,double v2,double v3)6301 TrinaryOpNode::eval_opcode(double v1, TrinaryOpcode op_code, double v2, double v3) noexcept(false)
6302 {
6303 switch (op_code)
6304 {
6305 case TrinaryOpcode::normcdf:
6306 return (0.5*(1+erf((v1-v2)/v3/M_SQRT2)));
6307 case TrinaryOpcode::normpdf:
6308 return (1/(v3*sqrt(2*M_PI)*exp(pow((v1-v2)/v3, 2)/2)));
6309 }
6310 // Suppress GCC warning
6311 exit(EXIT_FAILURE);
6312 }
6313
6314 double
eval(const eval_context_t & eval_context) const6315 TrinaryOpNode::eval(const eval_context_t &eval_context) const noexcept(false)
6316 {
6317 double v1 = arg1->eval(eval_context);
6318 double v2 = arg2->eval(eval_context);
6319 double v3 = arg3->eval(eval_context);
6320
6321 return eval_opcode(v1, op_code, v2, v3);
6322 }
6323
6324 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const6325 TrinaryOpNode::compile(ostream &CompileCode, unsigned int &instruction_number,
6326 bool lhs_rhs, const temporary_terms_t &temporary_terms,
6327 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
6328 const deriv_node_temp_terms_t &tef_terms) const
6329 {
6330 // If current node is a temporary term
6331 if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6332 {
6333 if (dynamic)
6334 {
6335 auto ii = map_idx.find(idx);
6336 FLDT_ fldt(ii->second);
6337 fldt.write(CompileCode, instruction_number);
6338 }
6339 else
6340 {
6341 auto ii = map_idx.find(idx);
6342 FLDST_ fldst(ii->second);
6343 fldst.write(CompileCode, instruction_number);
6344 }
6345 return;
6346 }
6347 arg1->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
6348 arg2->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
6349 arg3->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx, dynamic, steady_dynamic, tef_terms);
6350 FTRINARY_ ftrinary{static_cast<int>(op_code)};
6351 ftrinary.write(CompileCode, instruction_number);
6352 }
6353
6354 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const6355 TrinaryOpNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
6356 {
6357 if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6358 temporary_terms_inuse.insert(idx);
6359 else
6360 {
6361 arg1->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
6362 arg2->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
6363 arg3->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
6364 }
6365 }
6366
6367 bool
containsExternalFunction() const6368 TrinaryOpNode::containsExternalFunction() const
6369 {
6370 return arg1->containsExternalFunction()
6371 || arg2->containsExternalFunction()
6372 || arg3->containsExternalFunction();
6373 }
6374
6375 void
writeJsonAST(ostream & output) const6376 TrinaryOpNode::writeJsonAST(ostream &output) const
6377 {
6378 output << R"({"node_type" : "TrinaryOpNode", )"
6379 << R"("op" : ")";
6380 switch (op_code)
6381 {
6382 case TrinaryOpcode::normcdf:
6383 output << "normcdf";
6384 break;
6385 case TrinaryOpcode::normpdf:
6386 output << "normpdf";
6387 break;
6388 }
6389 output << R"(", "arg1" : )";
6390 arg1->writeJsonAST(output);
6391 output << R"(, "arg2" : )";
6392 arg2->writeJsonAST(output);
6393 output << R"(, "arg2" : )";
6394 arg3->writeJsonAST(output);
6395 output << "}";
6396 }
6397
6398 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const6399 TrinaryOpNode::writeJsonOutput(ostream &output,
6400 const temporary_terms_t &temporary_terms,
6401 const deriv_node_temp_terms_t &tef_terms,
6402 bool isdynamic) const
6403 {
6404 // If current node is a temporary term
6405 if (temporary_terms.find(const_cast<TrinaryOpNode *>(this)) != temporary_terms.end())
6406 {
6407 output << "T" << idx;
6408 return;
6409 }
6410
6411 switch (op_code)
6412 {
6413 case TrinaryOpcode::normcdf:
6414 output << "normcdf(";
6415 break;
6416 case TrinaryOpcode::normpdf:
6417 output << "normpdf(";
6418 break;
6419 }
6420
6421 arg1->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
6422 output << ",";
6423 arg2->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
6424 output << ",";
6425 arg3->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
6426 output << ")";
6427 }
6428
6429 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const6430 TrinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
6431 const temporary_terms_t &temporary_terms,
6432 const temporary_terms_idxs_t &temporary_terms_idxs,
6433 const deriv_node_temp_terms_t &tef_terms) const
6434 {
6435 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
6436 return;
6437
6438 switch (op_code)
6439 {
6440 case TrinaryOpcode::normcdf:
6441 if (isCOutput(output_type))
6442 {
6443 // In C, there is no normcdf() primitive, so use erf()
6444 output << "(0.5*(1+erf(((";
6445 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6446 output << ")-(";
6447 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6448 output << "))/(";
6449 arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6450 output << ")/M_SQRT2)))";
6451 }
6452 else
6453 {
6454 output << "normcdf(";
6455 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6456 output << ",";
6457 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6458 output << ",";
6459 arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6460 output << ")";
6461 }
6462 break;
6463 case TrinaryOpcode::normpdf:
6464 if (isCOutput(output_type))
6465 {
6466 //(1/(v3*sqrt(2*M_PI)*exp(pow((v1-v2)/v3,2)/2)))
6467 output << "(1/(";
6468 arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6469 output << "*sqrt(2*M_PI)*exp(pow((";
6470 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6471 output << "-";
6472 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6473 output << ")/";
6474 arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6475 output << ",2)/2)))";
6476 }
6477 else
6478 {
6479 output << "normpdf(";
6480 arg1->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6481 output << ",";
6482 arg2->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6483 output << ",";
6484 arg3->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6485 output << ")";
6486 }
6487 break;
6488 }
6489 }
6490
6491 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const6492 TrinaryOpNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
6493 const temporary_terms_t &temporary_terms,
6494 const temporary_terms_idxs_t &temporary_terms_idxs,
6495 deriv_node_temp_terms_t &tef_terms) const
6496 {
6497 arg1->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6498 arg2->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6499 arg3->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
6500 }
6501
6502 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const6503 TrinaryOpNode::writeJsonExternalFunctionOutput(vector<string> &efout,
6504 const temporary_terms_t &temporary_terms,
6505 deriv_node_temp_terms_t &tef_terms,
6506 bool isdynamic) const
6507 {
6508 arg1->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
6509 arg2->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
6510 arg3->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
6511 }
6512
6513 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const6514 TrinaryOpNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
6515 bool lhs_rhs, const temporary_terms_t &temporary_terms,
6516 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
6517 deriv_node_temp_terms_t &tef_terms) const
6518 {
6519 arg1->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
6520 dynamic, steady_dynamic, tef_terms);
6521 arg2->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
6522 dynamic, steady_dynamic, tef_terms);
6523 arg3->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
6524 dynamic, steady_dynamic, tef_terms);
6525 }
6526
6527 void
collectVARLHSVariable(set<expr_t> & result) const6528 TrinaryOpNode::collectVARLHSVariable(set<expr_t> &result) const
6529 {
6530 cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
6531 exit(EXIT_FAILURE);
6532 }
6533
6534 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const6535 TrinaryOpNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
6536 {
6537 arg1->collectDynamicVariables(type_arg, result);
6538 arg2->collectDynamicVariables(type_arg, result);
6539 arg3->collectDynamicVariables(type_arg, result);
6540 }
6541
6542 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const6543 TrinaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
6544 {
6545 pair<int, expr_t> res = arg1->normalizeEquation(var_endo, List_of_Op_RHS);
6546 bool is_endogenous_present_1 = res.first;
6547 expr_t expr_t_1 = res.second;
6548 res = arg2->normalizeEquation(var_endo, List_of_Op_RHS);
6549 bool is_endogenous_present_2 = res.first;
6550 expr_t expr_t_2 = res.second;
6551 res = arg3->normalizeEquation(var_endo, List_of_Op_RHS);
6552 bool is_endogenous_present_3 = res.first;
6553 expr_t expr_t_3 = res.second;
6554 if (!is_endogenous_present_1 && !is_endogenous_present_2 && !is_endogenous_present_3)
6555 return { 0, datatree.AddNormcdf(expr_t_1, expr_t_2, expr_t_3) };
6556 else
6557 return { 2, nullptr };
6558 }
6559
6560 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)6561 TrinaryOpNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
6562 {
6563 expr_t darg1 = arg1->getChainRuleDerivative(deriv_id, recursive_variables);
6564 expr_t darg2 = arg2->getChainRuleDerivative(deriv_id, recursive_variables);
6565 expr_t darg3 = arg3->getChainRuleDerivative(deriv_id, recursive_variables);
6566 return composeDerivatives(darg1, darg2, darg3);
6567 }
6568
6569 expr_t
buildSimilarTrinaryOpNode(expr_t alt_arg1,expr_t alt_arg2,expr_t alt_arg3,DataTree & alt_datatree) const6570 TrinaryOpNode::buildSimilarTrinaryOpNode(expr_t alt_arg1, expr_t alt_arg2, expr_t alt_arg3, DataTree &alt_datatree) const
6571 {
6572 switch (op_code)
6573 {
6574 case TrinaryOpcode::normcdf:
6575 return alt_datatree.AddNormcdf(alt_arg1, alt_arg2, alt_arg3);
6576 case TrinaryOpcode::normpdf:
6577 return alt_datatree.AddNormpdf(alt_arg1, alt_arg2, alt_arg3);
6578 }
6579 // Suppress GCC warning
6580 exit(EXIT_FAILURE);
6581 }
6582
6583 expr_t
toStatic(DataTree & static_datatree) const6584 TrinaryOpNode::toStatic(DataTree &static_datatree) const
6585 {
6586 expr_t sarg1 = arg1->toStatic(static_datatree);
6587 expr_t sarg2 = arg2->toStatic(static_datatree);
6588 expr_t sarg3 = arg3->toStatic(static_datatree);
6589 return buildSimilarTrinaryOpNode(sarg1, sarg2, sarg3, static_datatree);
6590 }
6591
6592 void
computeXrefs(EquationInfo & ei) const6593 TrinaryOpNode::computeXrefs(EquationInfo &ei) const
6594 {
6595 arg1->computeXrefs(ei);
6596 arg2->computeXrefs(ei);
6597 arg3->computeXrefs(ei);
6598 }
6599
6600 expr_t
clone(DataTree & datatree) const6601 TrinaryOpNode::clone(DataTree &datatree) const
6602 {
6603 expr_t substarg1 = arg1->clone(datatree);
6604 expr_t substarg2 = arg2->clone(datatree);
6605 expr_t substarg3 = arg3->clone(datatree);
6606 return buildSimilarTrinaryOpNode(substarg1, substarg2, substarg3, datatree);
6607 }
6608
6609 int
maxEndoLead() const6610 TrinaryOpNode::maxEndoLead() const
6611 {
6612 return max(arg1->maxEndoLead(), max(arg2->maxEndoLead(), arg3->maxEndoLead()));
6613 }
6614
6615 int
maxExoLead() const6616 TrinaryOpNode::maxExoLead() const
6617 {
6618 return max(arg1->maxExoLead(), max(arg2->maxExoLead(), arg3->maxExoLead()));
6619 }
6620
6621 int
maxEndoLag() const6622 TrinaryOpNode::maxEndoLag() const
6623 {
6624 return max(arg1->maxEndoLag(), max(arg2->maxEndoLag(), arg3->maxEndoLag()));
6625 }
6626
6627 int
maxExoLag() const6628 TrinaryOpNode::maxExoLag() const
6629 {
6630 return max(arg1->maxExoLag(), max(arg2->maxExoLag(), arg3->maxExoLag()));
6631 }
6632
6633 int
maxLead() const6634 TrinaryOpNode::maxLead() const
6635 {
6636 return max(arg1->maxLead(), max(arg2->maxLead(), arg3->maxLead()));
6637 }
6638
6639 int
maxLag() const6640 TrinaryOpNode::maxLag() const
6641 {
6642 return max(arg1->maxLag(), max(arg2->maxLag(), arg3->maxLag()));
6643 }
6644
6645 int
maxLagWithDiffsExpanded() const6646 TrinaryOpNode::maxLagWithDiffsExpanded() const
6647 {
6648 return max(arg1->maxLagWithDiffsExpanded(),
6649 max(arg2->maxLagWithDiffsExpanded(), arg3->maxLagWithDiffsExpanded()));
6650 }
6651
6652 expr_t
undiff() const6653 TrinaryOpNode::undiff() const
6654 {
6655 expr_t arg1subst = arg1->undiff();
6656 expr_t arg2subst = arg2->undiff();
6657 expr_t arg3subst = arg3->undiff();
6658 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6659 }
6660
6661 int
VarMinLag() const6662 TrinaryOpNode::VarMinLag() const
6663 {
6664 return min(min(arg1->VarMinLag(), arg2->VarMinLag()), arg3->VarMinLag());
6665 }
6666
6667 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const6668 TrinaryOpNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
6669 {
6670 return max(arg1->VarMaxLag(lhs_lag_equiv),
6671 max(arg2->VarMaxLag(lhs_lag_equiv),
6672 arg3->VarMaxLag(lhs_lag_equiv)));
6673 }
6674
6675 int
PacMaxLag(int lhs_symb_id) const6676 TrinaryOpNode::PacMaxLag(int lhs_symb_id) const
6677 {
6678 return max(arg1->PacMaxLag(lhs_symb_id), max(arg2->PacMaxLag(lhs_symb_id), arg3->PacMaxLag(lhs_symb_id)));
6679 }
6680
6681 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const6682 TrinaryOpNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
6683 {
6684 int symb_id = arg1->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
6685 if (symb_id >= 0)
6686 return symb_id;
6687
6688 symb_id = arg2->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
6689 if (symb_id >= 0)
6690 return symb_id;
6691
6692 return arg3->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
6693 }
6694
6695 expr_t
decreaseLeadsLags(int n) const6696 TrinaryOpNode::decreaseLeadsLags(int n) const
6697 {
6698 expr_t arg1subst = arg1->decreaseLeadsLags(n);
6699 expr_t arg2subst = arg2->decreaseLeadsLags(n);
6700 expr_t arg3subst = arg3->decreaseLeadsLags(n);
6701 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6702 }
6703
6704 expr_t
decreaseLeadsLagsPredeterminedVariables() const6705 TrinaryOpNode::decreaseLeadsLagsPredeterminedVariables() const
6706 {
6707 expr_t arg1subst = arg1->decreaseLeadsLagsPredeterminedVariables();
6708 expr_t arg2subst = arg2->decreaseLeadsLagsPredeterminedVariables();
6709 expr_t arg3subst = arg3->decreaseLeadsLagsPredeterminedVariables();
6710 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6711 }
6712
6713 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const6714 TrinaryOpNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
6715 {
6716 if (maxEndoLead() < 2)
6717 return const_cast<TrinaryOpNode *>(this);
6718 else if (deterministic_model)
6719 {
6720 expr_t arg1subst = arg1->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
6721 expr_t arg2subst = arg2->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
6722 expr_t arg3subst = arg3->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model);
6723 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6724 }
6725 else
6726 return createEndoLeadAuxiliaryVarForMyself(subst_table, neweqs);
6727 }
6728
6729 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6730 TrinaryOpNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6731 {
6732 expr_t arg1subst = arg1->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
6733 expr_t arg2subst = arg2->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
6734 expr_t arg3subst = arg3->substituteEndoLagGreaterThanTwo(subst_table, neweqs);
6735 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6736 }
6737
6738 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const6739 TrinaryOpNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
6740 {
6741 if (maxExoLead() == 0)
6742 return const_cast<TrinaryOpNode *>(this);
6743 else if (deterministic_model)
6744 {
6745 expr_t arg1subst = arg1->substituteExoLead(subst_table, neweqs, deterministic_model);
6746 expr_t arg2subst = arg2->substituteExoLead(subst_table, neweqs, deterministic_model);
6747 expr_t arg3subst = arg3->substituteExoLead(subst_table, neweqs, deterministic_model);
6748 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6749 }
6750 else
6751 return createExoLeadAuxiliaryVarForMyself(subst_table, neweqs);
6752 }
6753
6754 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6755 TrinaryOpNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6756 {
6757 expr_t arg1subst = arg1->substituteExoLag(subst_table, neweqs);
6758 expr_t arg2subst = arg2->substituteExoLag(subst_table, neweqs);
6759 expr_t arg3subst = arg3->substituteExoLag(subst_table, neweqs);
6760 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6761 }
6762
6763 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const6764 TrinaryOpNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
6765 {
6766 expr_t arg1subst = arg1->substituteExpectation(subst_table, neweqs, partial_information_model);
6767 expr_t arg2subst = arg2->substituteExpectation(subst_table, neweqs, partial_information_model);
6768 expr_t arg3subst = arg3->substituteExpectation(subst_table, neweqs, partial_information_model);
6769 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6770 }
6771
6772 expr_t
substituteAdl() const6773 TrinaryOpNode::substituteAdl() const
6774 {
6775 expr_t arg1subst = arg1->substituteAdl();
6776 expr_t arg2subst = arg2->substituteAdl();
6777 expr_t arg3subst = arg3->substituteAdl();
6778 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6779 }
6780
6781 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const6782 TrinaryOpNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
6783 {
6784 expr_t arg1subst = arg1->substituteVarExpectation(subst_table);
6785 expr_t arg2subst = arg2->substituteVarExpectation(subst_table);
6786 expr_t arg3subst = arg3->substituteVarExpectation(subst_table);
6787 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6788 }
6789
6790 void
findDiffNodes(lag_equivalence_table_t & nodes) const6791 TrinaryOpNode::findDiffNodes(lag_equivalence_table_t &nodes) const
6792 {
6793 arg1->findDiffNodes(nodes);
6794 arg2->findDiffNodes(nodes);
6795 arg3->findDiffNodes(nodes);
6796 }
6797
6798 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const6799 TrinaryOpNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
6800 {
6801 arg1->findUnaryOpNodesForAuxVarCreation(nodes);
6802 arg2->findUnaryOpNodesForAuxVarCreation(nodes);
6803 arg3->findUnaryOpNodesForAuxVarCreation(nodes);
6804 }
6805
6806 int
findTargetVariable(int lhs_symb_id) const6807 TrinaryOpNode::findTargetVariable(int lhs_symb_id) const
6808 {
6809 int retval = arg1->findTargetVariable(lhs_symb_id);
6810 if (retval < 0)
6811 retval = arg2->findTargetVariable(lhs_symb_id);
6812 if (retval < 0)
6813 retval = arg3->findTargetVariable(lhs_symb_id);
6814 return retval;
6815 }
6816
6817 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6818 TrinaryOpNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
6819 vector<BinaryOpNode *> &neweqs) const
6820 {
6821 expr_t arg1subst = arg1->substituteDiff(nodes, subst_table, neweqs);
6822 expr_t arg2subst = arg2->substituteDiff(nodes, subst_table, neweqs);
6823 expr_t arg3subst = arg3->substituteDiff(nodes, subst_table, neweqs);
6824 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6825 }
6826
6827 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6828 TrinaryOpNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6829 {
6830 expr_t arg1subst = arg1->substituteUnaryOpNodes(nodes, subst_table, neweqs);
6831 expr_t arg2subst = arg2->substituteUnaryOpNodes(nodes, subst_table, neweqs);
6832 expr_t arg3subst = arg3->substituteUnaryOpNodes(nodes, subst_table, neweqs);
6833 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6834 }
6835
6836 int
countDiffs() const6837 TrinaryOpNode::countDiffs() const
6838 {
6839 return max(arg1->countDiffs(), max(arg2->countDiffs(), arg3->countDiffs()));
6840 }
6841
6842 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)6843 TrinaryOpNode::substitutePacExpectation(const string &name, expr_t subexpr)
6844 {
6845 expr_t arg1subst = arg1->substitutePacExpectation(name, subexpr);
6846 expr_t arg2subst = arg2->substitutePacExpectation(name, subexpr);
6847 expr_t arg3subst = arg3->substitutePacExpectation(name, subexpr);
6848 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6849 }
6850
6851 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const6852 TrinaryOpNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
6853 {
6854 expr_t arg1subst = arg1->differentiateForwardVars(subset, subst_table, neweqs);
6855 expr_t arg2subst = arg2->differentiateForwardVars(subset, subst_table, neweqs);
6856 expr_t arg3subst = arg3->differentiateForwardVars(subset, subst_table, neweqs);
6857 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6858 }
6859
6860 bool
isNumConstNodeEqualTo(double value) const6861 TrinaryOpNode::isNumConstNodeEqualTo(double value) const
6862 {
6863 return false;
6864 }
6865
6866 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const6867 TrinaryOpNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
6868 {
6869 return false;
6870 }
6871
6872 bool
containsPacExpectation(const string & pac_model_name) const6873 TrinaryOpNode::containsPacExpectation(const string &pac_model_name) const
6874 {
6875 return (arg1->containsPacExpectation(pac_model_name) || arg2->containsPacExpectation(pac_model_name) || arg3->containsPacExpectation(pac_model_name));
6876 }
6877
6878 bool
containsEndogenous() const6879 TrinaryOpNode::containsEndogenous() const
6880 {
6881 return (arg1->containsEndogenous() || arg2->containsEndogenous() || arg3->containsEndogenous());
6882 }
6883
6884 bool
containsExogenous() const6885 TrinaryOpNode::containsExogenous() const
6886 {
6887 return (arg1->containsExogenous() || arg2->containsExogenous() || arg3->containsExogenous());
6888 }
6889
6890 expr_t
replaceTrendVar() const6891 TrinaryOpNode::replaceTrendVar() const
6892 {
6893 expr_t arg1subst = arg1->replaceTrendVar();
6894 expr_t arg2subst = arg2->replaceTrendVar();
6895 expr_t arg3subst = arg3->replaceTrendVar();
6896 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6897 }
6898
6899 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const6900 TrinaryOpNode::detrend(int symb_id, bool log_trend, expr_t trend) const
6901 {
6902 expr_t arg1subst = arg1->detrend(symb_id, log_trend, trend);
6903 expr_t arg2subst = arg2->detrend(symb_id, log_trend, trend);
6904 expr_t arg3subst = arg3->detrend(symb_id, log_trend, trend);
6905 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6906 }
6907
6908 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const6909 TrinaryOpNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
6910 {
6911 expr_t arg1subst = arg1->removeTrendLeadLag(trend_symbols_map);
6912 expr_t arg2subst = arg2->removeTrendLeadLag(trend_symbols_map);
6913 expr_t arg3subst = arg3->removeTrendLeadLag(trend_symbols_map);
6914 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6915 }
6916
6917 bool
isInStaticForm() const6918 TrinaryOpNode::isInStaticForm() const
6919 {
6920 return arg1->isInStaticForm() && arg2->isInStaticForm() && arg3->isInStaticForm();
6921 }
6922
6923 bool
isParamTimesEndogExpr() const6924 TrinaryOpNode::isParamTimesEndogExpr() const
6925 {
6926 return arg1->isParamTimesEndogExpr()
6927 || arg2->isParamTimesEndogExpr()
6928 || arg3->isParamTimesEndogExpr();
6929 }
6930
6931 bool
isVarModelReferenced(const string & model_info_name) const6932 TrinaryOpNode::isVarModelReferenced(const string &model_info_name) const
6933 {
6934 return arg1->isVarModelReferenced(model_info_name)
6935 || arg2->isVarModelReferenced(model_info_name)
6936 || arg3->isVarModelReferenced(model_info_name);
6937 }
6938
6939 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const6940 TrinaryOpNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
6941 {
6942 arg1->getEndosAndMaxLags(model_endos_and_lags);
6943 arg2->getEndosAndMaxLags(model_endos_and_lags);
6944 arg3->getEndosAndMaxLags(model_endos_and_lags);
6945 }
6946
6947 expr_t
substituteStaticAuxiliaryVariable() const6948 TrinaryOpNode::substituteStaticAuxiliaryVariable() const
6949 {
6950 expr_t arg1subst = arg1->substituteStaticAuxiliaryVariable();
6951 expr_t arg2subst = arg2->substituteStaticAuxiliaryVariable();
6952 expr_t arg3subst = arg3->substituteStaticAuxiliaryVariable();
6953 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6954 }
6955
6956 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const6957 TrinaryOpNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
6958 {
6959 arg1->findConstantEquations(table);
6960 arg2->findConstantEquations(table);
6961 arg3->findConstantEquations(table);
6962 }
6963
6964 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const6965 TrinaryOpNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
6966 {
6967 expr_t arg1subst = arg1->replaceVarsInEquation(table);
6968 expr_t arg2subst = arg2->replaceVarsInEquation(table);
6969 expr_t arg3subst = arg3->replaceVarsInEquation(table);
6970 return buildSimilarTrinaryOpNode(arg1subst, arg2subst, arg3subst, datatree);
6971 }
6972
AbstractExternalFunctionNode(DataTree & datatree_arg,int idx_arg,int symb_id_arg,vector<expr_t> arguments_arg)6973 AbstractExternalFunctionNode::AbstractExternalFunctionNode(DataTree &datatree_arg,
6974 int idx_arg,
6975 int symb_id_arg,
6976 vector<expr_t> arguments_arg) :
6977 ExprNode{datatree_arg, idx_arg},
6978 symb_id{symb_id_arg},
6979 arguments{move(arguments_arg)}
6980 {
6981 }
6982
6983 void
prepareForDerivation()6984 AbstractExternalFunctionNode::prepareForDerivation()
6985 {
6986 if (preparedForDerivation)
6987 return;
6988
6989 for (auto argument : arguments)
6990 argument->prepareForDerivation();
6991
6992 non_null_derivatives = arguments.at(0)->non_null_derivatives;
6993 for (int i = 1; i < static_cast<int>(arguments.size()); i++)
6994 set_union(non_null_derivatives.begin(),
6995 non_null_derivatives.end(),
6996 arguments.at(i)->non_null_derivatives.begin(),
6997 arguments.at(i)->non_null_derivatives.end(),
6998 inserter(non_null_derivatives, non_null_derivatives.begin()));
6999
7000 preparedForDerivation = true;
7001 }
7002
7003 expr_t
computeDerivative(int deriv_id)7004 AbstractExternalFunctionNode::computeDerivative(int deriv_id)
7005 {
7006 assert(datatree.external_functions_table.getNargs(symb_id) > 0);
7007 vector<expr_t> dargs;
7008 for (auto argument : arguments)
7009 dargs.push_back(argument->getDerivative(deriv_id));
7010 return composeDerivatives(dargs);
7011 }
7012
7013 expr_t
getChainRuleDerivative(int deriv_id,const map<int,expr_t> & recursive_variables)7014 AbstractExternalFunctionNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
7015 {
7016 assert(datatree.external_functions_table.getNargs(symb_id) > 0);
7017 vector<expr_t> dargs;
7018 for (auto argument : arguments)
7019 dargs.push_back(argument->getChainRuleDerivative(deriv_id, recursive_variables));
7020 return composeDerivatives(dargs);
7021 }
7022
7023 unsigned int
compileExternalFunctionArguments(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const7024 AbstractExternalFunctionNode::compileExternalFunctionArguments(ostream &CompileCode, unsigned int &instruction_number,
7025 bool lhs_rhs, const temporary_terms_t &temporary_terms,
7026 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
7027 const deriv_node_temp_terms_t &tef_terms) const
7028 {
7029 for (auto argument : arguments)
7030 argument->compile(CompileCode, instruction_number, lhs_rhs, temporary_terms, map_idx,
7031 dynamic, steady_dynamic, tef_terms);
7032 return (arguments.size());
7033 }
7034
7035 void
collectVARLHSVariable(set<expr_t> & result) const7036 AbstractExternalFunctionNode::collectVARLHSVariable(set<expr_t> &result) const
7037 {
7038 cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
7039 exit(EXIT_FAILURE);
7040 }
7041
7042 void
collectDynamicVariables(SymbolType type_arg,set<pair<int,int>> & result) const7043 AbstractExternalFunctionNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
7044 {
7045 for (auto argument : arguments)
7046 argument->collectDynamicVariables(type_arg, result);
7047 }
7048
7049 void
collectTemporary_terms(const temporary_terms_t & temporary_terms,temporary_terms_inuse_t & temporary_terms_inuse,int Curr_Block) const7050 AbstractExternalFunctionNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
7051 {
7052 if (temporary_terms.find(const_cast<AbstractExternalFunctionNode *>(this)) != temporary_terms.end())
7053 temporary_terms_inuse.insert(idx);
7054 else
7055 for (auto argument : arguments)
7056 argument->collectTemporary_terms(temporary_terms, temporary_terms_inuse, Curr_Block);
7057 }
7058
7059 double
eval(const eval_context_t & eval_context) const7060 AbstractExternalFunctionNode::eval(const eval_context_t &eval_context) const noexcept(false)
7061 {
7062 throw EvalExternalFunctionException();
7063 }
7064
7065 int
maxEndoLead() const7066 AbstractExternalFunctionNode::maxEndoLead() const
7067 {
7068 int val = 0;
7069 for (auto argument : arguments)
7070 val = max(val, argument->maxEndoLead());
7071 return val;
7072 }
7073
7074 int
maxExoLead() const7075 AbstractExternalFunctionNode::maxExoLead() const
7076 {
7077 int val = 0;
7078 for (auto argument : arguments)
7079 val = max(val, argument->maxExoLead());
7080 return val;
7081 }
7082
7083 int
maxEndoLag() const7084 AbstractExternalFunctionNode::maxEndoLag() const
7085 {
7086 int val = 0;
7087 for (auto argument : arguments)
7088 val = max(val, argument->maxEndoLag());
7089 return val;
7090 }
7091
7092 int
maxExoLag() const7093 AbstractExternalFunctionNode::maxExoLag() const
7094 {
7095 int val = 0;
7096 for (auto argument : arguments)
7097 val = max(val, argument->maxExoLag());
7098 return val;
7099 }
7100
7101 int
maxLead() const7102 AbstractExternalFunctionNode::maxLead() const
7103 {
7104 int val = 0;
7105 for (auto argument : arguments)
7106 val = max(val, argument->maxLead());
7107 return val;
7108 }
7109
7110 int
maxLag() const7111 AbstractExternalFunctionNode::maxLag() const
7112 {
7113 int val = 0;
7114 for (auto argument : arguments)
7115 val = max(val, argument->maxLag());
7116 return val;
7117 }
7118
7119 int
maxLagWithDiffsExpanded() const7120 AbstractExternalFunctionNode::maxLagWithDiffsExpanded() const
7121 {
7122 int val = 0;
7123 for (auto argument : arguments)
7124 val = max(val, argument->maxLagWithDiffsExpanded());
7125 return val;
7126 }
7127
7128 expr_t
undiff() const7129 AbstractExternalFunctionNode::undiff() const
7130 {
7131 vector<expr_t> arguments_subst;
7132 for (auto argument : arguments)
7133 arguments_subst.push_back(argument->undiff());
7134 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7135 }
7136
7137 int
VarMinLag() const7138 AbstractExternalFunctionNode::VarMinLag() const
7139 {
7140 int val = 0;
7141 for (auto argument : arguments)
7142 val = min(val, argument->VarMinLag());
7143 return val;
7144 }
7145
7146 int
VarMaxLag(const set<expr_t> & lhs_lag_equiv) const7147 AbstractExternalFunctionNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
7148 {
7149 int max_lag = 0;
7150 for (auto argument : arguments)
7151 max_lag = max(max_lag, argument->VarMaxLag(lhs_lag_equiv));
7152 return max_lag;
7153 }
7154
7155 int
PacMaxLag(int lhs_symb_id) const7156 AbstractExternalFunctionNode::PacMaxLag(int lhs_symb_id) const
7157 {
7158 int val = 0;
7159 for (auto argument : arguments)
7160 val = max(val, argument->PacMaxLag(lhs_symb_id));
7161 return val;
7162 }
7163
7164 int
getPacTargetSymbId(int lhs_symb_id,int undiff_lhs_symb_id) const7165 AbstractExternalFunctionNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
7166 {
7167 int symb_id = -1;
7168 for (auto argument : arguments)
7169 {
7170 symb_id = argument->getPacTargetSymbId(lhs_symb_id, undiff_lhs_symb_id);
7171 if (symb_id >= 0)
7172 return symb_id;
7173 }
7174 return -1;
7175 }
7176
7177 expr_t
decreaseLeadsLags(int n) const7178 AbstractExternalFunctionNode::decreaseLeadsLags(int n) const
7179 {
7180 vector<expr_t> arguments_subst;
7181 for (auto argument : arguments)
7182 arguments_subst.push_back(argument->decreaseLeadsLags(n));
7183 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7184 }
7185
7186 expr_t
decreaseLeadsLagsPredeterminedVariables() const7187 AbstractExternalFunctionNode::decreaseLeadsLagsPredeterminedVariables() const
7188 {
7189 vector<expr_t> arguments_subst;
7190 for (auto argument : arguments)
7191 arguments_subst.push_back(argument->decreaseLeadsLagsPredeterminedVariables());
7192 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7193 }
7194
7195 expr_t
substituteEndoLeadGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const7196 AbstractExternalFunctionNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
7197 {
7198 vector<expr_t> arguments_subst;
7199 for (auto argument : arguments)
7200 arguments_subst.push_back(argument->substituteEndoLeadGreaterThanTwo(subst_table, neweqs, deterministic_model));
7201 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7202 }
7203
7204 expr_t
substituteEndoLagGreaterThanTwo(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7205 AbstractExternalFunctionNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7206 {
7207 vector<expr_t> arguments_subst;
7208 for (auto argument : arguments)
7209 arguments_subst.push_back(argument->substituteEndoLagGreaterThanTwo(subst_table, neweqs));
7210 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7211 }
7212
7213 expr_t
substituteExoLead(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool deterministic_model) const7214 AbstractExternalFunctionNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
7215 {
7216 vector<expr_t> arguments_subst;
7217 for (auto argument : arguments)
7218 arguments_subst.push_back(argument->substituteExoLead(subst_table, neweqs, deterministic_model));
7219 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7220 }
7221
7222 expr_t
substituteExoLag(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7223 AbstractExternalFunctionNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7224 {
7225 vector<expr_t> arguments_subst;
7226 for (auto argument : arguments)
7227 arguments_subst.push_back(argument->substituteExoLag(subst_table, neweqs));
7228 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7229 }
7230
7231 expr_t
substituteExpectation(subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs,bool partial_information_model) const7232 AbstractExternalFunctionNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
7233 {
7234 vector<expr_t> arguments_subst;
7235 for (auto argument : arguments)
7236 arguments_subst.push_back(argument->substituteExpectation(subst_table, neweqs, partial_information_model));
7237 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7238 }
7239
7240 expr_t
substituteAdl() const7241 AbstractExternalFunctionNode::substituteAdl() const
7242 {
7243 vector<expr_t> arguments_subst;
7244 for (auto argument : arguments)
7245 arguments_subst.push_back(argument->substituteAdl());
7246 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7247 }
7248
7249 expr_t
substituteVarExpectation(const map<string,expr_t> & subst_table) const7250 AbstractExternalFunctionNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
7251 {
7252 vector<expr_t> arguments_subst;
7253 for (auto argument : arguments)
7254 arguments_subst.push_back(argument->substituteVarExpectation(subst_table));
7255 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7256 }
7257
7258 void
findDiffNodes(lag_equivalence_table_t & nodes) const7259 AbstractExternalFunctionNode::findDiffNodes(lag_equivalence_table_t &nodes) const
7260 {
7261 for (auto argument : arguments)
7262 argument->findDiffNodes(nodes);
7263 }
7264
7265 void
findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t & nodes) const7266 AbstractExternalFunctionNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
7267 {
7268 for (auto argument : arguments)
7269 argument->findUnaryOpNodesForAuxVarCreation(nodes);
7270 }
7271
7272 int
findTargetVariable(int lhs_symb_id) const7273 AbstractExternalFunctionNode::findTargetVariable(int lhs_symb_id) const
7274 {
7275 for (auto argument : arguments)
7276 if (int retval = argument->findTargetVariable(lhs_symb_id);
7277 retval >= 0)
7278 return retval;
7279 return -1;
7280 }
7281
7282 expr_t
substituteDiff(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7283 AbstractExternalFunctionNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
7284 vector<BinaryOpNode *> &neweqs) const
7285 {
7286 vector<expr_t> arguments_subst;
7287 for (auto argument : arguments)
7288 arguments_subst.push_back(argument->substituteDiff(nodes, subst_table, neweqs));
7289 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7290 }
7291
7292 expr_t
substituteUnaryOpNodes(const lag_equivalence_table_t & nodes,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7293 AbstractExternalFunctionNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7294 {
7295 vector<expr_t> arguments_subst;
7296 for (auto argument : arguments)
7297 arguments_subst.push_back(argument->substituteUnaryOpNodes(nodes, subst_table, neweqs));
7298 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7299 }
7300
7301 int
countDiffs() const7302 AbstractExternalFunctionNode::countDiffs() const
7303 {
7304 int ndiffs = 0;
7305 for (auto argument : arguments)
7306 ndiffs = max(ndiffs, argument->countDiffs());
7307 return ndiffs;
7308 }
7309
7310 expr_t
substitutePacExpectation(const string & name,expr_t subexpr)7311 AbstractExternalFunctionNode::substitutePacExpectation(const string &name, expr_t subexpr)
7312 {
7313 vector<expr_t> arguments_subst;
7314 for (auto argument : arguments)
7315 arguments_subst.push_back(argument->substitutePacExpectation(name, subexpr));
7316 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7317 }
7318
7319 expr_t
differentiateForwardVars(const vector<string> & subset,subst_table_t & subst_table,vector<BinaryOpNode * > & neweqs) const7320 AbstractExternalFunctionNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
7321 {
7322 vector<expr_t> arguments_subst;
7323 for (auto argument : arguments)
7324 arguments_subst.push_back(argument->differentiateForwardVars(subset, subst_table, neweqs));
7325 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7326 }
7327
7328 bool
alreadyWrittenAsTefTerm(int the_symb_id,const deriv_node_temp_terms_t & tef_terms) const7329 AbstractExternalFunctionNode::alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const
7330 {
7331 if (tef_terms.find({ the_symb_id, arguments }) != tef_terms.end())
7332 return true;
7333 return false;
7334 }
7335
7336 int
getIndxInTefTerms(int the_symb_id,const deriv_node_temp_terms_t & tef_terms) const7337 AbstractExternalFunctionNode::getIndxInTefTerms(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const noexcept(false)
7338 {
7339 if (auto it = tef_terms.find({ the_symb_id, arguments });
7340 it != tef_terms.end())
7341 return it->second;
7342 throw UnknownFunctionNameAndArgs();
7343 }
7344
7345 void
computeTemporaryTerms(const pair<int,int> & derivOrder,map<pair<int,int>,temporary_terms_t> & temp_terms_map,map<expr_t,pair<int,pair<int,int>>> & reference_count,bool is_matlab) const7346 AbstractExternalFunctionNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
7347 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
7348 map<expr_t, pair<int, pair<int, int>>> &reference_count,
7349 bool is_matlab) const
7350 {
7351 /* All external function nodes are declared as temporary terms.
7352
7353 Given that temporary terms are separated in several functions (residuals,
7354 jacobian, …), we must make sure that all temporary terms derived from a
7355 given external function call are assigned just after that call.
7356
7357 As a consequence, we need to “promote” some terms to a previous level (in
7358 the sense that residuals come before jacobian), if a temporary term
7359 corresponding to the same external function call is present in that
7360 previous level. */
7361
7362 for (auto &tt : temp_terms_map)
7363 if (auto it = find_if(tt.second.cbegin(), tt.second.cend(), sameTefTermPredicate());
7364 it != tt.second.cend())
7365 {
7366 tt.second.insert(const_cast<AbstractExternalFunctionNode *>(this));
7367 return;
7368 }
7369
7370 temp_terms_map[derivOrder].insert(const_cast<AbstractExternalFunctionNode *>(this));
7371 }
7372
7373 bool
isNumConstNodeEqualTo(double value) const7374 AbstractExternalFunctionNode::isNumConstNodeEqualTo(double value) const
7375 {
7376 return false;
7377 }
7378
7379 bool
isVariableNodeEqualTo(SymbolType type_arg,int variable_id,int lag_arg) const7380 AbstractExternalFunctionNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
7381 {
7382 return false;
7383 }
7384
7385 bool
containsPacExpectation(const string & pac_model_name) const7386 AbstractExternalFunctionNode::containsPacExpectation(const string &pac_model_name) const
7387 {
7388 for (auto argument : arguments)
7389 if (argument->containsPacExpectation(pac_model_name))
7390 return true;
7391 return false;
7392 }
7393
7394 bool
containsEndogenous() const7395 AbstractExternalFunctionNode::containsEndogenous() const
7396 {
7397 for (auto argument : arguments)
7398 if (argument->containsEndogenous())
7399 return true;
7400 return false;
7401 }
7402
7403 bool
containsExogenous() const7404 AbstractExternalFunctionNode::containsExogenous() const
7405 {
7406 for (auto argument : arguments)
7407 if (argument->containsExogenous())
7408 return true;
7409 return false;
7410 }
7411
7412 expr_t
replaceTrendVar() const7413 AbstractExternalFunctionNode::replaceTrendVar() const
7414 {
7415 vector<expr_t> arguments_subst;
7416 for (auto argument : arguments)
7417 arguments_subst.push_back(argument->replaceTrendVar());
7418 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7419 }
7420
7421 expr_t
detrend(int symb_id,bool log_trend,expr_t trend) const7422 AbstractExternalFunctionNode::detrend(int symb_id, bool log_trend, expr_t trend) const
7423 {
7424 vector<expr_t> arguments_subst;
7425 for (auto argument : arguments)
7426 arguments_subst.push_back(argument->detrend(symb_id, log_trend, trend));
7427 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7428 }
7429
7430 expr_t
removeTrendLeadLag(const map<int,expr_t> & trend_symbols_map) const7431 AbstractExternalFunctionNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
7432 {
7433 vector<expr_t> arguments_subst;
7434 for (auto argument : arguments)
7435 arguments_subst.push_back(argument->removeTrendLeadLag(trend_symbols_map));
7436 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7437 }
7438
7439 bool
isInStaticForm() const7440 AbstractExternalFunctionNode::isInStaticForm() const
7441 {
7442 for (auto argument : arguments)
7443 if (!argument->isInStaticForm())
7444 return false;
7445 return true;
7446 }
7447
7448 bool
isParamTimesEndogExpr() const7449 AbstractExternalFunctionNode::isParamTimesEndogExpr() const
7450 {
7451 return false;
7452 }
7453
7454 bool
isVarModelReferenced(const string & model_info_name) const7455 AbstractExternalFunctionNode::isVarModelReferenced(const string &model_info_name) const
7456 {
7457 for (auto argument : arguments)
7458 if (!argument->isVarModelReferenced(model_info_name))
7459 return true;
7460 return false;
7461 }
7462
7463 void
getEndosAndMaxLags(map<string,int> & model_endos_and_lags) const7464 AbstractExternalFunctionNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
7465 {
7466 for (auto argument : arguments)
7467 argument->getEndosAndMaxLags(model_endos_and_lags);
7468 }
7469
7470 pair<int, expr_t>
normalizeEquation(int var_endo,vector<tuple<int,expr_t,expr_t>> & List_of_Op_RHS) const7471 AbstractExternalFunctionNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
7472 {
7473 vector<pair<bool, expr_t>> V_arguments;
7474 vector<expr_t> V_expr_t;
7475 bool present = false;
7476 for (auto argument : arguments)
7477 {
7478 V_arguments.emplace_back(argument->normalizeEquation(var_endo, List_of_Op_RHS));
7479 present = present || V_arguments[V_arguments.size()-1].first;
7480 V_expr_t.push_back(V_arguments[V_arguments.size()-1].second);
7481 }
7482 if (!present)
7483 return { 0, datatree.AddExternalFunction(symb_id, V_expr_t) };
7484 else
7485 return { 2, nullptr };
7486 }
7487
7488 void
writeExternalFunctionArguments(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const7489 AbstractExternalFunctionNode::writeExternalFunctionArguments(ostream &output, ExprNodeOutputType output_type,
7490 const temporary_terms_t &temporary_terms,
7491 const temporary_terms_idxs_t &temporary_terms_idxs,
7492 const deriv_node_temp_terms_t &tef_terms) const
7493 {
7494 for (auto it = arguments.begin(); it != arguments.end(); ++it)
7495 {
7496 if (it != arguments.begin())
7497 output << ",";
7498
7499 (*it)->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7500 }
7501 }
7502
7503 void
writeJsonASTExternalFunctionArguments(ostream & output) const7504 AbstractExternalFunctionNode::writeJsonASTExternalFunctionArguments(ostream &output) const
7505 {
7506 int i = 0;
7507 output << "{";
7508 for (auto it = arguments.begin(); it != arguments.end(); ++it, i++)
7509 {
7510 if (it != arguments.begin())
7511 output << ",";
7512
7513 output << R"("arg)" << i << R"(" : )";
7514 (*it)->writeJsonAST(output);
7515 }
7516 output << "}";
7517 }
7518
7519 void
writeJsonExternalFunctionArguments(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7520 AbstractExternalFunctionNode::writeJsonExternalFunctionArguments(ostream &output,
7521 const temporary_terms_t &temporary_terms,
7522 const deriv_node_temp_terms_t &tef_terms,
7523 bool isdynamic) const
7524 {
7525 for (auto it = arguments.begin(); it != arguments.end(); ++it)
7526 {
7527 if (it != arguments.begin())
7528 output << ",";
7529
7530 (*it)->writeJsonOutput(output, temporary_terms, tef_terms, isdynamic);
7531 }
7532 }
7533
7534 void
writePrhs(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms,const string & ending) const7535 AbstractExternalFunctionNode::writePrhs(ostream &output, ExprNodeOutputType output_type,
7536 const temporary_terms_t &temporary_terms,
7537 const temporary_terms_idxs_t &temporary_terms_idxs,
7538 const deriv_node_temp_terms_t &tef_terms, const string &ending) const
7539 {
7540 output << "mxArray *prhs"<< ending << "[nrhs"<< ending << "];" << endl;
7541 int i = 0;
7542 for (auto argument : arguments)
7543 {
7544 output << "prhs" << ending << "[" << i++ << "] = mxCreateDoubleScalar("; // All external_function arguments are scalars
7545 argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7546 output << ");" << endl;
7547 }
7548 }
7549
7550 bool
containsExternalFunction() const7551 AbstractExternalFunctionNode::containsExternalFunction() const
7552 {
7553 return true;
7554 }
7555
7556 expr_t
substituteStaticAuxiliaryVariable() const7557 AbstractExternalFunctionNode::substituteStaticAuxiliaryVariable() const
7558 {
7559 vector<expr_t> arguments_subst;
7560 for (auto argument : arguments)
7561 arguments_subst.push_back(argument->substituteStaticAuxiliaryVariable());
7562 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7563 }
7564
7565 void
findConstantEquations(map<VariableNode *,NumConstNode * > & table) const7566 AbstractExternalFunctionNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
7567 {
7568 for (auto argument : arguments)
7569 argument->findConstantEquations(table);
7570 }
7571
7572 expr_t
replaceVarsInEquation(map<VariableNode *,NumConstNode * > & table) const7573 AbstractExternalFunctionNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
7574 {
7575 vector<expr_t> arguments_subst;
7576 for (auto argument : arguments)
7577 arguments_subst.push_back(argument->replaceVarsInEquation(table));
7578 return buildSimilarExternalFunctionNode(arguments_subst, datatree);
7579 }
7580
ExternalFunctionNode(DataTree & datatree_arg,int idx_arg,int symb_id_arg,const vector<expr_t> & arguments_arg)7581 ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg,
7582 int idx_arg,
7583 int symb_id_arg,
7584 const vector<expr_t> &arguments_arg) :
7585 AbstractExternalFunctionNode{datatree_arg, idx_arg, symb_id_arg, arguments_arg}
7586 {
7587 }
7588
7589 expr_t
composeDerivatives(const vector<expr_t> & dargs)7590 ExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
7591 {
7592 vector<expr_t> dNodes;
7593 for (int i = 0; i < static_cast<int>(dargs.size()); i++)
7594 dNodes.push_back(datatree.AddTimes(dargs.at(i),
7595 datatree.AddFirstDerivExternalFunction(symb_id, arguments, i+1)));
7596
7597 expr_t theDeriv = datatree.Zero;
7598 for (auto &dNode : dNodes)
7599 theDeriv = datatree.AddPlus(theDeriv, dNode);
7600 return theDeriv;
7601 }
7602
7603 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const7604 ExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
7605 temporary_terms_t &temporary_terms,
7606 map<expr_t, pair<int, int>> &first_occurence,
7607 int Curr_block,
7608 vector< vector<temporary_terms_t>> &v_temporary_terms,
7609 int equation) const
7610 {
7611 expr_t this2 = const_cast<ExternalFunctionNode *>(this);
7612 temporary_terms.insert(this2);
7613 first_occurence[this2] = { Curr_block, equation };
7614 v_temporary_terms[Curr_block][equation].insert(this2);
7615 }
7616
7617 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const7618 ExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number,
7619 bool lhs_rhs, const temporary_terms_t &temporary_terms,
7620 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
7621 const deriv_node_temp_terms_t &tef_terms) const
7622 {
7623 if (temporary_terms.find(const_cast<ExternalFunctionNode *>(this)) != temporary_terms.end())
7624 {
7625 if (dynamic)
7626 {
7627 auto ii = map_idx.find(idx);
7628 FLDT_ fldt(ii->second);
7629 fldt.write(CompileCode, instruction_number);
7630 }
7631 else
7632 {
7633 auto ii = map_idx.find(idx);
7634 FLDST_ fldst(ii->second);
7635 fldst.write(CompileCode, instruction_number);
7636 }
7637 return;
7638 }
7639
7640 if (!lhs_rhs)
7641 {
7642 FLDTEF_ fldtef(getIndxInTefTerms(symb_id, tef_terms));
7643 fldtef.write(CompileCode, instruction_number);
7644 }
7645 else
7646 {
7647 FSTPTEF_ fstptef(getIndxInTefTerms(symb_id, tef_terms));
7648 fstptef.write(CompileCode, instruction_number);
7649 }
7650 }
7651
7652 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const7653 ExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
7654 bool lhs_rhs, const temporary_terms_t &temporary_terms,
7655 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
7656 deriv_node_temp_terms_t &tef_terms) const
7657 {
7658 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7659 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7660
7661 for (auto argument : arguments)
7662 argument->compileExternalFunctionOutput(CompileCode, instruction_number, lhs_rhs, temporary_terms,
7663 map_idx, dynamic, steady_dynamic, tef_terms);
7664
7665 if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
7666 {
7667 tef_terms[{ symb_id, arguments }] = static_cast<int>(tef_terms.size());
7668 int indx = getIndxInTefTerms(symb_id, tef_terms);
7669 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
7670 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7671
7672 unsigned int nb_output_arguments = 0;
7673 if (symb_id == first_deriv_symb_id
7674 && symb_id == second_deriv_symb_id)
7675 nb_output_arguments = 3;
7676 else if (symb_id == first_deriv_symb_id)
7677 nb_output_arguments = 2;
7678 else
7679 nb_output_arguments = 1;
7680 unsigned int nb_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms,
7681 map_idx, dynamic, steady_dynamic, tef_terms);
7682
7683 FCALL_ fcall(nb_output_arguments, nb_input_arguments, datatree.symbol_table.getName(symb_id), indx);
7684 switch (nb_output_arguments)
7685 {
7686 case 1:
7687 fcall.set_function_type(ExternalFunctionType::withoutDerivative);
7688 break;
7689 case 2:
7690 fcall.set_function_type(ExternalFunctionType::withFirstDerivative);
7691 break;
7692 case 3:
7693 fcall.set_function_type(ExternalFunctionType::withFirstAndSecondDerivative);
7694 break;
7695 }
7696 fcall.write(CompileCode, instruction_number);
7697 FSTPTEF_ fstptef(indx);
7698 fstptef.write(CompileCode, instruction_number);
7699 }
7700 }
7701
7702 void
writeJsonAST(ostream & output) const7703 ExternalFunctionNode::writeJsonAST(ostream &output) const
7704 {
7705 output << R"({"node_type" : "ExternalFunctionNode", )"
7706 << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "args" : [)";
7707 writeJsonASTExternalFunctionArguments(output);
7708 output << "]}";
7709 }
7710
7711 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7712 ExternalFunctionNode::writeJsonOutput(ostream &output,
7713 const temporary_terms_t &temporary_terms,
7714 const deriv_node_temp_terms_t &tef_terms,
7715 bool isdynamic) const
7716 {
7717 if (temporary_terms.find(const_cast<ExternalFunctionNode *>(this)) != temporary_terms.end())
7718 {
7719 output << "T" << idx;
7720 return;
7721 }
7722
7723 try
7724 {
7725 int tef_idx = getIndxInTefTerms(symb_id, tef_terms);
7726 output << "TEF_" << tef_idx;
7727 }
7728 catch (UnknownFunctionNameAndArgs &)
7729 {
7730 // When writing the JSON output at parsing pass, we don’t use TEF terms
7731 output << datatree.symbol_table.getName(symb_id) << "(";
7732 writeJsonExternalFunctionArguments(output, temporary_terms, tef_terms, isdynamic);
7733 output << ")";
7734 }
7735 }
7736
7737 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const7738 ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
7739 const temporary_terms_t &temporary_terms,
7740 const temporary_terms_idxs_t &temporary_terms_idxs,
7741 const deriv_node_temp_terms_t &tef_terms) const
7742 {
7743 if (output_type == ExprNodeOutputType::matlabOutsideModel || output_type == ExprNodeOutputType::steadyStateFile
7744 || output_type == ExprNodeOutputType::juliaSteadyStateFile
7745 || output_type == ExprNodeOutputType::epilogueFile
7746 || isLatexOutput(output_type))
7747 {
7748 string name = isLatexOutput(output_type) ? datatree.symbol_table.getTeXName(symb_id)
7749 : datatree.symbol_table.getName(symb_id);
7750 output << name << "(";
7751 writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7752 output << ")";
7753 return;
7754 }
7755
7756 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
7757 return;
7758
7759 if (isCOutput(output_type))
7760 output << "*";
7761 output << "TEF_" << getIndxInTefTerms(symb_id, tef_terms);
7762 }
7763
7764 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const7765 ExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
7766 const temporary_terms_t &temporary_terms,
7767 const temporary_terms_idxs_t &temporary_terms_idxs,
7768 deriv_node_temp_terms_t &tef_terms) const
7769 {
7770 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7771 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7772
7773 for (auto argument : arguments)
7774 argument->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7775
7776 if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
7777 {
7778 tef_terms[{ symb_id, arguments }] = static_cast<int>(tef_terms.size());
7779 int indx = getIndxInTefTerms(symb_id, tef_terms);
7780 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
7781 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7782
7783 if (isCOutput(output_type))
7784 {
7785 stringstream ending;
7786 ending << "_tef_" << getIndxInTefTerms(symb_id, tef_terms);
7787 if (symb_id == first_deriv_symb_id
7788 && symb_id == second_deriv_symb_id)
7789 output << "int nlhs" << ending.str() << " = 3;" << endl
7790 << "double *TEF_" << indx << ", "
7791 << "*TEFD_" << indx << ", "
7792 << "*TEFDD_" << indx << ";" << endl;
7793 else if (symb_id == first_deriv_symb_id)
7794 output << "int nlhs" << ending.str() << " = 2;" << endl
7795 << "double *TEF_" << indx << ", "
7796 << "*TEFD_" << indx << "; " << endl;
7797 else
7798 output << "int nlhs" << ending.str() << " = 1;" << endl
7799 << "double *TEF_" << indx << ";" << endl;
7800
7801 output << "mxArray *plhs" << ending.str()<< "[nlhs"<< ending.str() << "];" << endl;
7802 output << "int nrhs" << ending.str()<< " = " << arguments.size() << ";" << endl;
7803 writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
7804
7805 output << "mexCallMATLAB("
7806 << "nlhs" << ending.str() << ", "
7807 << "plhs" << ending.str() << ", "
7808 << "nrhs" << ending.str() << ", "
7809 << "prhs" << ending.str() << R"(, ")"
7810 << datatree.symbol_table.getName(symb_id) << R"(");)" << endl;
7811
7812 if (symb_id == first_deriv_symb_id
7813 && symb_id == second_deriv_symb_id)
7814 output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl
7815 << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl
7816 << "TEFDD_" << indx << " = mxGetPr(plhs" << ending.str() << "[2]);" << endl
7817 << "int TEFDD_" << indx << "_nrows = (int)mxGetM(plhs" << ending.str()<< "[2]);" << endl;
7818 else if (symb_id == first_deriv_symb_id)
7819 output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl
7820 << "TEFD_" << indx << " = mxGetPr(plhs" << ending.str() << "[1]);" << endl;
7821 else
7822 output << "TEF_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
7823 }
7824 else
7825 {
7826 if (symb_id == first_deriv_symb_id
7827 && symb_id == second_deriv_symb_id)
7828 output << "[TEF_" << indx << ", TEFD_"<< indx << ", TEFDD_"<< indx << "] = ";
7829 else if (symb_id == first_deriv_symb_id)
7830 output << "[TEF_" << indx << ", TEFD_"<< indx << "] = ";
7831 else
7832 output << "TEF_" << indx << " = ";
7833
7834 output << datatree.symbol_table.getName(symb_id) << "(";
7835 writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
7836 output << ");" << endl;
7837 }
7838 }
7839 }
7840
7841 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7842 ExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
7843 const temporary_terms_t &temporary_terms,
7844 deriv_node_temp_terms_t &tef_terms,
7845 bool isdynamic) const
7846 {
7847 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7848 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7849
7850 for (auto argument : arguments)
7851 argument->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
7852
7853 if (!alreadyWrittenAsTefTerm(symb_id, tef_terms))
7854 {
7855 tef_terms[{ symb_id, arguments }] = static_cast<int>(tef_terms.size());
7856 int indx = getIndxInTefTerms(symb_id, tef_terms);
7857 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
7858 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7859
7860 stringstream ef;
7861 ef << R"({"external_function": {)"
7862 << R"("external_function_term": "TEF_)" << indx << R"(")";
7863
7864 if (symb_id == first_deriv_symb_id)
7865 ef << R"(, "external_function_term_d": "TEFD_)" << indx << R"(")";
7866
7867 if (symb_id == second_deriv_symb_id)
7868 ef << R"(, "external_function_term_dd": "TEFDD_)" << indx << R"(")";
7869
7870 ef << R"(, "value": ")" << datatree.symbol_table.getName(symb_id) << "(";
7871 writeJsonExternalFunctionArguments(ef, temporary_terms, tef_terms, isdynamic);
7872 ef << R"lit()"}})lit";
7873 efout.push_back(ef.str());
7874 }
7875 }
7876
7877 expr_t
toStatic(DataTree & static_datatree) const7878 ExternalFunctionNode::toStatic(DataTree &static_datatree) const
7879 {
7880 vector<expr_t> static_arguments;
7881 for (auto argument : arguments)
7882 static_arguments.push_back(argument->toStatic(static_datatree));
7883 return static_datatree.AddExternalFunction(symb_id, static_arguments);
7884 }
7885
7886 void
computeXrefs(EquationInfo & ei) const7887 ExternalFunctionNode::computeXrefs(EquationInfo &ei) const
7888 {
7889 vector<expr_t> dynamic_arguments;
7890 for (auto argument : arguments)
7891 argument->computeXrefs(ei);
7892 }
7893
7894 expr_t
clone(DataTree & datatree) const7895 ExternalFunctionNode::clone(DataTree &datatree) const
7896 {
7897 vector<expr_t> dynamic_arguments;
7898 for (auto argument : arguments)
7899 dynamic_arguments.push_back(argument->clone(datatree));
7900 return datatree.AddExternalFunction(symb_id, dynamic_arguments);
7901 }
7902
7903 expr_t
buildSimilarExternalFunctionNode(vector<expr_t> & alt_args,DataTree & alt_datatree) const7904 ExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const
7905 {
7906 return alt_datatree.AddExternalFunction(symb_id, alt_args);
7907 }
7908
7909 function<bool (expr_t)>
sameTefTermPredicate() const7910 ExternalFunctionNode::sameTefTermPredicate() const
7911 {
7912 return [this](expr_t e) {
7913 auto e2 = dynamic_cast<ExternalFunctionNode *>(e);
7914 return (e2 != nullptr && e2->symb_id == symb_id);
7915 };
7916 }
7917
FirstDerivExternalFunctionNode(DataTree & datatree_arg,int idx_arg,int top_level_symb_id_arg,const vector<expr_t> & arguments_arg,int inputIndex_arg)7918 FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatree_arg,
7919 int idx_arg,
7920 int top_level_symb_id_arg,
7921 const vector<expr_t> &arguments_arg,
7922 int inputIndex_arg) :
7923 AbstractExternalFunctionNode{datatree_arg, idx_arg, top_level_symb_id_arg, arguments_arg},
7924 inputIndex{inputIndex_arg}
7925 {
7926 }
7927
7928 void
computeTemporaryTerms(map<expr_t,int> & reference_count,temporary_terms_t & temporary_terms,map<expr_t,pair<int,int>> & first_occurence,int Curr_block,vector<vector<temporary_terms_t>> & v_temporary_terms,int equation) const7929 FirstDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
7930 temporary_terms_t &temporary_terms,
7931 map<expr_t, pair<int, int>> &first_occurence,
7932 int Curr_block,
7933 vector< vector<temporary_terms_t>> &v_temporary_terms,
7934 int equation) const
7935 {
7936 expr_t this2 = const_cast<FirstDerivExternalFunctionNode *>(this);
7937 temporary_terms.insert(this2);
7938 first_occurence[this2] = { Curr_block, equation };
7939 v_temporary_terms[Curr_block][equation].insert(this2);
7940 }
7941
7942 expr_t
composeDerivatives(const vector<expr_t> & dargs)7943 FirstDerivExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
7944 {
7945 vector<expr_t> dNodes;
7946 for (int i = 0; i < static_cast<int>(dargs.size()); i++)
7947 dNodes.push_back(datatree.AddTimes(dargs.at(i),
7948 datatree.AddSecondDerivExternalFunction(symb_id, arguments, inputIndex, i+1)));
7949 expr_t theDeriv = datatree.Zero;
7950 for (auto &dNode : dNodes)
7951 theDeriv = datatree.AddPlus(theDeriv, dNode);
7952 return theDeriv;
7953 }
7954
7955 void
writeJsonAST(ostream & output) const7956 FirstDerivExternalFunctionNode::writeJsonAST(ostream &output) const
7957 {
7958 output << R"({"node_type" : "FirstDerivExternalFunctionNode", )"
7959 << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "args" : [)";
7960 writeJsonASTExternalFunctionArguments(output);
7961 output << "]}";
7962 }
7963
7964 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const7965 FirstDerivExternalFunctionNode::writeJsonOutput(ostream &output,
7966 const temporary_terms_t &temporary_terms,
7967 const deriv_node_temp_terms_t &tef_terms,
7968 bool isdynamic) const
7969 {
7970 // If current node is a temporary term
7971 if (temporary_terms.find(const_cast<FirstDerivExternalFunctionNode *>(this)) != temporary_terms.end())
7972 {
7973 output << "T" << idx;
7974 return;
7975 }
7976
7977 const int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
7978 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
7979
7980 const int tmpIndx = inputIndex - 1;
7981
7982 if (first_deriv_symb_id == symb_id)
7983 output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms)
7984 << "[" << tmpIndx << "]";
7985 else if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
7986 output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
7987 else
7988 output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
7989 << "[" << tmpIndx << "]";
7990 }
7991
7992 void
writeOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,const deriv_node_temp_terms_t & tef_terms) const7993 FirstDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
7994 const temporary_terms_t &temporary_terms,
7995 const temporary_terms_idxs_t &temporary_terms_idxs,
7996 const deriv_node_temp_terms_t &tef_terms) const
7997 {
7998 assert(output_type != ExprNodeOutputType::matlabOutsideModel);
7999
8000 if (isLatexOutput(output_type))
8001 {
8002 output << R"(\frac{\partial )" << datatree.symbol_table.getTeXName(symb_id)
8003 << R"(}{\partial )" << inputIndex << "}(";
8004 writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8005 output << ")";
8006 return;
8007 }
8008
8009 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
8010 return;
8011
8012 const int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8013 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8014
8015 const int tmpIndx = inputIndex - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
8016
8017 if (first_deriv_symb_id == symb_id)
8018 output << "TEFD_" << getIndxInTefTerms(symb_id, tef_terms)
8019 << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type);
8020 else if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8021 {
8022 if (isCOutput(output_type))
8023 output << "*";
8024 output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
8025 }
8026 else
8027 output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
8028 << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndx << RIGHT_ARRAY_SUBSCRIPT(output_type);
8029 }
8030
8031 void
compile(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,const deriv_node_temp_terms_t & tef_terms) const8032 FirstDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number,
8033 bool lhs_rhs, const temporary_terms_t &temporary_terms,
8034 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8035 const deriv_node_temp_terms_t &tef_terms) const
8036 {
8037 if (temporary_terms.find(const_cast<FirstDerivExternalFunctionNode *>(this)) != temporary_terms.end())
8038 {
8039 if (dynamic)
8040 {
8041 auto ii = map_idx.find(idx);
8042 FLDT_ fldt(ii->second);
8043 fldt.write(CompileCode, instruction_number);
8044 }
8045 else
8046 {
8047 auto ii = map_idx.find(idx);
8048 FLDST_ fldst(ii->second);
8049 fldst.write(CompileCode, instruction_number);
8050 }
8051 return;
8052 }
8053 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8054 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8055
8056 if (!lhs_rhs)
8057 {
8058 FLDTEFD_ fldtefd(getIndxInTefTerms(symb_id, tef_terms), inputIndex);
8059 fldtefd.write(CompileCode, instruction_number);
8060 }
8061 else
8062 {
8063 FSTPTEFD_ fstptefd(getIndxInTefTerms(symb_id, tef_terms), inputIndex);
8064 fstptefd.write(CompileCode, instruction_number);
8065 }
8066 }
8067
8068 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const8069 FirstDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
8070 const temporary_terms_t &temporary_terms,
8071 const temporary_terms_idxs_t &temporary_terms_idxs,
8072 deriv_node_temp_terms_t &tef_terms) const
8073 {
8074 assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8075 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8076 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8077
8078 /* For a node with derivs provided by the user function, call the method
8079 on the non-derived node */
8080 if (first_deriv_symb_id == symb_id)
8081 {
8082 expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8083 parent->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs,
8084 tef_terms);
8085 return;
8086 }
8087
8088 if (alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
8089 return;
8090
8091 if (isCOutput(output_type))
8092 if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8093 {
8094 stringstream ending;
8095 ending << "_tefd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex;
8096 output << "int nlhs" << ending.str() << " = 1;" << endl
8097 << "double *TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << ";" << endl
8098 << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8099 << "int nrhs" << ending.str() << " = 3;" << endl
8100 << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl
8101 << "mwSize dims" << ending.str() << "[2];" << endl;
8102
8103 output << "dims" << ending.str() << "[0] = 1;" << endl
8104 << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl;
8105
8106 output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
8107 << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex << ");"<< endl
8108 << "prhs" << ending.str() << "[2] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl;
8109
8110 int i = 0;
8111 for (auto argument : arguments)
8112 {
8113 output << "mxSetCell(prhs" << ending.str() << "[2], "
8114 << i++ << ", "
8115 << "mxCreateDoubleScalar("; // All external_function arguments are scalars
8116 argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8117 output << "));" << endl;
8118 }
8119
8120 output << "mexCallMATLAB("
8121 << "nlhs" << ending.str() << ", "
8122 << "plhs" << ending.str() << ", "
8123 << "nrhs" << ending.str() << ", "
8124 << "prhs" << ending.str() << R"(, ")"
8125 << R"(jacob_element");)" << endl;
8126
8127 output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex
8128 << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8129 }
8130 else
8131 {
8132 tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8133 int indx = getIndxInTefTerms(first_deriv_symb_id, tef_terms);
8134 stringstream ending;
8135 ending << "_tefd_def_" << indx;
8136 output << "int nlhs" << ending.str() << " = 1;" << endl
8137 << "double *TEFD_def_" << indx << ";" << endl
8138 << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8139 << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl;
8140 writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
8141
8142 output << "mexCallMATLAB("
8143 << "nlhs" << ending.str() << ", "
8144 << "plhs" << ending.str() << ", "
8145 << "nrhs" << ending.str() << ", "
8146 << "prhs" << ending.str() << R"(, ")"
8147 << datatree.symbol_table.getName(first_deriv_symb_id) << R"(");)" << endl;
8148
8149 output << "TEFD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8150 }
8151 else
8152 {
8153 if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8154 output << "TEFD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << " = jacob_element('"
8155 << datatree.symbol_table.getName(symb_id) << "'," << inputIndex << ",{";
8156 else
8157 {
8158 tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8159 output << "TEFD_def_" << getIndxInTefTerms(first_deriv_symb_id, tef_terms)
8160 << " = " << datatree.symbol_table.getName(first_deriv_symb_id) << "(";
8161 }
8162
8163 writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8164
8165 if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8166 output << "}";
8167 output << ");" << endl;
8168 }
8169 }
8170
8171 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const8172 FirstDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
8173 const temporary_terms_t &temporary_terms,
8174 deriv_node_temp_terms_t &tef_terms,
8175 bool isdynamic) const
8176 {
8177 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8178 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8179
8180 /* For a node with derivs provided by the user function, call the method
8181 on the non-derived node */
8182 if (first_deriv_symb_id == symb_id)
8183 {
8184 expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8185 parent->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
8186 return;
8187 }
8188
8189 if (alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
8190 return;
8191
8192 stringstream ef;
8193 if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8194 ef << R"({"first_deriv_external_function": {)"
8195 << R"("external_function_term": "TEFD_fdd_)" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex << R"(")"
8196 << R"(, "analytic_derivative": false)"
8197 << R"(, "wrt": )" << inputIndex
8198 << R"(, "value": ")" << datatree.symbol_table.getName(symb_id) << "(";
8199 else
8200 {
8201 tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8202 ef << R"({"first_deriv_external_function": {)"
8203 << R"("external_function_term": "TEFD_def_)" << getIndxInTefTerms(first_deriv_symb_id, tef_terms) << R"(")"
8204 << R"(, "analytic_derivative": true)"
8205 << R"(, "value": ")" << datatree.symbol_table.getName(first_deriv_symb_id) << "(";
8206 }
8207
8208 writeJsonExternalFunctionArguments(ef, temporary_terms, tef_terms, isdynamic);
8209 ef << R"lit()"}})lit";
8210 efout.push_back(ef.str());
8211 }
8212
8213 void
compileExternalFunctionOutput(ostream & CompileCode,unsigned int & instruction_number,bool lhs_rhs,const temporary_terms_t & temporary_terms,const map_idx_t & map_idx,bool dynamic,bool steady_dynamic,deriv_node_temp_terms_t & tef_terms) const8214 FirstDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
8215 bool lhs_rhs, const temporary_terms_t &temporary_terms,
8216 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8217 deriv_node_temp_terms_t &tef_terms) const
8218 {
8219 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8220 assert(first_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8221
8222 if (first_deriv_symb_id == symb_id || alreadyWrittenAsTefTerm(first_deriv_symb_id, tef_terms))
8223 return;
8224
8225 unsigned int nb_add_input_arguments = compileExternalFunctionArguments(CompileCode, instruction_number, lhs_rhs, temporary_terms,
8226 map_idx, dynamic, steady_dynamic, tef_terms);
8227 if (first_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8228 {
8229 unsigned int nb_input_arguments = 0;
8230 unsigned int nb_output_arguments = 1;
8231 unsigned int indx = getIndxInTefTerms(symb_id, tef_terms);
8232 FCALL_ fcall(nb_output_arguments, nb_input_arguments, "jacob_element", indx);
8233 fcall.set_arg_func_name(datatree.symbol_table.getName(symb_id));
8234 fcall.set_row(inputIndex);
8235 fcall.set_nb_add_input_arguments(nb_add_input_arguments);
8236 fcall.set_function_type(ExternalFunctionType::numericalFirstDerivative);
8237 fcall.write(CompileCode, instruction_number);
8238 FSTPTEFD_ fstptefd(indx, inputIndex);
8239 fstptefd.write(CompileCode, instruction_number);
8240 }
8241 else
8242 {
8243 tef_terms[{ first_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8244 int indx = getIndxInTefTerms(symb_id, tef_terms);
8245 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8246 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8247
8248 unsigned int nb_output_arguments = 1;
8249
8250 FCALL_ fcall(nb_output_arguments, nb_add_input_arguments, datatree.symbol_table.getName(first_deriv_symb_id), indx);
8251 fcall.set_function_type(ExternalFunctionType::firstDerivative);
8252 fcall.write(CompileCode, instruction_number);
8253 FSTPTEFD_ fstptefd(indx, inputIndex);
8254 fstptefd.write(CompileCode, instruction_number);
8255 }
8256 }
8257
8258 expr_t
8259 FirstDerivExternalFunctionNode::clone(DataTree &datatree) const
8260 {
8261 vector<expr_t> dynamic_arguments;
8262 for (auto argument : arguments)
8263 dynamic_arguments.push_back(argument->clone(datatree));
8264 return datatree.AddFirstDerivExternalFunction(symb_id, dynamic_arguments,
8265 inputIndex);
8266 }
8267
8268 expr_t
8269 FirstDerivExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const
8270 {
8271 return alt_datatree.AddFirstDerivExternalFunction(symb_id, alt_args, inputIndex);
8272 }
8273
8274 expr_t
8275 FirstDerivExternalFunctionNode::toStatic(DataTree &static_datatree) const
8276 {
8277 vector<expr_t> static_arguments;
8278 for (auto argument : arguments)
8279 static_arguments.push_back(argument->toStatic(static_datatree));
8280 return static_datatree.AddFirstDerivExternalFunction(symb_id, static_arguments,
8281 inputIndex);
8282 }
8283
8284 void
8285 FirstDerivExternalFunctionNode::computeXrefs(EquationInfo &ei) const
8286 {
8287 vector<expr_t> dynamic_arguments;
8288 for (auto argument : arguments)
8289 argument->computeXrefs(ei);
8290 }
8291
8292 function<bool (expr_t)>
8293 FirstDerivExternalFunctionNode::sameTefTermPredicate() const
8294 {
8295 int first_deriv_symb_id = datatree.external_functions_table.getFirstDerivSymbID(symb_id);
8296 if (first_deriv_symb_id == symb_id)
8297 return [this](expr_t e) {
8298 auto e2 = dynamic_cast<ExternalFunctionNode *>(e);
8299 return (e2 && e2->symb_id == symb_id);
8300 };
8301 else
8302 return [this](expr_t e) {
8303 auto e2 = dynamic_cast<FirstDerivExternalFunctionNode *>(e);
8304 return (e2 && e2->symb_id == symb_id);
8305 };
8306 }
8307
8308 SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datatree_arg,
8309 int idx_arg,
8310 int top_level_symb_id_arg,
8311 const vector<expr_t> &arguments_arg,
8312 int inputIndex1_arg,
8313 int inputIndex2_arg) :
8314 AbstractExternalFunctionNode{datatree_arg, idx_arg, top_level_symb_id_arg, arguments_arg},
8315 inputIndex1{inputIndex1_arg},
8316 inputIndex2{inputIndex2_arg}
8317 {
8318 }
8319
8320 void
8321 SecondDerivExternalFunctionNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
8322 temporary_terms_t &temporary_terms,
8323 map<expr_t, pair<int, int>> &first_occurence,
8324 int Curr_block,
8325 vector< vector<temporary_terms_t>> &v_temporary_terms,
8326 int equation) const
8327 {
8328 expr_t this2 = const_cast<SecondDerivExternalFunctionNode *>(this);
8329 temporary_terms.insert(this2);
8330 first_occurence[this2] = { Curr_block, equation };
8331 v_temporary_terms[Curr_block][equation].insert(this2);
8332 }
8333
8334 expr_t
8335 SecondDerivExternalFunctionNode::composeDerivatives(const vector<expr_t> &dargs)
8336 {
8337 cerr << "ERROR: third order derivatives of external functions are not implemented" << endl;
8338 exit(EXIT_FAILURE);
8339 }
8340
8341 void
8342 SecondDerivExternalFunctionNode::writeJsonAST(ostream &output) const
8343 {
8344 output << R"({"node_type" : "SecondDerivExternalFunctionNode", )"
8345 << R"("name" : ")" << datatree.symbol_table.getName(symb_id) << R"(", "args" : [)";
8346 writeJsonASTExternalFunctionArguments(output);
8347 output << "]}";
8348 }
8349
8350 void
writeJsonOutput(ostream & output,const temporary_terms_t & temporary_terms,const deriv_node_temp_terms_t & tef_terms,bool isdynamic) const8351 SecondDerivExternalFunctionNode::writeJsonOutput(ostream &output,
8352 const temporary_terms_t &temporary_terms,
8353 const deriv_node_temp_terms_t &tef_terms,
8354 bool isdynamic) const
8355 {
8356 // If current node is a temporary term
8357 if (temporary_terms.find(const_cast<SecondDerivExternalFunctionNode *>(this)) != temporary_terms.end())
8358 {
8359 output << "T" << idx;
8360 return;
8361 }
8362
8363 const int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8364 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8365
8366 const int tmpIndex1 = inputIndex1 - 1;
8367 const int tmpIndex2 = inputIndex2 - 1;
8368
8369 if (second_deriv_symb_id == symb_id)
8370 output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms)
8371 << "[" << tmpIndex1 << "," << tmpIndex2 << "]";
8372 else if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8373 output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
8374 else
8375 output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8376 << "[" << tmpIndex1 << "," << tmpIndex2 << "]";
8377 }
8378
8379 void
8380 SecondDerivExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
8381 const temporary_terms_t &temporary_terms,
8382 const temporary_terms_idxs_t &temporary_terms_idxs,
8383 const deriv_node_temp_terms_t &tef_terms) const
8384 {
8385 assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8386
8387 if (isLatexOutput(output_type))
8388 {
8389 output << R"(\frac{\partial^2 )" << datatree.symbol_table.getTeXName(symb_id)
8390 << R"(}{\partial )" << inputIndex1 << R"(\partial )" << inputIndex2 << "}(";
8391 writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8392 output << ")";
8393 return;
8394 }
8395
8396 if (checkIfTemporaryTermThenWrite(output, output_type, temporary_terms, temporary_terms_idxs))
8397 return;
8398
8399 const int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8400 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8401
8402 const int tmpIndex1 = inputIndex1 - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
8403 const int tmpIndex2 = inputIndex2 - 1 + ARRAY_SUBSCRIPT_OFFSET(output_type);
8404
8405 int indx = getIndxInTefTerms(symb_id, tef_terms);
8406 if (second_deriv_symb_id == symb_id)
8407 if (isCOutput(output_type))
8408 output << "TEFDD_" << indx
8409 << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << " * TEFDD_" << indx << "_nrows + "
8410 << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8411 else
8412 output << "TEFDD_" << getIndxInTefTerms(symb_id, tef_terms)
8413 << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8414 else if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8415 {
8416 if (isCOutput(output_type))
8417 output << "*";
8418 output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
8419 }
8420 else
8421 if (isCOutput(output_type))
8422 output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8423 << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << " * PROBLEM_" << indx << "_nrows"
8424 << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8425 else
8426 output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8427 << LEFT_ARRAY_SUBSCRIPT(output_type) << tmpIndex1 << "," << tmpIndex2 << RIGHT_ARRAY_SUBSCRIPT(output_type);
8428 }
8429
8430 void
writeExternalFunctionOutput(ostream & output,ExprNodeOutputType output_type,const temporary_terms_t & temporary_terms,const temporary_terms_idxs_t & temporary_terms_idxs,deriv_node_temp_terms_t & tef_terms) const8431 SecondDerivExternalFunctionNode::writeExternalFunctionOutput(ostream &output, ExprNodeOutputType output_type,
8432 const temporary_terms_t &temporary_terms,
8433 const temporary_terms_idxs_t &temporary_terms_idxs,
8434 deriv_node_temp_terms_t &tef_terms) const
8435 {
8436 assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8437 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8438 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8439
8440 /* For a node with derivs provided by the user function, call the method
8441 on the non-derived node */
8442 if (second_deriv_symb_id == symb_id)
8443 {
8444 expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8445 parent->writeExternalFunctionOutput(output, output_type, temporary_terms, temporary_terms_idxs,
8446 tef_terms);
8447 return;
8448 }
8449
8450 if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms))
8451 return;
8452
8453 if (isCOutput(output_type))
8454 if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8455 {
8456 stringstream ending;
8457 ending << "_tefdd_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2;
8458 output << "int nlhs" << ending.str() << " = 1;" << endl
8459 << "double *TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << ";" << endl
8460 << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8461 << "int nrhs" << ending.str() << " = 4;" << endl
8462 << "mxArray *prhs" << ending.str() << "[nrhs"<< ending.str() << "];" << endl
8463 << "mwSize dims" << ending.str() << "[2];" << endl;
8464
8465 output << "dims" << ending.str() << "[0] = 1;" << endl
8466 << "dims" << ending.str() << "[1] = " << arguments.size() << ";" << endl;
8467
8468 output << "prhs" << ending.str() << R"([0] = mxCreateString(")" << datatree.symbol_table.getName(symb_id) << R"(");)" << endl
8469 << "prhs" << ending.str() << "[1] = mxCreateDoubleScalar(" << inputIndex1 << ");"<< endl
8470 << "prhs" << ending.str() << "[2] = mxCreateDoubleScalar(" << inputIndex2 << ");"<< endl
8471 << "prhs" << ending.str() << "[3] = mxCreateCellArray(2, dims" << ending.str() << ");"<< endl;
8472
8473 int i = 0;
8474 for (auto argument : arguments)
8475 {
8476 output << "mxSetCell(prhs" << ending.str() << "[3], "
8477 << i++ << ", "
8478 << "mxCreateDoubleScalar("; // All external_function arguments are scalars
8479 argument->writeOutput(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8480 output << "));" << endl;
8481 }
8482
8483 output << "mexCallMATLAB("
8484 << "nlhs" << ending.str() << ", "
8485 << "plhs" << ending.str() << ", "
8486 << "nrhs" << ending.str() << ", "
8487 << "prhs" << ending.str() << R"(, ")"
8488 << R"(hess_element");)" << endl;
8489
8490 output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
8491 << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8492 }
8493 else
8494 {
8495 tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8496 int indx = getIndxInTefTerms(second_deriv_symb_id, tef_terms);
8497 stringstream ending;
8498 ending << "_tefdd_def_" << indx;
8499
8500 output << "int nlhs" << ending.str() << " = 1;" << endl
8501 << "double *TEFDD_def_" << indx << ";" << endl
8502 << "mxArray *plhs" << ending.str() << "[nlhs"<< ending.str() << "];" << endl
8503 << "int nrhs" << ending.str() << " = " << arguments.size() << ";" << endl;
8504 writePrhs(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms, ending.str());
8505
8506 output << "mexCallMATLAB("
8507 << "nlhs" << ending.str() << ", "
8508 << "plhs" << ending.str() << ", "
8509 << "nrhs" << ending.str() << ", "
8510 << "prhs" << ending.str() << R"(, ")"
8511 << datatree.symbol_table.getName(second_deriv_symb_id) << R"(");)" << endl;
8512
8513 output << "TEFDD_def_" << indx << " = mxGetPr(plhs" << ending.str() << "[0]);" << endl;
8514 }
8515 else
8516 {
8517 if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8518 output << "TEFDD_fdd_" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2
8519 << " = hess_element('" << datatree.symbol_table.getName(symb_id) << "',"
8520 << inputIndex1 << "," << inputIndex2 << ",{";
8521 else
8522 {
8523 tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8524 output << "TEFDD_def_" << getIndxInTefTerms(second_deriv_symb_id, tef_terms)
8525 << " = " << datatree.symbol_table.getName(second_deriv_symb_id) << "(";
8526 }
8527
8528 writeExternalFunctionArguments(output, output_type, temporary_terms, temporary_terms_idxs, tef_terms);
8529
8530 if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8531 output << "}";
8532 output << ");" << endl;
8533 }
8534 }
8535
8536 void
writeJsonExternalFunctionOutput(vector<string> & efout,const temporary_terms_t & temporary_terms,deriv_node_temp_terms_t & tef_terms,bool isdynamic) const8537 SecondDerivExternalFunctionNode::writeJsonExternalFunctionOutput(vector<string> &efout,
8538 const temporary_terms_t &temporary_terms,
8539 deriv_node_temp_terms_t &tef_terms,
8540 bool isdynamic) const
8541 {
8542 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8543 assert(second_deriv_symb_id != ExternalFunctionsTable::IDSetButNoNameProvided);
8544
8545 /* For a node with derivs provided by the user function, call the method
8546 on the non-derived node */
8547 if (second_deriv_symb_id == symb_id)
8548 {
8549 expr_t parent = datatree.AddExternalFunction(symb_id, arguments);
8550 parent->writeJsonExternalFunctionOutput(efout, temporary_terms, tef_terms, isdynamic);
8551 return;
8552 }
8553
8554 if (alreadyWrittenAsTefTerm(second_deriv_symb_id, tef_terms))
8555 return;
8556
8557 stringstream ef;
8558 if (second_deriv_symb_id == ExternalFunctionsTable::IDNotSet)
8559 ef << R"({"second_deriv_external_function": {)"
8560 << R"("external_function_term": "TEFDD_fdd_)" << getIndxInTefTerms(symb_id, tef_terms) << "_" << inputIndex1 << "_" << inputIndex2 << R"(")"
8561 << R"(, "analytic_derivative": false)"
8562 << R"(, "wrt1": )" << inputIndex1
8563 << R"(, "wrt2": )" << inputIndex2
8564 << R"(, "value": ")" << datatree.symbol_table.getName(symb_id) << "(";
8565 else
8566 {
8567 tef_terms[{ second_deriv_symb_id, arguments }] = static_cast<int>(tef_terms.size());
8568 ef << R"({"second_deriv_external_function": {)"
8569 << R"("external_function_term": "TEFDD_def_)" << getIndxInTefTerms(second_deriv_symb_id, tef_terms) << R"(")"
8570 << R"(, "analytic_derivative": true)"
8571 << R"(, "value": ")" << datatree.symbol_table.getName(second_deriv_symb_id) << "(";
8572 }
8573
8574 writeJsonExternalFunctionArguments(ef, temporary_terms, tef_terms, isdynamic);
8575 ef << R"lit()"}})lit" << endl;
8576 efout.push_back(ef.str());
8577 }
8578
8579 expr_t
8580 SecondDerivExternalFunctionNode::clone(DataTree &datatree) const
8581 {
8582 vector<expr_t> dynamic_arguments;
8583 for (auto argument : arguments)
8584 dynamic_arguments.push_back(argument->clone(datatree));
8585 return datatree.AddSecondDerivExternalFunction(symb_id, dynamic_arguments,
8586 inputIndex1, inputIndex2);
8587 }
8588
8589 expr_t
8590 SecondDerivExternalFunctionNode::buildSimilarExternalFunctionNode(vector<expr_t> &alt_args, DataTree &alt_datatree) const
8591 {
8592 return alt_datatree.AddSecondDerivExternalFunction(symb_id, alt_args, inputIndex1, inputIndex2);
8593 }
8594
8595 expr_t
8596 SecondDerivExternalFunctionNode::toStatic(DataTree &static_datatree) const
8597 {
8598 vector<expr_t> static_arguments;
8599 for (auto argument : arguments)
8600 static_arguments.push_back(argument->toStatic(static_datatree));
8601 return static_datatree.AddSecondDerivExternalFunction(symb_id, static_arguments,
8602 inputIndex1, inputIndex2);
8603 }
8604
8605 void
8606 SecondDerivExternalFunctionNode::computeXrefs(EquationInfo &ei) const
8607 {
8608 vector<expr_t> dynamic_arguments;
8609 for (auto argument : arguments)
8610 argument->computeXrefs(ei);
8611 }
8612
8613 void
8614 SecondDerivExternalFunctionNode::compile(ostream &CompileCode, unsigned int &instruction_number,
8615 bool lhs_rhs, const temporary_terms_t &temporary_terms,
8616 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8617 const deriv_node_temp_terms_t &tef_terms) const
8618 {
8619 cerr << "SecondDerivExternalFunctionNode::compile: not implemented." << endl;
8620 exit(EXIT_FAILURE);
8621 }
8622
8623 void
8624 SecondDerivExternalFunctionNode::compileExternalFunctionOutput(ostream &CompileCode, unsigned int &instruction_number,
8625 bool lhs_rhs, const temporary_terms_t &temporary_terms,
8626 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8627 deriv_node_temp_terms_t &tef_terms) const
8628 {
8629 cerr << "SecondDerivExternalFunctionNode::compileExternalFunctionOutput: not implemented." << endl;
8630 exit(EXIT_FAILURE);
8631 }
8632
8633 function<bool (expr_t)>
8634 SecondDerivExternalFunctionNode::sameTefTermPredicate() const
8635 {
8636 int second_deriv_symb_id = datatree.external_functions_table.getSecondDerivSymbID(symb_id);
8637 if (second_deriv_symb_id == symb_id)
8638 return [this](expr_t e) {
8639 auto e2 = dynamic_cast<ExternalFunctionNode *>(e);
8640 return (e2 && e2->symb_id == symb_id);
8641 };
8642 else
8643 return [this](expr_t e) {
8644 auto e2 = dynamic_cast<SecondDerivExternalFunctionNode *>(e);
8645 return (e2 && e2->symb_id == symb_id);
8646 };
8647 }
8648
8649 VarExpectationNode::VarExpectationNode(DataTree &datatree_arg,
8650 int idx_arg,
8651 string model_name_arg) :
8652 ExprNode{datatree_arg, idx_arg},
8653 model_name{move(model_name_arg)}
8654 {
8655 }
8656
8657 void
8658 VarExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
8659 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
8660 map<expr_t, pair<int, pair<int, int>>> &reference_count,
8661 bool is_matlab) const
8662 {
8663 cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl;
8664 exit(EXIT_FAILURE);
8665 }
8666
8667 void
8668 VarExpectationNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
8669 temporary_terms_t &temporary_terms,
8670 map<expr_t, pair<int, int>> &first_occurence,
8671 int Curr_block,
8672 vector< vector<temporary_terms_t>> &v_temporary_terms,
8673 int equation) const
8674 {
8675 cerr << "VarExpectationNode::computeTemporaryTerms not implemented." << endl;
8676 exit(EXIT_FAILURE);
8677 }
8678
8679 expr_t
8680 VarExpectationNode::toStatic(DataTree &static_datatree) const
8681 {
8682 cerr << "VarExpectationNode::toStatic not implemented." << endl;
8683 exit(EXIT_FAILURE);
8684 }
8685
8686 expr_t
8687 VarExpectationNode::clone(DataTree &datatree) const
8688 {
8689 return datatree.AddVarExpectation(model_name);
8690 }
8691
8692 void
8693 VarExpectationNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
8694 const temporary_terms_t &temporary_terms,
8695 const temporary_terms_idxs_t &temporary_terms_idxs,
8696 const deriv_node_temp_terms_t &tef_terms) const
8697 {
8698 assert(output_type != ExprNodeOutputType::matlabOutsideModel);
8699
8700 if (isLatexOutput(output_type))
8701 {
8702 output << "VAR_EXPECTATION(" << model_name << ')';
8703 return;
8704 }
8705
8706 cerr << "VarExpectationNode::writeOutput not implemented for non-LaTeX." << endl;
8707 exit(EXIT_FAILURE);
8708 }
8709
8710 int
8711 VarExpectationNode::maxEndoLead() const
8712 {
8713 cerr << "VarExpectationNode::maxEndoLead not implemented." << endl;
8714 exit(EXIT_FAILURE);
8715 }
8716
8717 int
8718 VarExpectationNode::maxExoLead() const
8719 {
8720 cerr << "VarExpectationNode::maxExoLead not implemented." << endl;
8721 exit(EXIT_FAILURE);
8722 }
8723
8724 int
8725 VarExpectationNode::maxEndoLag() const
8726 {
8727 cerr << "VarExpectationNode::maxEndoLead not implemented." << endl;
8728 exit(EXIT_FAILURE);
8729 }
8730
8731 int
8732 VarExpectationNode::maxExoLag() const
8733 {
8734 cerr << "VarExpectationNode::maxExoLead not implemented." << endl;
8735 exit(EXIT_FAILURE);
8736 }
8737
8738 int
8739 VarExpectationNode::maxLead() const
8740 {
8741 cerr << "VarExpectationNode::maxLead not implemented." << endl;
8742 exit(EXIT_FAILURE);
8743 }
8744
8745 int
8746 VarExpectationNode::maxLag() const
8747 {
8748 cerr << "VarExpectationNode::maxLag not implemented." << endl;
8749 exit(EXIT_FAILURE);
8750 }
8751
8752 int
8753 VarExpectationNode::maxLagWithDiffsExpanded() const
8754 {
8755 /* This node will be substituted by lagged variables, so in theory we should
8756 return a strictly positive value. But from here this value is not easy to
8757 compute.
8758 We return 0, because currently this function is only called from
8759 DynamicModel::setLeadsLagsOrig(), and the maximum lag will nevertheless be
8760 correctly computed because the maximum lag of the VAR will be taken into
8761 account via the corresponding equations. */
8762 return 0;
8763 }
8764
8765 expr_t
8766 VarExpectationNode::undiff() const
8767 {
8768 cerr << "VarExpectationNode::undiff not implemented." << endl;
8769 exit(EXIT_FAILURE);
8770 }
8771
8772 int
8773 VarExpectationNode::VarMinLag() const
8774 {
8775 cerr << "VarExpectationNode::VarMinLag not implemented." << endl;
8776 exit(EXIT_FAILURE);
8777 }
8778
8779 int
8780 VarExpectationNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
8781 {
8782 cerr << "VarExpectationNode::VarMaxLag not implemented." << endl;
8783 exit(EXIT_FAILURE);
8784 }
8785
8786 int
8787 VarExpectationNode::PacMaxLag(int lhs_symb_id) const
8788 {
8789 cerr << "VarExpectationNode::PacMaxLag not implemented." << endl;
8790 exit(EXIT_FAILURE);
8791 }
8792
8793 int
8794 VarExpectationNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
8795 {
8796 return -1;
8797 }
8798
8799 expr_t
8800 VarExpectationNode::decreaseLeadsLags(int n) const
8801 {
8802 cerr << "VarExpectationNode::decreaseLeadsLags not implemented." << endl;
8803 exit(EXIT_FAILURE);
8804 }
8805
8806 void
8807 VarExpectationNode::prepareForDerivation()
8808 {
8809 cerr << "VarExpectationNode::prepareForDerivation not implemented." << endl;
8810 exit(EXIT_FAILURE);
8811 }
8812
8813 expr_t
8814 VarExpectationNode::computeDerivative(int deriv_id)
8815 {
8816 cerr << "VarExpectationNode::computeDerivative not implemented." << endl;
8817 exit(EXIT_FAILURE);
8818 }
8819
8820 expr_t
8821 VarExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
8822 {
8823 cerr << "VarExpectationNode::getChainRuleDerivative not implemented." << endl;
8824 exit(EXIT_FAILURE);
8825 }
8826
8827 bool
8828 VarExpectationNode::containsExternalFunction() const
8829 {
8830 return false;
8831 }
8832
8833 double
8834 VarExpectationNode::eval(const eval_context_t &eval_context) const noexcept(false)
8835 {
8836 throw EvalException();
8837 }
8838
8839 int
8840 VarExpectationNode::countDiffs() const
8841 {
8842 cerr << "VarExpectationNode::countDiffs not implemented." << endl;
8843 exit(EXIT_FAILURE);
8844 }
8845
8846 void
8847 VarExpectationNode::computeXrefs(EquationInfo &ei) const
8848 {
8849 }
8850
8851 void
8852 VarExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
8853 {
8854 cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
8855 exit(EXIT_FAILURE);
8856 }
8857
8858 void
8859 VarExpectationNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
8860 {
8861 }
8862
8863 void
8864 VarExpectationNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
8865 {
8866 cerr << "VarExpectationNode::collectTemporary_terms not implemented." << endl;
8867 exit(EXIT_FAILURE);
8868 }
8869
8870 void
8871 VarExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_number,
8872 bool lhs_rhs, const temporary_terms_t &temporary_terms,
8873 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
8874 const deriv_node_temp_terms_t &tef_terms) const
8875 {
8876 cerr << "VarExpectationNode::compile not implemented." << endl;
8877 exit(EXIT_FAILURE);
8878 }
8879
8880 pair<int, expr_t>
8881 VarExpectationNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
8882 {
8883 cerr << "VarExpectationNode::normalizeEquation not implemented." << endl;
8884 exit(EXIT_FAILURE);
8885 }
8886
8887 expr_t
8888 VarExpectationNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
8889 {
8890 cerr << "VarExpectationNode::substituteEndoLeadGreaterThanTwo not implemented." << endl;
8891 exit(EXIT_FAILURE);
8892 }
8893
8894 expr_t
8895 VarExpectationNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8896 {
8897 cerr << "VarExpectationNode::substituteEndoLagGreaterThanTwo not implemented." << endl;
8898 exit(EXIT_FAILURE);
8899 }
8900
8901 expr_t
8902 VarExpectationNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
8903 {
8904 cerr << "VarExpectationNode::substituteExoLead not implemented." << endl;
8905 exit(EXIT_FAILURE);
8906 }
8907
8908 expr_t
8909 VarExpectationNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8910 {
8911 cerr << "VarExpectationNode::substituteExoLag not implemented." << endl;
8912 exit(EXIT_FAILURE);
8913 }
8914
8915 expr_t
8916 VarExpectationNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
8917 {
8918 return const_cast<VarExpectationNode *>(this);
8919 }
8920
8921 expr_t
8922 VarExpectationNode::substituteAdl() const
8923 {
8924 return const_cast<VarExpectationNode *>(this);
8925 }
8926
8927 expr_t
8928 VarExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
8929 {
8930 auto it = subst_table.find(model_name);
8931 if (it == subst_table.end())
8932 {
8933 cerr << "ERROR: unknown model '" << model_name << "' used in var_expectation expression" << endl;
8934 exit(EXIT_FAILURE);
8935 }
8936 return it->second;
8937 }
8938
8939 void
8940 VarExpectationNode::findDiffNodes(lag_equivalence_table_t &nodes) const
8941 {
8942 }
8943
8944 void
8945 VarExpectationNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
8946 {
8947 }
8948
8949 int
8950 VarExpectationNode::findTargetVariable(int lhs_symb_id) const
8951 {
8952 return -1;
8953 }
8954
8955 expr_t
8956 VarExpectationNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
8957 vector<BinaryOpNode *> &neweqs) const
8958 {
8959 return const_cast<VarExpectationNode *>(this);
8960 }
8961
8962 expr_t
8963 VarExpectationNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8964 {
8965 return const_cast<VarExpectationNode *>(this);
8966 }
8967
8968 expr_t
8969 VarExpectationNode::substitutePacExpectation(const string &name, expr_t subexpr)
8970 {
8971 return const_cast<VarExpectationNode *>(this);
8972 }
8973
8974 expr_t
8975 VarExpectationNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
8976 {
8977 cerr << "VarExpectationNode::differentiateForwardVars not implemented." << endl;
8978 exit(EXIT_FAILURE);
8979 }
8980
8981 bool
8982 VarExpectationNode::containsPacExpectation(const string &pac_model_name) const
8983 {
8984 return false;
8985 }
8986
8987 bool
8988 VarExpectationNode::containsEndogenous() const
8989 {
8990 cerr << "VarExpectationNode::containsEndogenous not implemented." << endl;
8991 exit(EXIT_FAILURE);
8992 }
8993
8994 bool
8995 VarExpectationNode::containsExogenous() const
8996 {
8997 cerr << "VarExpectationNode::containsExogenous not implemented." << endl;
8998 exit(EXIT_FAILURE);
8999 }
9000
9001 bool
9002 VarExpectationNode::isNumConstNodeEqualTo(double value) const
9003 {
9004 return false;
9005 }
9006
9007 expr_t
9008 VarExpectationNode::decreaseLeadsLagsPredeterminedVariables() const
9009 {
9010 cerr << "VarExpectationNode::decreaseLeadsLagsPredeterminedVariables not implemented." << endl;
9011 exit(EXIT_FAILURE);
9012 }
9013
9014 bool
9015 VarExpectationNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
9016 {
9017 return false;
9018 }
9019
9020 expr_t
9021 VarExpectationNode::replaceTrendVar() const
9022 {
9023 cerr << "VarExpectationNode::replaceTrendVar not implemented." << endl;
9024 exit(EXIT_FAILURE);
9025 }
9026
9027 expr_t
9028 VarExpectationNode::detrend(int symb_id, bool log_trend, expr_t trend) const
9029 {
9030 cerr << "VarExpectationNode::detrend not implemented." << endl;
9031 exit(EXIT_FAILURE);
9032 }
9033
9034 expr_t
9035 VarExpectationNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
9036 {
9037 cerr << "VarExpectationNode::removeTrendLeadLag not implemented." << endl;
9038 exit(EXIT_FAILURE);
9039 }
9040
9041 bool
9042 VarExpectationNode::isInStaticForm() const
9043 {
9044 cerr << "VarExpectationNode::isInStaticForm not implemented." << endl;
9045 exit(EXIT_FAILURE);
9046 }
9047
9048 bool
9049 VarExpectationNode::isVarModelReferenced(const string &model_info_name) const
9050 {
9051 /* TODO: should check here whether the var_expectation_model is equal to the
9052 argument; we probably need a VarModelTable class to do that elegantly */
9053 return false;
9054 }
9055
9056 void
9057 VarExpectationNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
9058 {
9059 }
9060
9061 bool
9062 VarExpectationNode::isParamTimesEndogExpr() const
9063 {
9064 return false;
9065 }
9066
9067 expr_t
9068 VarExpectationNode::substituteStaticAuxiliaryVariable() const
9069 {
9070 return const_cast<VarExpectationNode *>(this);
9071 }
9072
9073 void
9074 VarExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
9075 {
9076 return;
9077 }
9078
9079 expr_t
9080 VarExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
9081 {
9082 return const_cast<VarExpectationNode *>(this);
9083 }
9084
9085 void
9086 VarExpectationNode::writeJsonAST(ostream &output) const
9087 {
9088 output << R"({"node_type" : "VarExpectationNode", )"
9089 << R"("name" : ")" << model_name << R"("})";
9090 }
9091
9092 void
9093 VarExpectationNode::writeJsonOutput(ostream &output,
9094 const temporary_terms_t &temporary_terms,
9095 const deriv_node_temp_terms_t &tef_terms,
9096 bool isdynamic) const
9097 {
9098 output << "var_expectation("
9099 << "model_name = " << model_name
9100 << ")";
9101 }
9102
9103 PacExpectationNode::PacExpectationNode(DataTree &datatree_arg,
9104 int idx_arg,
9105 string model_name_arg) :
9106 ExprNode{datatree_arg, idx_arg},
9107 model_name{move(model_name_arg)}
9108 {
9109 }
9110
9111 void
9112 PacExpectationNode::computeTemporaryTerms(const pair<int, int> &derivOrder,
9113 map<pair<int, int>, temporary_terms_t> &temp_terms_map,
9114 map<expr_t, pair<int, pair<int, int>>> &reference_count,
9115 bool is_matlab) const
9116 {
9117 temp_terms_map[derivOrder].insert(const_cast<PacExpectationNode *>(this));
9118 }
9119
9120 void
9121 PacExpectationNode::computeTemporaryTerms(map<expr_t, int> &reference_count,
9122 temporary_terms_t &temporary_terms,
9123 map<expr_t, pair<int, int>> &first_occurence,
9124 int Curr_block,
9125 vector< vector<temporary_terms_t>> &v_temporary_terms,
9126 int equation) const
9127 {
9128 expr_t this2 = const_cast<PacExpectationNode *>(this);
9129 temporary_terms.insert(this2);
9130 first_occurence[this2] = { Curr_block, equation };
9131 v_temporary_terms[Curr_block][equation].insert(this2);
9132 }
9133
9134 expr_t
9135 PacExpectationNode::toStatic(DataTree &static_datatree) const
9136 {
9137 return static_datatree.AddPacExpectation(string(model_name));
9138 }
9139
9140 expr_t
9141 PacExpectationNode::clone(DataTree &datatree) const
9142 {
9143 return datatree.AddPacExpectation(string(model_name));
9144 }
9145
9146 void
9147 PacExpectationNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
9148 const temporary_terms_t &temporary_terms,
9149 const temporary_terms_idxs_t &temporary_terms_idxs,
9150 const deriv_node_temp_terms_t &tef_terms) const
9151 {
9152 assert(output_type != ExprNodeOutputType::matlabOutsideModel);
9153 if (isLatexOutput(output_type))
9154 {
9155 output << "PAC_EXPECTATION" << LEFT_PAR(output_type) << model_name << RIGHT_PAR(output_type);
9156 return;
9157 }
9158 }
9159
9160 int
9161 PacExpectationNode::maxEndoLead() const
9162 {
9163 return 0;
9164 }
9165
9166 int
9167 PacExpectationNode::maxExoLead() const
9168 {
9169 return 0;
9170 }
9171
9172 int
9173 PacExpectationNode::maxEndoLag() const
9174 {
9175 return 0;
9176 }
9177
9178 int
9179 PacExpectationNode::maxExoLag() const
9180 {
9181 return 0;
9182 }
9183
9184 int
9185 PacExpectationNode::maxLead() const
9186 {
9187 return 0;
9188 }
9189
9190 int
9191 PacExpectationNode::maxLag() const
9192 {
9193 return 0;
9194 }
9195
9196 int
9197 PacExpectationNode::maxLagWithDiffsExpanded() const
9198 {
9199 // Same comment as in VarExpectationNode::maxLagWithDiffsExpanded()
9200 return 0;
9201 }
9202
9203 expr_t
9204 PacExpectationNode::undiff() const
9205 {
9206 return const_cast<PacExpectationNode *>(this);
9207 }
9208
9209 int
9210 PacExpectationNode::VarMinLag() const
9211 {
9212 return 1;
9213 }
9214
9215 int
9216 PacExpectationNode::VarMaxLag(const set<expr_t> &lhs_lag_equiv) const
9217 {
9218 return 0;
9219 }
9220
9221 int
9222 PacExpectationNode::PacMaxLag(int lhs_symb_id) const
9223 {
9224 return 0;
9225 }
9226
9227 int
9228 PacExpectationNode::getPacTargetSymbId(int lhs_symb_id, int undiff_lhs_symb_id) const
9229 {
9230 return -1;
9231 }
9232
9233 expr_t
9234 PacExpectationNode::decreaseLeadsLags(int n) const
9235 {
9236 return const_cast<PacExpectationNode *>(this);
9237 }
9238
9239 void
9240 PacExpectationNode::prepareForDerivation()
9241 {
9242 cerr << "PacExpectationNode::prepareForDerivation: shouldn't arrive here." << endl;
9243 exit(EXIT_FAILURE);
9244 }
9245
9246 expr_t
9247 PacExpectationNode::computeDerivative(int deriv_id)
9248 {
9249 cerr << "PacExpectationNode::computeDerivative: shouldn't arrive here." << endl;
9250 exit(EXIT_FAILURE);
9251 }
9252
9253 expr_t
9254 PacExpectationNode::getChainRuleDerivative(int deriv_id, const map<int, expr_t> &recursive_variables)
9255 {
9256 cerr << "PacExpectationNode::getChainRuleDerivative: shouldn't arrive here." << endl;
9257 exit(EXIT_FAILURE);
9258 }
9259
9260 bool
9261 PacExpectationNode::containsExternalFunction() const
9262 {
9263 return false;
9264 }
9265
9266 double
9267 PacExpectationNode::eval(const eval_context_t &eval_context) const noexcept(false)
9268 {
9269 throw EvalException();
9270 }
9271
9272 void
9273 PacExpectationNode::computeXrefs(EquationInfo &ei) const
9274 {
9275 }
9276
9277 void
9278 PacExpectationNode::collectVARLHSVariable(set<expr_t> &result) const
9279 {
9280 cerr << "ERROR: you can only have variables or unary ops on LHS of VAR" << endl;
9281 exit(EXIT_FAILURE);
9282 }
9283
9284 void
9285 PacExpectationNode::collectDynamicVariables(SymbolType type_arg, set<pair<int, int>> &result) const
9286 {
9287 }
9288
9289 void
9290 PacExpectationNode::collectTemporary_terms(const temporary_terms_t &temporary_terms, temporary_terms_inuse_t &temporary_terms_inuse, int Curr_Block) const
9291 {
9292 if (temporary_terms.find(const_cast<PacExpectationNode *>(this)) != temporary_terms.end())
9293 temporary_terms_inuse.insert(idx);
9294 }
9295
9296 void
9297 PacExpectationNode::compile(ostream &CompileCode, unsigned int &instruction_number,
9298 bool lhs_rhs, const temporary_terms_t &temporary_terms,
9299 const map_idx_t &map_idx, bool dynamic, bool steady_dynamic,
9300 const deriv_node_temp_terms_t &tef_terms) const
9301 {
9302 cerr << "PacExpectationNode::compile not implemented." << endl;
9303 exit(EXIT_FAILURE);
9304 }
9305
9306 int
9307 PacExpectationNode::countDiffs() const
9308 {
9309 return 0;
9310 }
9311
9312 pair<int, expr_t>
9313 PacExpectationNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>> &List_of_Op_RHS) const
9314 {
9315 cerr << "PacExpectationNode::normalizeEquation not implemented." << endl;
9316 exit(EXIT_FAILURE);
9317 }
9318
9319 expr_t
9320 PacExpectationNode::substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
9321 {
9322 return const_cast<PacExpectationNode *>(this);
9323 }
9324
9325 expr_t
9326 PacExpectationNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9327 {
9328 return const_cast<PacExpectationNode *>(this);
9329 }
9330
9331 expr_t
9332 PacExpectationNode::substituteExoLead(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool deterministic_model) const
9333 {
9334 return const_cast<PacExpectationNode *>(this);
9335 }
9336
9337 expr_t
9338 PacExpectationNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9339 {
9340 return const_cast<PacExpectationNode *>(this);
9341 }
9342
9343 expr_t
9344 PacExpectationNode::substituteExpectation(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs, bool partial_information_model) const
9345 {
9346 return const_cast<PacExpectationNode *>(this);
9347 }
9348
9349 expr_t
9350 PacExpectationNode::substituteAdl() const
9351 {
9352 return const_cast<PacExpectationNode *>(this);
9353 }
9354
9355 expr_t
9356 PacExpectationNode::substituteVarExpectation(const map<string, expr_t> &subst_table) const
9357 {
9358 return const_cast<PacExpectationNode *>(this);
9359 }
9360
9361 void
9362 PacExpectationNode::findDiffNodes(lag_equivalence_table_t &nodes) const
9363 {
9364 }
9365
9366 void
9367 PacExpectationNode::findUnaryOpNodesForAuxVarCreation(lag_equivalence_table_t &nodes) const
9368 {
9369 }
9370
9371 int
9372 PacExpectationNode::findTargetVariable(int lhs_symb_id) const
9373 {
9374 return -1;
9375 }
9376
9377 expr_t
9378 PacExpectationNode::substituteDiff(const lag_equivalence_table_t &nodes, subst_table_t &subst_table,
9379 vector<BinaryOpNode *> &neweqs) const
9380 {
9381 return const_cast<PacExpectationNode *>(this);
9382 }
9383
9384 expr_t
9385 PacExpectationNode::substituteUnaryOpNodes(const lag_equivalence_table_t &nodes, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9386 {
9387 return const_cast<PacExpectationNode *>(this);
9388 }
9389
9390 expr_t
9391 PacExpectationNode::differentiateForwardVars(const vector<string> &subset, subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const
9392 {
9393 return const_cast<PacExpectationNode *>(this);
9394 }
9395
9396 bool
9397 PacExpectationNode::containsPacExpectation(const string &pac_model_name) const
9398 {
9399 if (pac_model_name.empty())
9400 return true;
9401 else
9402 return pac_model_name == model_name;
9403 }
9404
9405 bool
9406 PacExpectationNode::containsEndogenous() const
9407 {
9408 return true;
9409 }
9410
9411 bool
9412 PacExpectationNode::containsExogenous() const
9413 {
9414 return false;
9415 }
9416
9417 bool
9418 PacExpectationNode::isNumConstNodeEqualTo(double value) const
9419 {
9420 return false;
9421 }
9422
9423 expr_t
9424 PacExpectationNode::decreaseLeadsLagsPredeterminedVariables() const
9425 {
9426 return const_cast<PacExpectationNode *>(this);
9427 }
9428
9429 bool
9430 PacExpectationNode::isVariableNodeEqualTo(SymbolType type_arg, int variable_id, int lag_arg) const
9431 {
9432 return false;
9433 }
9434
9435 expr_t
9436 PacExpectationNode::replaceTrendVar() const
9437 {
9438 return const_cast<PacExpectationNode *>(this);
9439 }
9440
9441 expr_t
9442 PacExpectationNode::detrend(int symb_id, bool log_trend, expr_t trend) const
9443 {
9444 return const_cast<PacExpectationNode *>(this);
9445 }
9446
9447 expr_t
9448 PacExpectationNode::removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const
9449 {
9450 return const_cast<PacExpectationNode *>(this);
9451 }
9452
9453 bool
9454 PacExpectationNode::isInStaticForm() const
9455 {
9456 return false;
9457 }
9458
9459 bool
9460 PacExpectationNode::isVarModelReferenced(const string &model_info_name) const
9461 {
9462 return model_name == model_info_name;
9463 }
9464
9465 void
9466 PacExpectationNode::getEndosAndMaxLags(map<string, int> &model_endos_and_lags) const
9467 {
9468 }
9469
9470 expr_t
9471 PacExpectationNode::substituteStaticAuxiliaryVariable() const
9472 {
9473 return const_cast<PacExpectationNode *>(this);
9474 }
9475
9476 void
9477 PacExpectationNode::findConstantEquations(map<VariableNode *, NumConstNode *> &table) const
9478 {
9479 return;
9480 }
9481
9482 expr_t
9483 PacExpectationNode::replaceVarsInEquation(map<VariableNode *, NumConstNode *> &table) const
9484 {
9485 return const_cast<PacExpectationNode *>(this);
9486 }
9487
9488 void
9489 PacExpectationNode::writeJsonAST(ostream &output) const
9490 {
9491 output << R"({"node_type" : "PacExpectationNode", )"
9492 << R"("name" : ")" << model_name << R"("})";
9493 }
9494
9495 void
9496 PacExpectationNode::writeJsonOutput(ostream &output,
9497 const temporary_terms_t &temporary_terms,
9498 const deriv_node_temp_terms_t &tef_terms,
9499 bool isdynamic) const
9500 {
9501 output << "pac_expectation("
9502 << "model_name = " << model_name
9503 << ")";
9504 }
9505
9506 bool
9507 PacExpectationNode::isParamTimesEndogExpr() const
9508 {
9509 return false;
9510 }
9511
9512 expr_t
9513 PacExpectationNode::substitutePacExpectation(const string &name, expr_t subexpr)
9514 {
9515 if (model_name != name)
9516 return const_cast<PacExpectationNode *>(this);
9517 return subexpr;
9518 }
9519
9520 void
9521 ExprNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const
9522 {
9523 terms.emplace_back(const_cast<ExprNode *>(this), current_sign);
9524 }
9525
9526 void
9527 UnaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const
9528 {
9529 if (op_code == UnaryOpcode::uminus)
9530 arg->decomposeAdditiveTerms(terms, -current_sign);
9531 else
9532 ExprNode::decomposeAdditiveTerms(terms, current_sign);
9533 }
9534
9535 void
9536 BinaryOpNode::decomposeAdditiveTerms(vector<pair<expr_t, int>> &terms, int current_sign) const
9537 {
9538 if (op_code == BinaryOpcode::plus || op_code == BinaryOpcode::minus)
9539 {
9540 arg1->decomposeAdditiveTerms(terms, current_sign);
9541 if (op_code == BinaryOpcode::plus)
9542 arg2->decomposeAdditiveTerms(terms, current_sign);
9543 else
9544 arg2->decomposeAdditiveTerms(terms, -current_sign);
9545 }
9546 else
9547 ExprNode::decomposeAdditiveTerms(terms, current_sign);
9548 }
9549
9550 tuple<int, int, int, double>
9551 ExprNode::matchVariableTimesConstantTimesParam(bool variable_obligatory) const
9552 {
9553 int variable_id = -1, lag = 0, param_id = -1;
9554 double constant = 1.0;
9555 matchVTCTPHelper(variable_id, lag, param_id, constant, false);
9556 if (variable_obligatory && variable_id == -1)
9557 throw MatchFailureException{"No variable in this expression"};
9558 return {variable_id, lag, param_id, constant};
9559 }
9560
9561 void
9562 ExprNode::matchVTCTPHelper(int &var_id, int &lag, int ¶m_id, double &constant, bool at_denominator) const
9563 {
9564 throw MatchFailureException{"Expression not allowed in linear combination of variables"};
9565 }
9566
9567 void
9568 NumConstNode::matchVTCTPHelper(int &var_id, int &lag, int ¶m_id, double &constant, bool at_denominator) const
9569 {
9570 double myvalue = eval({});
9571 if (at_denominator)
9572 constant /= myvalue;
9573 else
9574 constant *= myvalue;
9575 }
9576
9577 void
9578 VariableNode::matchVTCTPHelper(int &var_id, int &lag, int ¶m_id, double &constant, bool at_denominator) const
9579 {
9580 if (at_denominator)
9581 throw MatchFailureException{"A variable or parameter cannot appear at denominator"};
9582
9583 SymbolType type = get_type();
9584 if (type == SymbolType::endogenous || type == SymbolType::exogenous)
9585 {
9586 if (var_id != -1)
9587 throw MatchFailureException{"More than one variable in this expression"};
9588 var_id = symb_id;
9589 lag = this->lag;
9590 }
9591 else if (type == SymbolType::parameter)
9592 {
9593 if (param_id != -1)
9594 throw MatchFailureException{"More than one parameter in this expression"};
9595 param_id = symb_id;
9596 }
9597 else
9598 throw MatchFailureException{"Symbol " + datatree.symbol_table.getName(symb_id) + " not allowed here"};
9599 }
9600
9601 void
9602 UnaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int ¶m_id, double &constant, bool at_denominator) const
9603 {
9604 if (op_code == UnaryOpcode::uminus)
9605 {
9606 constant = -constant;
9607 arg->matchVTCTPHelper(var_id, lag, param_id, constant, at_denominator);
9608 }
9609 else
9610 throw MatchFailureException{"Operator not allowed in this expression"};
9611 }
9612
9613 void
9614 BinaryOpNode::matchVTCTPHelper(int &var_id, int &lag, int ¶m_id, double &constant, bool at_denominator) const
9615 {
9616 if (op_code == BinaryOpcode::times || op_code == BinaryOpcode::divide)
9617 {
9618 arg1->matchVTCTPHelper(var_id, lag, param_id, constant, at_denominator);
9619 if (op_code == BinaryOpcode::times)
9620 arg2->matchVTCTPHelper(var_id, lag, param_id, constant, at_denominator);
9621 else
9622 arg2->matchVTCTPHelper(var_id, lag, param_id, constant, !at_denominator);
9623 }
9624 else
9625 throw MatchFailureException{"Operator not allowed in this expression"};
9626 }
9627
9628 vector<tuple<int, int, int, double>>
9629 ExprNode::matchLinearCombinationOfVariables(bool variable_obligatory_in_each_term) const
9630 {
9631 vector<pair<expr_t, int>> terms;
9632 decomposeAdditiveTerms(terms);
9633
9634 vector<tuple<int, int, int, double>> result;
9635
9636 for (const auto &it : terms)
9637 {
9638 expr_t term = it.first;
9639 int sign = it.second;
9640 auto m = term->matchVariableTimesConstantTimesParam(variable_obligatory_in_each_term);
9641 get<3>(m) *= sign;
9642 result.push_back(m);
9643 }
9644 return result;
9645 }
9646
9647 pair<int, vector<tuple<int, int, int, double>>>
9648 ExprNode::matchParamTimesLinearCombinationOfVariables() const
9649 {
9650 auto bopn = dynamic_cast<const BinaryOpNode *>(this);
9651 if (!bopn || bopn->op_code != BinaryOpcode::times)
9652 throw MatchFailureException{"Not a multiplicative expression"};
9653
9654 expr_t param = bopn->arg1, lincomb = bopn->arg2;
9655
9656 auto is_param = [](expr_t e) {
9657 auto vn = dynamic_cast<VariableNode *>(e);
9658 return vn && vn->get_type() == SymbolType::parameter;
9659 };
9660
9661 if (!is_param(param))
9662 {
9663 swap(param, lincomb);
9664 if (!is_param(param))
9665 throw MatchFailureException{"No parameter on either side of the multiplication"};
9666 }
9667
9668 return { dynamic_cast<VariableNode *>(param)->symb_id, lincomb->matchLinearCombinationOfVariables() };
9669 }
9670