1 // 2 // SourceTargetCodeGen.cpp 3 // MNNCodegen 4 // 5 // Created by MNN on 2020/11/27. 6 // 7 8 #include "cpu/CPUAst.hpp" 9 #include <sstream> 10 11 using namespace AST; codegen(SourceTarget * target)12std::string PrototypeAST::codegen(SourceTarget *target) { 13 std::stringstream ss; 14 ss << target->getIndent(); 15 ss << "void " << Name << "("; 16 ss << "float** inputs, float** outputs"; 17 ss << ")\n"; 18 return ss.str(); 19 } 20 codegen(SourceTarget * target)21std::string FunctionAST::codegen(SourceTarget* target) { 22 std::stringstream ss; 23 ss << Proto->codegen(target) << "{\n"; 24 target->addIndent(); 25 ss << Body->codegen(target); 26 target->subIndent(); 27 ss << "}\n"; 28 return ss.str(); 29 } 30 codegen(SourceTarget * target)31std::string ListExprAST::codegen(SourceTarget* target) { 32 std::stringstream ss; 33 for (auto& expr : exprs) { 34 ss << expr->codegen(target); 35 } 36 return ss.str(); 37 } 38 codegen(SourceTarget * target)39std::string VarExprAST::codegen(SourceTarget* target) { 40 } 41 codegen(SourceTarget * target)42std::string ForExprAST::codegen(SourceTarget* target) { 43 std::stringstream ss; 44 ss << target->getIndent() << "for (int "; 45 ss << VarName << " = " << Start->codegen(target) << "; "; 46 ss << VarName << " < " << End->codegen(target) << "; "; 47 ss << VarName << " += " << Step->codegen(target) << ") {\n"; 48 target->addIndent(); 49 ss << Body->codegen(target); 50 target->subIndent(); 51 ss << target->getIndent() << "}\n"; 52 return ss.str(); 53 } 54 codegen(SourceTarget * target)55std::string IfExprAST::codegen(SourceTarget* target) { 56 } 57 codegen(SourceTarget * target)58std::string CallExprAST::codegen(SourceTarget* target) { 59 } 60 codegen(SourceTarget * target)61std::string AssignExprAST::codegen(SourceTarget* target) { 62 std::stringstream ss; 63 ss << target->getIndent() << LHS->codegen(target) << " = " << RHS->codegen(target) << ";\n"; 64 return ss.str(); 65 } 66 codegen(SourceTarget * target)67std::string BinaryExprAST::codegen(SourceTarget *target) { 68 std::stringstream ss; 69 auto l = LHS->codegen(target); 70 auto r = RHS->codegen(target); 71 switch (Op) { 72 case MNN::BinaryOpOperation_ADD: 73 ss << "(" << l << " + " << r << ")"; 74 break; 75 case MNN::BinaryOpOperation_SUB: 76 ss << "(" << l << " - " << r << ")"; 77 break; 78 case MNN::BinaryOpOperation_MUL: 79 ss << "(" << l << " * " << r << ")"; 80 break; 81 case MNN::BinaryOpOperation_DIV: 82 case MNN::BinaryOpOperation_REALDIV: 83 ss << "(" << l << " / " << r << ")"; 84 break; 85 case MNN::BinaryOpOperation_FLOORDIV: 86 ss << "floor(" << l << " / " << r << ")"; 87 break; 88 case MNN::BinaryOpOperation_POW: 89 ss << "pow(" << l << ", " << r << ")"; 90 break; 91 case MNN::BinaryOpOperation_MINIMUM: 92 ss << "fmin(" << l << ", " << r << ")"; 93 break; 94 case MNN::BinaryOpOperation_MAXIMUM: 95 ss << "fmax(" << l << ", " << r << ")"; 96 break; 97 case MNN::BinaryOpOperation_GREATER: 98 ss << "(" << l << " > " << r << ")"; 99 break; 100 case MNN::BinaryOpOperation_GREATER_EQUAL: 101 ss << "(" << l << " >= " << r << ")"; 102 break; 103 case MNN::BinaryOpOperation_LESS: 104 ss << "(" << l << " < " << r << ")"; 105 break; 106 case MNN::BinaryOpOperation_LESS_EQUAL: 107 ss << "(" << l << " <= " << r << ")"; 108 break; 109 case MNN::BinaryOpOperation_EQUAL: 110 ss << "(" << l << " == " << r << ")"; 111 break; 112 default: 113 MNN_ASSERT(false); 114 } 115 return ss.str(); 116 } 117 codegen(SourceTarget * target)118std::string ReluExprAST::codegen(SourceTarget *target) { 119 std::stringstream ss; 120 auto x = Operand->codegen(target); 121 if (maxVal == 0.f) { 122 // slope = minVal 123 // relu(x) = ((x < 0) * slope * x + (x >= 0) * x) 124 ss << "((" << x << " < 0 ) * " << minVal << " * " << x << " + (" << x << " >= 0 ) * " << x << ")"; 125 } else { 126 // relu6(x) = min(max(x, minv), maxv) 127 ss << "fmin(fmax(" << x << ", " << minVal << "), " << maxVal << ")"; 128 } 129 return ss.str(); 130 } 131 codegen(SourceTarget * target)132std::string UnaryExprAST::codegen(SourceTarget *target) { 133 std::stringstream ss; 134 auto x = Operand->codegen(target); 135 switch (Op) { 136 case MNN::UnaryOpOperation_ABS: 137 ss << "abs(" << x << ")"; 138 break; 139 case MNN::UnaryOpOperation_FLOOR: 140 ss << "floor(" << x << ")"; 141 break; 142 case MNN::UnaryOpOperation_CEIL: 143 ss << "ceil(" << x << ")"; 144 break; 145 case MNN::UnaryOpOperation_SQRT: 146 ss << "sqrt(" << x << ")"; 147 break; 148 case MNN::UnaryOpOperation_EXP: 149 ss << "exp(" << x << ")"; 150 break; 151 case MNN::UnaryOpOperation_LOG: 152 ss << "log(" << x << ")"; 153 break; 154 case MNN::UnaryOpOperation_SIN: 155 ss << "sin(" << x << ")"; 156 break; 157 case MNN::UnaryOpOperation_COS: 158 ss << "cos(" << x << ")"; 159 break; 160 case MNN::UnaryOpOperation_ROUND: 161 ss << "round(" << x << ")"; 162 break; 163 case MNN::UnaryOpOperation_NEG: 164 ss << "(-" << x << ")"; 165 break; 166 case MNN::UnaryOpOperation_SQUARE: 167 ss << "(" << x << " * " << x << ")"; 168 break; 169 case MNN::UnaryOpOperation_RSQRT: 170 ss << "(1.f / sqrt(" << x << "))"; 171 break; 172 case MNN::UnaryOpOperation_RECIPROCAL: 173 ss << "(1.f / " << x << ")"; 174 break; 175 case MNN::UnaryOpOperation_SIGMOID: 176 ss << "(1.f / (1.f + exp(-" << x << ")))"; 177 break; 178 case MNN::UnaryOpOperation_TANH: 179 ss << "tanh(" << x << ")"; 180 break; 181 default: 182 MNN_ASSERT(false); 183 } 184 return ss.str(); 185 } 186 codegen(SourceTarget * target)187std::string SubscriptExprAST::codegen(SourceTarget *target) { 188 std::stringstream ss; 189 ss << Base->codegen(target) << "[" << Offset->codegen(target) << "]"; 190 return ss.str(); 191 } 192 codegen(SourceTarget * target)193std::string VariableExprAST::codegen(SourceTarget *target) { 194 std::stringstream ss; 195 ss << Name; 196 return ss.str(); 197 } 198 codegen(SourceTarget * target)199std::string NumberExprAST::codegen(SourceTarget *target) { 200 std::stringstream ss; 201 switch (mType) { 202 case FP32: 203 ss << mVal.f32Val; 204 break; 205 case FP64: 206 ss << mVal.f64Val; 207 break; 208 case INT32: 209 ss << mVal.i32Val; 210 break; 211 case INT64: 212 ss << mVal.i64Val; 213 break; 214 default: 215 return nullptr; 216 } 217 return ss.str(); 218 } 219