1import collections 2import numbers 3import os 4import statement_types 5from sig_utils import parse_array, non_differentiable_args, special_arg_values 6 7class CodeGenerator: 8 """ 9 This class generates C++ to test Stan functions 10 """ 11 def __init__(self): 12 self.name_counter = 0 13 self.code_list = [] 14 15 def _add_statement(self, statement): 16 """ 17 Add a statement to the code generator 18 19 :param statement: An object of type statement_types.CppStatement 20 """ 21 if not isinstance(statement, statement_types.CppStatement): 22 raise TypeError("Argument to FunctionGenerator._add_statement must be an instance of an object that inherits from CppStatement") 23 24 self.code_list.append(statement) 25 return statement 26 27 def _get_next_name_suffix(self): 28 """Get the next available """ 29 self.name_counter += 1 30 return repr(self.name_counter - 1) 31 32 def cpp(self): 33 """Generate and return the c++ code corresponding to the list of statements in the code generator""" 34 return os.linesep.join(statement.cpp() for statement in self.code_list) 35 36 def build_arguments(self, signature_parser, arg_overloads, size): 37 """ 38 Generate argument variables for each of the arguments in the given signature_parser 39 with the given overloads in arg_overloads and with the given size 40 41 :param signature_parser: An instance of SignatureParser 42 :param arg_overloads: A list of argument overloads (Prim/Fwd/Rev/etc.) as strings 43 :param size: Size of matrix-like arguments. This is not used for array arguments (which will effectively all be size 1) 44 """ 45 arg_list = [] 46 for n, (overload, stan_arg) in enumerate(zip(arg_overloads, signature_parser.stan_args)): 47 suffix = self._get_next_name_suffix() 48 49 number_nested_arrays, inner_type = parse_array(stan_arg) 50 51 # Check if argument is differentiable 52 if inner_type == "int" or n in non_differentiable_args.get(signature_parser.function_name, []): 53 overload = "Prim" 54 55 # By default the variable value is None and a default will be substituted 56 value = None 57 58 # Check for special arguments (constrained variables or types) 59 try: 60 special_arg = special_arg_values[signature_parser.function_name][n] 61 if isinstance(special_arg, str): 62 inner_type = special_arg 63 elif special_arg is not None: 64 value = special_arg 65 except KeyError: 66 pass 67 68 # The first case here is used for the array initializers in sig_utils.special_arg_values 69 # Everything else uses the second case 70 if number_nested_arrays > 0 and isinstance(value, collections.Iterable): 71 arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = value) 72 else: 73 if inner_type == "int": 74 arg = statement_types.IntVariable("int" + suffix, value) 75 elif inner_type == "real": 76 arg = statement_types.RealVariable(overload, "real" + suffix, value) 77 elif inner_type in ("vector", "row_vector", "matrix"): 78 arg = statement_types.MatrixVariable(overload, "matrix" + suffix, inner_type, size, value) 79 elif inner_type == "rng": 80 arg = statement_types.RngVariable("rng" + suffix) 81 elif inner_type == "ostream_ptr": 82 arg = statement_types.OStreamVariable("ostream" + suffix) 83 elif inner_type == "scalar_return_type": 84 arg = statement_types.ReturnTypeTVariable("ret_type" + suffix, *arg_list) 85 elif inner_type == "simplex": 86 arg = statement_types.SimplexVariable(overload, "simplex" + suffix, size, value) 87 elif inner_type == "positive_definite_matrix": 88 arg = statement_types.PositiveDefiniteMatrixVariable(overload, "positive_definite_matrix" + suffix, size, value) 89 elif inner_type == "(vector, vector, data array[] real, data array[] int) => vector": 90 arg = statement_types.AlgebraSolverFunctorVariable("functor" + suffix) 91 elif inner_type == "(real, vector, ostream_ptr, vector) => vector": 92 arg = statement_types.OdeFunctorVariable("functor" + suffix) 93 else: 94 raise Exception("Inner type " + inner_type + " not supported") 95 96 if number_nested_arrays > 0: 97 self._add_statement(arg) 98 arg = statement_types.ArrayVariable(overload, "array" + suffix, number_nested_arrays, inner_type, size = 1, value = arg) 99 100 arg_list.append(self._add_statement(arg)) 101 102 if signature_parser.is_rng(): 103 arg_list.append(self._add_statement(statement_types.RngVariable("rng" + self._get_next_name_suffix()))) 104 105 return arg_list 106 107 def add(self, arg1, arg2): 108 """ 109 Generate code for arg1 + arg2 110 111 :param arg1: First argument 112 :param arg1: Second argument 113 """ 114 return self._add_statement(statement_types.FunctionCall("stan::math::add", "sum_of_sums" + self._get_next_name_suffix(), arg1, arg2)) 115 116 def convert_to_expression(self, arg, size = None): 117 """ 118 Generate code to convert arg to an expression type of given size. If size is None, use the argument size 119 120 :param arg: Argument to convert to expression 121 """ 122 return self._add_statement(statement_types.ExpressionVariable(arg.name + "_expr" + self._get_next_name_suffix(), arg, size)) 123 124 def expect_adj_eq(self, arg1, arg2): 125 """ 126 Generate code that checks that the adjoints of arg1 and arg2 are equal 127 128 :param arg1: First argument 129 :param arg2: Second argument 130 """ 131 return self._add_statement(statement_types.FunctionCall("stan::test::expect_adj_eq", None, arg1, arg2)) 132 133 def expect_eq(self, arg1, arg2): 134 """ 135 Generate code that checks that values of arg1 and arg2 are equal 136 137 :param arg1: First argument 138 :param arg2: Second argument 139 """ 140 return self._add_statement(statement_types.FunctionCall("EXPECT_STAN_EQ", None, arg1, arg2)) 141 142 def expect_leq_one(self, arg): 143 """ 144 Generate code to check that arg is less than or equal to one 145 146 :param arg: Argument to check 147 """ 148 one = self._add_statement(statement_types.IntVariable("int" + self._get_next_name_suffix(), 1)) 149 return self._add_statement(statement_types.FunctionCall("EXPECT_LE", None, arg, one)) 150 151 def function_call_assign(self, cpp_function_name, *args): 152 """ 153 Generate code to call the c++ function given by cpp_function_name with given args and assign the result to another variable 154 155 :param cpp_function_name: c++ function name to call 156 :param args: list of arguments to pass to function 157 """ 158 return self._add_statement(statement_types.FunctionCall(cpp_function_name, "result" + self._get_next_name_suffix(), *args)) 159 160 def grad(self, arg): 161 """ 162 Generate code to call stan::test::grad(arg) (equivalent of arg.grad()) 163 164 :param arg: Argument to call grad on 165 """ 166 return self._add_statement(statement_types.FunctionCall("stan::test::grad", None, arg)) 167 168 def recover_memory(self): 169 """Generate code to call stan::math::recover_memory()""" 170 return self._add_statement(statement_types.FunctionCall("stan::math::recover_memory", None)) 171 172 def recursive_sum(self, arg): 173 """ 174 Generate code that repeatedly sums arg until all that is left is a scalar 175 176 :param arg: Argument to sum 177 """ 178 return self._add_statement(statement_types.FunctionCall("stan::test::recursive_sum", "summed_result" + self._get_next_name_suffix(), arg)) 179 180 def to_var_value(self, arg): 181 """ 182 Generate code to convert arg to a varmat 183 184 :param arg: Argument to convert to varmat 185 """ 186 return self._add_statement(statement_types.FunctionCall("stan::math::to_var_value", arg.name + "_varmat" + self._get_next_name_suffix(), arg)) 187