1import collections
2import numbers
3import os
4import statement_types
5from sig_utils import parse_array, non_differentiable_args, special_arg_values
6
7class CodeGenerator:
8    """
9    This class generates C++ to test Stan functions
10    """
11    def __init__(self):
12        self.name_counter = 0
13        self.code_list = []
14
15    def _add_statement(self, statement):
16        """
17        Add a statement to the code generator
18
19        :param statement: An object of type statement_types.CppStatement
20        """
21        if not isinstance(statement, statement_types.CppStatement):
22            raise TypeError("Argument to FunctionGenerator._add_statement must be an instance of an object that inherits from CppStatement")
23
24        self.code_list.append(statement)
25        return statement
26
27    def _get_next_name_suffix(self):
28        """Get the next available """
29        self.name_counter += 1
30        return repr(self.name_counter - 1)
31
32    def cpp(self):
33        """Generate and return the c++ code corresponding to the list of statements in the code generator"""
34        return os.linesep.join(statement.cpp() for statement in self.code_list)
35
36    def build_arguments(self, signature_parser, arg_overloads, size):
37        """
38        Generate argument variables for each of the arguments in the given signature_parser
39        with the given overloads in arg_overloads and with the given size
40
41        :param signature_parser: An instance of SignatureParser
42        :param arg_overloads: A list of argument overloads (Prim/Fwd/Rev/etc.) as strings
43        :param size: Size of matrix-like arguments. This is not used for array arguments (which will effectively all be size 1)
44        """
45        arg_list = []
46        for n, (overload, stan_arg) in enumerate(zip(arg_overloads, signature_parser.stan_args)):
47            suffix = self._get_next_name_suffix()
48
49            number_nested_arrays, inner_type = parse_array(stan_arg)
50
51            # Check if argument is differentiable
52            if inner_type == "int" or n in non_differentiable_args.get(signature_parser.function_name, []):
53                  overload = "Prim"
54
55            # By default the variable value is None and a default will be substituted
56            value = None
57
58            # Check for special arguments (constrained variables or types)
59            try:
60                special_arg = special_arg_values[signature_parser.function_name][n]
61                if isinstance(special_arg, str):
62                    inner_type = special_arg
63                elif special_arg is not None:
64                    value = special_arg
65            except KeyError:
66                pass
67
68            # The first case here is used for the array initializers in sig_utils.special_arg_values
69            # Everything else uses the second case
70            if number_nested_arrays > 0 and isinstance(value, collections.Iterable):
71                arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = value)
72            else:
73                if inner_type == "int":
74                    arg = statement_types.IntVariable("int" + suffix, value)
75                elif inner_type == "real":
76                    arg = statement_types.RealVariable(overload, "real" + suffix, value)
77                elif inner_type in ("vector", "row_vector", "matrix"):
78                    arg = statement_types.MatrixVariable(overload, "matrix" + suffix, inner_type, size, value)
79                elif inner_type == "rng":
80                    arg = statement_types.RngVariable("rng" + suffix)
81                elif inner_type == "ostream_ptr":
82                    arg = statement_types.OStreamVariable("ostream" + suffix)
83                elif inner_type == "scalar_return_type":
84                    arg = statement_types.ReturnTypeTVariable("ret_type" + suffix, *arg_list)
85                elif inner_type == "simplex":
86                    arg = statement_types.SimplexVariable(overload, "simplex" + suffix, size, value)
87                elif inner_type == "positive_definite_matrix":
88                    arg = statement_types.PositiveDefiniteMatrixVariable(overload, "positive_definite_matrix" + suffix, size, value)
89                elif inner_type == "(vector, vector, data array[] real, data array[] int) => vector":
90                    arg = statement_types.AlgebraSolverFunctorVariable("functor" + suffix)
91                elif inner_type == "(real, vector, ostream_ptr, vector) => vector":
92                    arg = statement_types.OdeFunctorVariable("functor" + suffix)
93                else:
94                    raise Exception("Inner type " + inner_type + " not supported")
95
96                if number_nested_arrays > 0:
97                    self._add_statement(arg)
98                    arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = arg)
99
100            arg_list.append(self._add_statement(arg))
101
102        if signature_parser.is_rng():
103            arg_list.append(self._add_statement(statement_types.RngVariable("rng" + self._get_next_name_suffix())))
104
105        return arg_list
106
107    def add(self, arg1, arg2):
108        """
109        Generate code for arg1 + arg2
110
111        :param arg1: First argument
112        :param arg1: Second argument
113        """
114        return self._add_statement(statement_types.FunctionCall("stan::math::add", "sum_of_sums" + self._get_next_name_suffix(), arg1, arg2))
115
116    def convert_to_expression(self, arg, size = None):
117        """
118        Generate code to convert arg to an expression type of given size. If size is None, use the argument size
119
120        :param arg: Argument to convert to expression
121        """
122        return self._add_statement(statement_types.ExpressionVariable(arg.name + "_expr" + self._get_next_name_suffix(), arg, size))
123
124    def expect_adj_eq(self, arg1, arg2):
125        """
126        Generate code that checks that the adjoints of arg1 and arg2 are equal
127
128        :param arg1: First argument
129        :param arg2: Second argument
130        """
131        return self._add_statement(statement_types.FunctionCall("stan::test::expect_adj_eq", None, arg1, arg2))
132
133    def expect_eq(self, arg1, arg2):
134        """
135        Generate code that checks that values of arg1 and arg2 are equal
136
137        :param arg1: First argument
138        :param arg2: Second argument
139        """
140        return self._add_statement(statement_types.FunctionCall("EXPECT_STAN_EQ", None, arg1, arg2))
141
142    def expect_leq_one(self, arg):
143        """
144        Generate code to check that arg is less than or equal to one
145
146        :param arg: Argument to check
147        """
148        one = self._add_statement(statement_types.IntVariable("int" + self._get_next_name_suffix(), 1))
149        return self._add_statement(statement_types.FunctionCall("EXPECT_LE", None, arg, one))
150
151    def function_call_assign(self, cpp_function_name, *args):
152        """
153        Generate code to call the c++ function given by cpp_function_name with given args and assign the result to another variable
154
155        :param cpp_function_name: c++ function name to call
156        :param args: list of arguments to pass to function
157        """
158        return self._add_statement(statement_types.FunctionCall(cpp_function_name, "result" + self._get_next_name_suffix(), *args))
159
160    def grad(self, arg):
161        """
162        Generate code to call stan::test::grad(arg) (equivalent of arg.grad())
163
164        :param arg: Argument to call grad on
165        """
166        return self._add_statement(statement_types.FunctionCall("stan::test::grad", None, arg))
167
168    def recover_memory(self):
169        """Generate code to call stan::math::recover_memory()"""
170        return self._add_statement(statement_types.FunctionCall("stan::math::recover_memory", None))
171
172    def recursive_sum(self, arg):
173        """
174        Generate code that repeatedly sums arg until all that is left is a scalar
175
176        :param arg: Argument to sum
177        """
178        return self._add_statement(statement_types.FunctionCall("stan::test::recursive_sum", "summed_result" + self._get_next_name_suffix(), arg))
179
180    def to_var_value(self, arg):
181        """
182        Generate code to convert arg to a varmat
183
184        :param arg: Argument to convert to varmat
185        """
186        return self._add_statement(statement_types.FunctionCall("stan::math::to_var_value", arg.name + "_varmat" + self._get_next_name_suffix(), arg))
187