1 #ifndef STAN_LANG_GENERATOR_GENERATE_WRITE_ARRAY_METHOD_HPP
2 #define STAN_LANG_GENERATOR_GENERATE_WRITE_ARRAY_METHOD_HPP
3 
4 #include <stan/lang/ast.hpp>
5 #include <stan/lang/generator/constants.hpp>
6 #include <stan/lang/generator/generate_block_var.hpp>
7 #include <stan/lang/generator/generate_catch_throw_located.hpp>
8 #include <stan/lang/generator/generate_comment.hpp>
9 #include <stan/lang/generator/generate_indent.hpp>
10 #include <stan/lang/generator/generate_read_transform_params.hpp>
11 #include <stan/lang/generator/generate_statements.hpp>
12 #include <stan/lang/generator/generate_try.hpp>
13 #include <stan/lang/generator/generate_validate_block_var.hpp>
14 #include <stan/lang/generator/generate_write_block_var.hpp>
15 #include <stan/lang/generator/generate_void_statement.hpp>
16 
17 #include <ostream>
18 #include <string>
19 
20 namespace stan {
21   namespace lang {
22 
23     /**
24      * Generate the <code>write_array</code> method for the specified
25      * program, with specified model name to the specified stream.
26      *
27      * @param[in] prog program from which to generate
28      * @param[in] model_name name of model
29      * @param[in,out] o stream for generating
30      */
generate_write_array_method(const program & prog,const std::string & model_name,std::ostream & o)31     void generate_write_array_method(const program& prog,
32                                      const std::string& model_name,
33                                      std::ostream& o) {
34       o << INDENT << "template <typename RNG>" << EOL;
35       o << INDENT << "void write_array(RNG& base_rng__," << EOL;
36       o << INDENT << "                 std::vector<double>& params_r__," << EOL;
37       o << INDENT << "                 std::vector<int>& params_i__," << EOL;
38       o << INDENT << "                 std::vector<double>& vars__," << EOL;
39       o << INDENT << "                 bool include_tparams__ = true," << EOL;
40       o << INDENT << "                 bool include_gqs__ = true," << EOL;
41       o << INDENT
42         << "                 std::ostream* pstream__ = 0) const {" << EOL;
43       o << INDENT2 << "typedef double local_scalar_t__;" << EOL2;
44 
45       o << INDENT2 << "vars__.resize(0);" << EOL;
46       o << INDENT2
47         << "stan::io::reader<local_scalar_t__> in__(params_r__, params_i__);"
48         << EOL;
49       o << INDENT2 << "static const char* function__ = \""
50         << model_name << "_namespace::write_array\";" << EOL;
51       generate_void_statement("function__", 2, o);
52       o << EOL;
53 
54       generate_comment("read-transform, write parameters", 2, o);
55       generate_read_transform_params(prog.parameter_decl_, 2, o);
56 
57       o << INDENT2 <<  "double lp__ = 0.0;" << EOL;
58       generate_void_statement("lp__", 2, o);
59       o << INDENT2 << "stan::math::accumulator<double> lp_accum__;" << EOL2;
60       o << INDENT2
61         << "local_scalar_t__ DUMMY_VAR__"
62         << "(std::numeric_limits<double>::quiet_NaN());"
63         << EOL;
64       o << INDENT2 << "(void) DUMMY_VAR__;  // suppress unused var warning"
65         << EOL2;
66       o << INDENT2 << "if (!include_tparams__ && !include_gqs__) return;"
67         << EOL2;
68 
69       generate_try(2, o);
70       if (prog.derived_decl_.first.size() > 0) {
71         generate_comment("declare and define transformed parameters", 3, o);
72         for (size_t i = 0; i < prog.derived_decl_.first.size(); ++i) {
73           generate_indent(3, o);
74           o << "current_statement_begin__ = "
75             <<  prog.derived_decl_.first[i].begin_line_ << ";"
76             << EOL;
77           generate_block_var(prog.derived_decl_.first[i], "double", 3, o);
78           o << EOL;
79         }
80       }
81 
82       if (prog.derived_decl_.second.size() > 0) {
83         generate_comment("do transformed parameters statements", 3, o);
84         generate_statements(prog.derived_decl_.second, 3, o);
85         o << EOL;
86       }
87 
88       o << INDENT3 << "if (!include_gqs__ && !include_tparams__) return;"
89         << EOL;
90 
91       if (prog.derived_decl_.first.size() > 0) {
92         generate_comment("validate transformed parameters", 3, o);
93         o << INDENT3
94           << "const char* function__ = \"validate transformed params\";"
95           << EOL;
96         o << INDENT3
97           << "(void) function__;  // dummy to suppress unused var warning"
98           << EOL;
99         o << EOL;
100 
101         for (size_t i = 0; i < prog.derived_decl_.first.size(); ++i) {
102           block_var_decl bvd = prog.derived_decl_.first[i];
103           if (bvd.type().innermost_type().is_constrained()) {
104             generate_indent(3, o);
105             o << "current_statement_begin__ = "
106               <<  bvd.begin_line_ << ";" << EOL;
107             generate_validate_block_var(bvd, 3, o);
108           }
109         }
110 
111         generate_comment("write transformed parameters", 3, o);
112         o << INDENT3 << "if (include_tparams__) {" << EOL;
113         for (size_t i = 0; i < prog.derived_decl_.first.size(); ++i) {
114           generate_write_block_var(prog.derived_decl_.first[i], 4, o);
115         }
116         o << INDENT3 << "}" << EOL;
117       }
118 
119       o << INDENT3 << "if (!include_gqs__) return;"
120         << EOL;
121       if (prog.generated_decl_.first.size() > 0) {
122         generate_comment("declare and define generated quantities", 3, o);
123         for (size_t i = 0; i < prog.generated_decl_.first.size(); ++i) {
124           generate_indent(3, o);
125           o << "current_statement_begin__ = "
126             <<  prog.generated_decl_.first[i].begin_line_ << ";"
127             << EOL;
128           generate_block_var(prog.generated_decl_.first[i], "double", 3, o);
129           o << EOL;
130         }
131       }
132 
133       if (prog.generated_decl_.second.size() > 0) {
134         generate_comment("generated quantities statements", 3, o);
135         generate_statements(prog.generated_decl_.second, 3, o);
136         o << EOL;
137       }
138 
139       if (prog.generated_decl_.first.size() > 0) {
140         generate_comment("validate, write generated quantities", 3, o);
141         for (size_t i = 0; i < prog.generated_decl_.first.size(); ++i) {
142           generate_indent(3, o);
143           o << "current_statement_begin__ = "
144             <<  prog.generated_decl_.first[i].begin_line_ << ";"
145             << EOL;
146           generate_validate_block_var(prog.generated_decl_.first[i], 3, o);
147           generate_write_block_var(prog.generated_decl_.first[i], 3, o);
148           o << EOL;
149         }
150       }
151       generate_catch_throw_located(2, o);
152 
153       o << INDENT << "}" << EOL2;
154 
155       o << INDENT << "template <typename RNG>" << EOL;
156       o << INDENT << "void write_array(RNG& base_rng," << EOL;
157       o << INDENT
158         << "                 Eigen::Matrix<double,Eigen::Dynamic,1>& params_r,"
159         << EOL;
160       o << INDENT
161         << "                 Eigen::Matrix<double,Eigen::Dynamic,1>& vars,"
162         << EOL;
163       o << INDENT << "                 bool include_tparams = true," << EOL;
164       o << INDENT << "                 bool include_gqs = true," << EOL;
165       o << INDENT
166         << "                 std::ostream* pstream = 0) const {" << EOL;
167       o << INDENT
168         << "  std::vector<double> params_r_vec(params_r.size());" << EOL;
169       o << INDENT << "  for (int i = 0; i < params_r.size(); ++i)" << EOL;
170       o << INDENT << "    params_r_vec[i] = params_r(i);" << EOL;
171       o << INDENT << "  std::vector<double> vars_vec;" << EOL;
172       o << INDENT << "  std::vector<int> params_i_vec;" << EOL;
173       o << INDENT
174         << "  write_array(base_rng, params_r_vec, params_i_vec, "
175         << "vars_vec, include_tparams, include_gqs, pstream);" << EOL;
176       o << INDENT << "  vars.resize(vars_vec.size());" << EOL;
177       o << INDENT << "  for (int i = 0; i < vars.size(); ++i)" << EOL;
178       o << INDENT << "    vars(i) = vars_vec[i];" << EOL;
179       o << INDENT << "}" << EOL2;
180     }
181 
182   }
183 }
184 #endif
185