1 //
2 //  SourceTargetCodeGen.cpp
3 //  MNNCodegen
4 //
5 //  Created by MNN on 2020/11/27.
6 //
7 
8 #include "cpu/CPUAst.hpp"
9 #include <sstream>
10 
11 using namespace AST;
codegen(SourceTarget * target)12 std::string PrototypeAST::codegen(SourceTarget *target) {
13     std::stringstream ss;
14     ss << target->getIndent();
15     ss << "void " << Name << "(";
16     ss << "float** inputs, float** outputs";
17     ss << ")\n";
18     return ss.str();
19 }
20 
codegen(SourceTarget * target)21 std::string FunctionAST::codegen(SourceTarget* target) {
22     std::stringstream ss;
23     ss << Proto->codegen(target) << "{\n";
24     target->addIndent();
25     ss << Body->codegen(target);
26     target->subIndent();
27     ss << "}\n";
28     return ss.str();
29 }
30 
codegen(SourceTarget * target)31 std::string ListExprAST::codegen(SourceTarget* target) {
32     std::stringstream ss;
33     for (auto& expr : exprs) {
34         ss << expr->codegen(target);
35     }
36     return ss.str();
37 }
38 
codegen(SourceTarget * target)39 std::string VarExprAST::codegen(SourceTarget* target) {
40 }
41 
codegen(SourceTarget * target)42 std::string ForExprAST::codegen(SourceTarget* target) {
43     std::stringstream ss;
44     ss << target->getIndent() << "for (int ";
45     ss << VarName << " = " << Start->codegen(target) << "; ";
46     ss << VarName << " < " << End->codegen(target) << "; ";
47     ss << VarName << " += " << Step->codegen(target) << ") {\n";
48     target->addIndent();
49     ss << Body->codegen(target);
50     target->subIndent();
51     ss << target->getIndent() << "}\n";
52     return ss.str();
53 }
54 
codegen(SourceTarget * target)55 std::string IfExprAST::codegen(SourceTarget* target) {
56 }
57 
codegen(SourceTarget * target)58 std::string CallExprAST::codegen(SourceTarget* target) {
59 }
60 
codegen(SourceTarget * target)61 std::string AssignExprAST::codegen(SourceTarget* target) {
62     std::stringstream ss;
63     ss << target->getIndent() << LHS->codegen(target) << " = " << RHS->codegen(target) << ";\n";
64     return ss.str();
65 }
66 
codegen(SourceTarget * target)67 std::string BinaryExprAST::codegen(SourceTarget *target) {
68     std::stringstream ss;
69     auto l = LHS->codegen(target);
70     auto r = RHS->codegen(target);
71     switch (Op) {
72         case MNN::BinaryOpOperation_ADD:
73             ss << "(" << l << " + " << r << ")";
74             break;
75         case MNN::BinaryOpOperation_SUB:
76             ss << "(" << l << " - " << r << ")";
77             break;
78         case MNN::BinaryOpOperation_MUL:
79             ss << "(" << l << " * " << r << ")";
80             break;
81         case MNN::BinaryOpOperation_DIV:
82         case MNN::BinaryOpOperation_REALDIV:
83             ss << "(" << l << " / " << r << ")";
84             break;
85         case MNN::BinaryOpOperation_FLOORDIV:
86             ss << "floor(" << l << " / " << r << ")";
87             break;
88         case MNN::BinaryOpOperation_POW:
89             ss << "pow(" << l << ", " << r << ")";
90             break;
91         case MNN::BinaryOpOperation_MINIMUM:
92             ss << "fmin(" << l << ", " << r << ")";
93             break;
94         case MNN::BinaryOpOperation_MAXIMUM:
95             ss << "fmax(" << l << ", " << r << ")";
96             break;
97         case MNN::BinaryOpOperation_GREATER:
98             ss << "(" << l << " > " << r << ")";
99             break;
100         case MNN::BinaryOpOperation_GREATER_EQUAL:
101             ss << "(" << l << " >= " << r << ")";
102             break;
103         case MNN::BinaryOpOperation_LESS:
104             ss << "(" << l << " < " << r << ")";
105             break;
106         case MNN::BinaryOpOperation_LESS_EQUAL:
107             ss << "(" << l << " <= " << r << ")";
108             break;
109         case MNN::BinaryOpOperation_EQUAL:
110             ss << "(" << l << " == " << r << ")";
111             break;
112         default:
113             MNN_ASSERT(false);
114     }
115     return ss.str();
116 }
117 
codegen(SourceTarget * target)118 std::string ReluExprAST::codegen(SourceTarget *target) {
119     std::stringstream ss;
120     auto x = Operand->codegen(target);
121     if (maxVal == 0.f) {
122         // slope = minVal
123         // relu(x) = ((x < 0) * slope * x + (x >= 0) * x)
124         ss << "((" << x << " < 0 ) * " << minVal << " * " << x << " + (" << x << " >= 0 ) * " << x << ")";
125     } else {
126         // relu6(x) = min(max(x, minv), maxv)
127         ss << "fmin(fmax(" << x << ", " << minVal << "), " << maxVal << ")";
128     }
129     return ss.str();
130 }
131 
codegen(SourceTarget * target)132 std::string UnaryExprAST::codegen(SourceTarget *target) {
133     std::stringstream ss;
134     auto x = Operand->codegen(target);
135     switch (Op) {
136         case MNN::UnaryOpOperation_ABS:
137             ss << "abs(" << x << ")";
138             break;
139         case MNN::UnaryOpOperation_FLOOR:
140             ss << "floor(" << x << ")";
141             break;
142         case MNN::UnaryOpOperation_CEIL:
143             ss << "ceil(" << x << ")";
144             break;
145         case MNN::UnaryOpOperation_SQRT:
146             ss << "sqrt(" << x << ")";
147             break;
148         case MNN::UnaryOpOperation_EXP:
149             ss << "exp(" << x << ")";
150             break;
151         case MNN::UnaryOpOperation_LOG:
152             ss << "log(" << x << ")";
153             break;
154         case MNN::UnaryOpOperation_SIN:
155             ss << "sin(" << x << ")";
156             break;
157         case MNN::UnaryOpOperation_COS:
158             ss << "cos(" << x << ")";
159             break;
160         case MNN::UnaryOpOperation_ROUND:
161             ss << "round(" << x << ")";
162             break;
163         case MNN::UnaryOpOperation_NEG:
164             ss << "(-" << x << ")";
165             break;
166         case MNN::UnaryOpOperation_SQUARE:
167             ss << "(" << x << " * " << x << ")";
168             break;
169         case MNN::UnaryOpOperation_RSQRT:
170             ss << "(1.f / sqrt(" << x << "))";
171             break;
172         case MNN::UnaryOpOperation_RECIPROCAL:
173             ss << "(1.f / " << x << ")";
174             break;
175         case MNN::UnaryOpOperation_SIGMOID:
176             ss << "(1.f / (1.f + exp(-" << x << ")))";
177             break;
178         case MNN::UnaryOpOperation_TANH:
179             ss << "tanh(" << x << ")";
180             break;
181         default:
182             MNN_ASSERT(false);
183     }
184     return ss.str();
185 }
186 
codegen(SourceTarget * target)187 std::string SubscriptExprAST::codegen(SourceTarget *target) {
188     std::stringstream ss;
189     ss << Base->codegen(target) << "[" << Offset->codegen(target) << "]";
190     return ss.str();
191 }
192 
codegen(SourceTarget * target)193 std::string VariableExprAST::codegen(SourceTarget *target) {
194     std::stringstream ss;
195     ss << Name;
196     return ss.str();
197 }
198 
codegen(SourceTarget * target)199 std::string NumberExprAST::codegen(SourceTarget *target) {
200     std::stringstream ss;
201     switch (mType) {
202         case FP32:
203             ss << mVal.f32Val;
204             break;
205         case FP64:
206             ss << mVal.f64Val;
207             break;
208         case INT32:
209             ss << mVal.i32Val;
210             break;
211         case INT64:
212             ss << mVal.i64Val;
213             break;
214         default:
215             return nullptr;
216     }
217     return ss.str();
218 }
219