1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file codegen_opengl.cc
22 *
23 * We are targeting OpenGL 3.3. The reason of not targeting a recent version
24 * of OpenGL is to have better compatibility of WebGL 2.
25 */
26 #include <tvm/packed_func_ext.h>
27 #include <vector>
28 #include <string>
29 #include <utility>
30 #include <unordered_map>
31 #include "codegen_opengl.h"
32 #include "build_common.h"
33 #include "../runtime/thread_storage_scope.h"
34
35 namespace tvm {
36 namespace codegen {
37
CodeGenOpenGL()38 CodeGenOpenGL::CodeGenOpenGL()
39 : output_(nullptr), output_iter_var_(nullptr) {}
40
InitFuncState(LoweredFunc f)41 void CodeGenOpenGL::InitFuncState(LoweredFunc f) {
42 CodeGenC::InitFuncState(f);
43 output_ = nullptr;
44 inputs_.clear();
45 output_iter_var_ = nullptr;
46 thread_extent_var_ = "";
47 this->decl_stream.str("");
48 this->stream.str("");
49 }
50
AddFunction(LoweredFunc f)51 void CodeGenOpenGL::AddFunction(LoweredFunc f) {
52 // clear previous generated state.
53 this->InitFuncState(f);
54
55 this->decl_stream << "#version 300 es\n";
56 this->decl_stream << "precision highp float;\n";
57
58 // skip the first underscore, so SSA variable starts from _1
59 GetUniqueName("_");
60 // add to alloc buffer type.
61 for (const auto& kv : f->handle_data_type) {
62 RegisterHandleType(kv.first.get(), kv.second.type());
63 }
64
65 // Allocate argument names. Store in `var_idmap_`.
66 for (auto arg : f->args) {
67 auto arg_name = GetUniqueName(arg.get()->name_hint);
68 var_idmap_[arg.get()] = arg_name;
69 }
70
71 thread_extent_var_ = GetUniqueName("thread_extent");
72 this->decl_stream << "uniform int " << thread_extent_var_ << ";\n";
73
74 this->stream << "void main() {\n";
75
76 int func_scope = this->BeginScope();
77 this->PrintStmt(f->body);
78 this->EndScope(func_scope);
79
80 this->PrintIndent();
81 this->stream << "}\n\n";
82
83 // Declare arguments.
84 for (auto arg : f->args) {
85 if (this->inputs_.find(arg.get()) != this->inputs_.cend()) {
86 // Declare input texture.
87 // Format:
88 // - Float: "uniform sampler2D {name};"
89 // - Int: "uniform isampler2D {name};"
90 // - UInt: "uniform usampler2D {name};"
91
92 auto arg_name = GetVarID(arg.get());
93
94 auto type_it = this->handle_data_type_.find(arg.get());
95 CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type.";
96 auto type = Type2TVMType(type_it->second);
97 CHECK_EQ(type.lanes, 1) << "Vector type not supported.";
98
99 switch (type.code) {
100 case kDLInt:
101 this->decl_stream << "uniform isampler2D " << arg_name << ";\n";
102 break;
103 case kDLUInt:
104 this->decl_stream << "uniform usampler2D " << arg_name << ";\n";
105 break;
106 case kDLFloat:
107 this->decl_stream << "uniform sampler2D " << arg_name << ";\n";
108 break;
109 default:
110 LOG(FATAL) << "Unsupported type code.";
111 }
112
113 } else if (this->output_ == arg.get()) {
114 // Declare output texture.
115 // Format: "out {type} {name};"
116
117 auto arg_name = GetVarID(arg.get());
118
119 auto type_it = this->handle_data_type_.find(arg.get());
120 CHECK(type_it != this->handle_data_type_.cend()) << "Cannot find type.";
121 auto type = type_it->second;
122
123 this->decl_stream << "out ";
124 PrintType(type, this->decl_stream);
125 this->decl_stream << " " << arg_name << ";\n";
126
127 } else {
128 // Declare uniform value.
129 // Format: "uniform {type} {name};"
130
131 auto arg_name = GetVarID(arg.get());
132 auto type = arg.get()->type;
133
134 this->decl_stream << "uniform ";
135 PrintType(type, this->decl_stream);
136 this->decl_stream << " " << arg_name << ";\n";
137 }
138 }
139
140 std::vector<std::string> arg_names;
141 std::vector<runtime::OpenGLArgKind> arg_kinds;
142 for (auto arg : f->args) {
143 std::string name = GetVarID(arg.get());
144
145 runtime::OpenGLArgKind kind;
146 if (inputs_.find(arg.get()) != inputs_.cend()) {
147 kind = runtime::OpenGLArgKind::kInputTexture;
148 } else if (output_ == arg.get()) {
149 kind = runtime::OpenGLArgKind::kOutputTexture;
150 } else {
151 kind = runtime::OpenGLArgKind::kUniform;
152 }
153
154 arg_names.push_back(name);
155 arg_kinds.push_back(kind);
156 }
157
158 shaders_[f->name] = runtime::OpenGLShader(
159 this->decl_stream.str() + this->stream.str(),
160 std::move(arg_names), std::move(arg_kinds),
161 this->thread_extent_var_);
162 }
163
Finish()164 std::unordered_map<std::string, runtime::OpenGLShader> CodeGenOpenGL::Finish() {
165 return shaders_;
166 }
167
BindThreadIndex(const IterVar & iv)168 void CodeGenOpenGL::BindThreadIndex(const IterVar& iv) {
169 CHECK_EQ(iv->thread_tag, "threadIdx.x") << "Must be threadIdx.x";
170 CHECK(var_idmap_.find(iv->var.get()) == var_idmap_.end())
171 << "Only support one thread iter var";
172 CHECK(output_iter_var_ == nullptr) << "Only support one thread iter var";
173
174 var_idmap_[iv->var.get()] = iv->thread_tag;
175 output_iter_var_ = iv->var.get();
176
177 // Declare threadIdx local variable.
178 this->PrintIndent();
179 this->stream << "ivec2 threadIdx = ivec2(" << runtime::kTextureRowSize
180 << " * int(gl_FragCoord.y) + int(gl_FragCoord.x), 0);\n";
181
182 // Return directly if threadIdx.x >= thread_extent.
183 this->PrintIndent();
184 this->stream << "if (threadIdx.x >= " << thread_extent_var_ << ") {\n";
185 this->PrintIndent();
186 this->stream << " return;\n";
187 this->PrintIndent();
188 this->stream << "}\n";
189 }
190
VisitStmt_(const Store * op)191 void CodeGenOpenGL::VisitStmt_(const Store* op) {
192 LOG(FATAL) << "Store statement not supported in OpenGL."
193 << " Texture store should be a Call statement.";
194 }
195
196 // texelFetch(tex, ivec2(idx & kTextureRowMask, idx >> kTextureRowBits), 0).r
TexelFetch(const Variable * buffer,Expr index)197 std::string CodeGenOpenGL::TexelFetch(const Variable* buffer, Expr index) {
198 std::ostringstream os;
199 os << "texelFetch(" << GetVarID(buffer) << ", ivec2(int(";
200 PrintExpr(index, os);
201 os << ") & " << runtime::kTextureRowMask << ", int(";
202 PrintExpr(index, os);
203 os << ") >> " << runtime::kTextureRowBits << "), 0).r";
204 return os.str();
205 }
206
207 // Print a reference expression to a buffer.
208 // Format: texelFetch(buffer, index, 0).r
GetBufferRef(Type t,const Variable * buffer,Expr index)209 std::string CodeGenOpenGL::GetBufferRef(
210 Type t, const Variable* buffer, Expr index) {
211 CHECK_EQ(t.lanes(), 1) << "Vector type not supported.";
212 CHECK(HandleTypeMatch(buffer, t)) << "Type mismatch not supported.";
213
214 if (buffer == this->output_) {
215 // This is the output texture.
216 return GetVarID(buffer);
217 } else {
218 // This is an input texture.
219 this->inputs_.insert(buffer);
220 return TexelFetch(buffer, index);
221 }
222 }
223
PrintType(Type t,std::ostream & os)224 void CodeGenOpenGL::PrintType(Type t, std::ostream& os) {
225 switch (t.code()) {
226 case kDLInt:
227 CHECK_EQ(t.bits(), 32) << "Only support 32-bit int.";
228 os << "int";
229 break;
230 case kDLUInt:
231 CHECK_EQ(t.bits(), 32) << "Only support 32-bit uint.";
232 os << "uint";
233 break;
234 case kDLFloat:
235 CHECK_EQ(t.bits(), 32) << "Only support 32-bit float.";
236 os << "float";
237 break;
238 default:
239 LOG(FATAL) << "Unsupported type code.";
240 }
241 }
242
243 // Codegen for immediate values
244
VisitExpr_(const IntImm * op,std::ostream & os)245 void CodeGenOpenGL::VisitExpr_(const IntImm* op, std::ostream& os) {
246 CHECK_EQ(op->type, Int(32)) << "GLSL 3.0 only supports 32-bit ints.";
247 CodeGenC::VisitExpr_(op, os);
248 }
249
VisitExpr_(const UIntImm * op,std::ostream & os)250 void CodeGenOpenGL::VisitExpr_(const UIntImm* op, std::ostream& os) {
251 CHECK_EQ(op->type, UInt(32)) << "GLSL 3.0 only supports 32-bit uints.";
252 CodeGenC::VisitExpr_(op, os);
253 }
254
VisitExpr_(const FloatImm * op,std::ostream & os)255 void CodeGenOpenGL::VisitExpr_(const FloatImm* op, std::ostream& os) {
256 CHECK_EQ(op->type, Float(32)) << "GLSL 3.0 only supports 32-bit floats.";
257 CodeGenC::VisitExpr_(op, os);
258 }
259
VisitExpr_(const StringImm *,std::ostream & os)260 void CodeGenOpenGL::VisitExpr_(const StringImm*, std::ostream& os) {
261 LOG(FATAL) << "GLSL 3.0 doesn't support strings.";
262 }
263
VisitStmt_(const Evaluate * op)264 void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
265 auto call = op->value.as<Call>();
266 if (call == nullptr || call->name != Call::glsl_texture_store) {
267 // Fallback to normal logic.
268 CodeGenC::VisitStmt_(op);
269 }
270
271 CHECK_EQ(call->args.size(), 2);
272 auto buffer = call->args[0].as<Variable>();
273 auto value = call->args[1];
274
275 // Doesn't support store to vector.
276 auto type = value.type();
277 CHECK_EQ(type.lanes(), 1)
278 << "Vectorized store not implemented, type = " << type;
279
280 CHECK(inputs_.find(buffer) == inputs_.cend())
281 << "Texture has been read from before. Must not store to it.";
282 if (output_ == nullptr) {
283 output_ = buffer; // Record that this texture is the output.
284 } else {
285 CHECK(output_ == buffer) << "GLSL can only write to 1 texture.";
286 }
287
288 this->PrintIndent();
289 this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
290 }
291
BuildOpenGL(Array<LoweredFunc> funcs)292 runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
293 bool output_ssa = false;
294 CodeGenOpenGL cg;
295 cg.Init(output_ssa);
296 for (LoweredFunc f : funcs) {
297 cg.AddFunction(f);
298 }
299 auto shaders = cg.Finish();
300 return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs));
301 }
302
303 TVM_REGISTER_API("codegen.build_opengl")
304 .set_body_typed(BuildOpenGL);
305
306 } // namespace codegen
307 } // namespace tvm
308