1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_opengl.cc
22  *
23  * We are targeting OpenGL 3.3. The reason of not targeting a recent version
24  * of OpenGL is to have better compatibility of WebGL 2.
25  */
26 #include <tvm/packed_func_ext.h>
27 #include <vector>
28 #include <string>
29 #include <utility>
30 #include <unordered_map>
31 #include "codegen_opengl.h"
32 #include "build_common.h"
33 #include "../runtime/thread_storage_scope.h"
34 
35 namespace tvm {
36 namespace codegen {
37 
CodeGenOpenGL()38 CodeGenOpenGL::CodeGenOpenGL()
39     : output_(nullptr), output_iter_var_(nullptr) {}
40 
InitFuncState(LoweredFunc f)41 void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
42   CodeGenC::InitFuncState(f);
43   output_ = nullptr;
44   inputs_.clear();
45   output_iter_var_ = nullptr;
46   thread_extent_var_ = "";
47   this->decl_stream.str("");
48   this->stream.str("");
49 }
50 
AddFunction(LoweredFunc f)51 void CodeGenOpenGL::AddFunction(LoweredFunc f) {
52   // clear previous generated state.
53   this->InitFuncState(f);
54 
55   this->decl_stream << "#version 300 es\n";
56   this->decl_stream << "precision highp float;\n";
57 
58   // skip the first underscore, so SSA variable starts from _1
59   GetUniqueName("_");
60   // add to alloc buffer type.
61   for (const auto& kv : f->handle_data_type) {
62     RegisterHandleType(kv.first.get(), kv.second.type());
63   }
64 
65   // Allocate argument names. Store in `var_idmap_`.
66   for (auto arg : f->args) {
67     auto arg_name = GetUniqueName(arg.get()->name_hint);
68     var_idmap_[arg.get()] = arg_name;
69   }
70 
71   thread_extent_var_ = GetUniqueName("thread_extent");
72   this->decl_stream << "uniform int " << thread_extent_var_ << ";\n";
73 
74   this->stream << "void main() {\n";
75 
76   int func_scope = this->BeginScope();
77   this->PrintStmt(f->body);
78   this->EndScope(func_scope);
79 
80   this->PrintIndent();
81   this->stream << "}\n\n";
82 
83   // Declare arguments.
84   for (auto arg : f->args) {
85     if (this->inputs_.find(arg.get()) != this->inputs_.cend()) {
86       // Declare input texture.
87       // Format:
88       // - Float: "uniform sampler2D {name};"
89       // - Int: "uniform isampler2D {name};"
90       // - UInt: "uniform usampler2D {name};"
91 
92       auto arg_name = GetVarID(arg.get());
93 
94       auto type_it = this->handle_data_type_.find(arg.get());
95       CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type.";
96       auto type = Type2TVMType(type_it->second);
97       CHECK_EQ(type.lanes, 1) << "Vector type not supported.";
98 
99       switch (type.code) {
100         case kDLInt:
101           this->decl_stream << "uniform isampler2D " << arg_name << ";\n";
102           break;
103         case kDLUInt:
104           this->decl_stream << "uniform usampler2D " << arg_name << ";\n";
105           break;
106         case kDLFloat:
107           this->decl_stream << "uniform sampler2D " << arg_name << ";\n";
108           break;
109         default:
110           LOG(FATAL) << "Unsupported type code.";
111       }
112 
113     } else if (this->output_ == arg.get()) {
114       // Declare output texture.
115       // Format: "out {type} {name};"
116 
117       auto arg_name = GetVarID(arg.get());
118 
119       auto type_it = this->handle_data_type_.find(arg.get());
120       CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type.";
121       auto type = type_it->second;
122 
123       this->decl_stream << "out ";
124       PrintType(type, this->decl_stream);
125       this->decl_stream << " " << arg_name << ";\n";
126 
127     } else {
128       // Declare uniform value.
129       // Format: "uniform {type} {name};"
130 
131       auto arg_name = GetVarID(arg.get());
132       auto type = arg.get()->type;
133 
134       this->decl_stream << "uniform ";
135       PrintType(type, this->decl_stream);
136       this->decl_stream << " " << arg_name << ";\n";
137     }
138   }
139 
140   std::vector<std::string> arg_names;
141   std::vector<runtime::OpenGLArgKind> arg_kinds;
142   for (auto arg : f->args) {
143     std::string name = GetVarID(arg.get());
144 
145     runtime::OpenGLArgKind kind;
146     if (inputs_.find(arg.get()) != inputs_.cend()) {
147       kind = runtime::OpenGLArgKind::kInputTexture;
148     } else if (output_ == arg.get()) {
149       kind = runtime::OpenGLArgKind::kOutputTexture;
150     } else {
151       kind = runtime::OpenGLArgKind::kUniform;
152     }
153 
154     arg_names.push_back(name);
155     arg_kinds.push_back(kind);
156   }
157 
158   shaders_[f->name] = runtime::OpenGLShader(
159       this->decl_stream.str() + this->stream.str(),
160       std::move(arg_names), std::move(arg_kinds),
161       this->thread_extent_var_);
162 }
163 
Finish()164 std::unordered_map<std::string, runtime::OpenGLShader> CodeGenOpenGL::Finish() {
165   return shaders_;
166 }
167 
BindThreadIndex(const IterVar & iv)168 void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
169   CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x";
170   CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end())
171     << "Only support one thread iter var";
172   CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var";
173 
174   var_idmap_[iv->var.get()] = iv->thread_tag;
175   output_iter_var_ = iv->var.get();
176 
177   // Declare threadIdx local variable.
178   this->PrintIndent();
179   this->stream << "ivec2 threadIdx = ivec2(" << runtime::kTextureRowSize
180                << " * int(gl_FragCoord.y) + int(gl_FragCoord.x), 0);\n";
181 
182   // Return directly if threadIdx.x >= thread_extent.
183   this->PrintIndent();
184   this->stream << "if (threadIdx.x >= " << thread_extent_var_ << ") {\n";
185   this->PrintIndent();
186   this->stream << "  return;\n";
187   this->PrintIndent();
188   this->stream << "}\n";
189 }
190 
VisitStmt_(const Store * op)191 void CodeGenOpenGL::VisitStmt_(const Store* op) {
192   LOG(FATAL) << "Store statement not supported in OpenGL."
193              << " Texture store should be a Call statement.";
194 }
195 
196 // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
TexelFetch(const Variable * buffer,Expr index)197 std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
198   std::ostringstream os;
199   os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
200   PrintExpr(index, os);
201   os << ") & " << runtime::kTextureRowMask << ", int(";
202   PrintExpr(index, os);
203   os << ") >> " << runtime::kTextureRowBits << "), 0).r";
204   return os.str();
205 }
206 
207 // Print a reference expression to a buffer.
208 // Format: texelFetch(buffer, index, 0).r
GetBufferRef(Type t,const Variable * buffer,Expr index)209 std::string CodeGenOpenGL::GetBufferRef(
210     Type t, const Variable* buffer, Expr index) {
211   CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
212   CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
213 
214   if (buffer == this->output_) {
215     // This is the output texture.
216     return GetVarID(buffer);
217   } else {
218     // This is an input texture.
219     this->inputs_.insert(buffer);
220     return TexelFetch(buffer, index);
221   }
222 }
223 
PrintType(Type t,std::ostream & os)224 void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
225   switch (t.code()) {
226     case kDLInt:
227       CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
228       os << "int";
229       break;
230     case kDLUInt:
231       CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint.";
232       os << "uint";
233       break;
234     case kDLFloat:
235       CHECK_EQ(t.bits(), 32) << "Only support 32-bit float.";
236       os << "float";
237       break;
238     default:
239       LOG(FATAL) << "Unsupported type code.";
240   }
241 }
242 
243 // Codegen for immediate values
244 
VisitExpr_(const IntImm * op,std::ostream & os)245 void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) {
246   CHECK_EQ(op->type, Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
247   CodeGenC::VisitExpr_(op, os);
248 }
249 
VisitExpr_(const UIntImm * op,std::ostream & os)250 void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) {
251   CHECK_EQ(op->type, UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
252   CodeGenC::VisitExpr_(op, os);
253 }
254 
VisitExpr_(const FloatImm * op,std::ostream & os)255 void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) {
256   CHECK_EQ(op->type, Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
257   CodeGenC::VisitExpr_(op, os);
258 }
259 
VisitExpr_(const StringImm *,std::ostream & os)260 void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) {
261   LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
262 }
263 
VisitStmt_(const Evaluate * op)264 void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
265   auto call = op->value.as<Call>();
266   if (call == nullptr || call->name != Call::glsl_texture_store) {
267     // Fallback to normal logic.
268     CodeGenC::VisitStmt_(op);
269   }
270 
271   CHECK_EQ(call->args.size(), 2);
272   auto buffer = call->args[0].as<Variable>();
273   auto value = call->args[1];
274 
275   // Doesn't support store to vector.
276   auto type = value.type();
277   CHECK_EQ(type.lanes(), 1)
278     << "Vectorized store not implemented, type = " << type;
279 
280   CHECK(inputs_.find(buffer) == inputs_.cend())
281     << "Texture has been read from before. Must not store to it.";
282   if (output_ == nullptr) {
283     output_ = buffer;  // Record that this texture is the output.
284   } else {
285     CHECK(output_ == buffer) << "GLSL can only write to 1 texture.";
286   }
287 
288   this->PrintIndent();
289   this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
290 }
291 
BuildOpenGL(Array<LoweredFunc> funcs)292 runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
293   bool output_ssa = false;
294   CodeGenOpenGL cg;
295   cg.Init(output_ssa);
296   for (LoweredFunc f : funcs) {
297     cg.AddFunction(f);
298   }
299   auto shaders = cg.Finish();
300   return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs));
301 }
302 
303 TVM_REGISTER_API("codegen.build_opengl")
304 .set_body_typed(BuildOpenGL);
305 
306 }  // namespace codegen
307 }  // namespace tvm
308