1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file codegen_spirv.cc
22 * \brief Generate SPIRV block
23 */
24 #include "codegen_spirv.h"
25
26 #include <tvm/runtime/container.h>
27 #include <tvm/tir/builtin.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30
31 #include <string>
32
33 namespace tvm {
34 namespace codegen {
35
BuildFunction(const PrimFunc & f,const std::string & name)36 std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) {
37 this->InitFuncState();
38 CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
39 std::vector<Var> pod_args;
40 uint32_t num_buffer = 0;
41
42 for (Var arg : f->params) {
43 DataType t = arg.dtype();
44 if (t.is_handle()) {
45 if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
46 auto* prim = ptr->element_type.as<PrimTypeNode>();
47 CHECK(prim);
48 DataType value_type = prim->dtype;
49 spirv::Value arg_value =
50 builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer);
51 storage_info_[arg.get()].UpdateContentType(value_type);
52 var_map_[arg.get()] = arg_value;
53 } else {
54 LOG(FATAL) << "require all handles to be typed";
55 }
56 ++num_buffer;
57 } else {
58 pod_args.push_back(arg);
59 }
60 }
61 spirv::Value func_ptr = builder_->NewFunction();
62 builder_->StartFunction(func_ptr);
63
64 // All the POD arguments are passed in through PushConstant
65 if (pod_args.size() != 0) {
66 std::vector<spirv::SType> value_types;
67 for (size_t i = 0; i < pod_args.size(); ++i) {
68 value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
69 }
70 spirv::Value ptr = builder_->DeclarePushConstant(value_types);
71 for (size_t i = 0; i < pod_args.size(); ++i) {
72 spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
73 var_map_[pod_args[i].get()] = value;
74 }
75 }
76 this->VisitStmt(f->body);
77 builder_->SetLocalSize(func_ptr, workgroup_size_);
78 builder_->MakeInst(spv::OpReturn);
79 builder_->MakeInst(spv::OpFunctionEnd);
80
81 builder_->CommitKernelFunction(func_ptr, name);
82
83 return builder_->Finalize();
84 }
85
InitFuncState()86 void CodeGenSPIRV::InitFuncState() {
87 std::fill(workgroup_size_, workgroup_size_ + 3, 1);
88 var_map_.clear();
89 storage_info_.clear();
90 analyzer_.reset(new arith::Analyzer());
91 builder_.reset(new spirv::IRBuilder());
92 builder_->InitHeader();
93 }
94
GetThreadIndex(const IterVar & iv,const PrimExpr & extent)95 spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
96 runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
97 spirv::Value v;
98 if (ts.rank == 1) {
99 v = builder_->GetLocalID(ts.dim_index);
100 auto* sizeptr = extent.as<tir::IntImmNode>();
101 CHECK(sizeptr) << "SPIRV only allows constant thread group size "
102 << " get " << extent;
103 CHECK_LT(ts.dim_index, 3);
104 workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
105 } else {
106 v = builder_->GetWorkgroupID(ts.dim_index);
107 }
108 return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
109 }
110
CreateStorageSync(const CallNode * op)111 spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
112 const std::string& sync = op->args[0].as<StringImmNode>()->value;
113 spirv::Value value;
114 if (sync == "warp") {
115 return value;
116 } else if (sync == "shared") {
117 auto type_int = builder_->GetSType(DataType::Int(32));
118 builder_->MakeInst(
119 spv::OpControlBarrier,
120 builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
121 builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
122 builder_->IntImm(type_int,
123 static_cast<int64_t>(spv::MemorySemanticsSequentiallyConsistentMask |
124 spv::MemorySemanticsWorkgroupMemoryMask)));
125 } else {
126 LOG(FATAL) << "Do not support sync " << sync;
127 }
128 return value;
129 }
130
VisitExpr_(const VarNode * op)131 spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) {
132 auto it = var_map_.find(op);
133 CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
134 return it->second;
135 }
136
VisitExpr_(const IntImmNode * op)137 spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
138 return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
139 }
140
VisitExpr_(const FloatImmNode * op)141 spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
142 return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
143 }
144
VisitExpr_(const StringImmNode * op)145 spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) {
146 LOG(FATAL) << "StringImm is not supported in Device code";
147 return spirv::Value();
148 }
149
VisitExpr_(const CastNode * op)150 spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) {
151 return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
152 }
153
VisitExpr_(const AddNode * op)154 spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) {
155 return builder_->Add(MakeValue(op->a), MakeValue(op->b));
156 }
157
VisitExpr_(const SubNode * op)158 spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) {
159 return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
160 }
161
VisitExpr_(const MulNode * op)162 spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) {
163 return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
164 }
165
VisitExpr_(const DivNode * op)166 spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) {
167 return builder_->Div(MakeValue(op->a), MakeValue(op->b));
168 }
169
VisitExpr_(const ModNode * op)170 spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) {
171 return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
172 }
173
VisitExpr_(const MinNode * op)174 spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) {
175 spirv::Value a = MakeValue(op->a);
176 spirv::Value b = MakeValue(op->b);
177 return builder_->Select(builder_->LT(a, b), a, b);
178 }
179
VisitExpr_(const MaxNode * op)180 spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) {
181 spirv::Value a = MakeValue(op->a);
182 spirv::Value b = MakeValue(op->b);
183 return builder_->Select(builder_->GT(a, b), a, b);
184 }
185
VisitExpr_(const LTNode * op)186 spirv::Value CodeGenSPIRV::VisitExpr_(const LTNode* op) {
187 return builder_->LT(MakeValue(op->a), MakeValue(op->b));
188 }
189
VisitExpr_(const LENode * op)190 spirv::Value CodeGenSPIRV::VisitExpr_(const LENode* op) {
191 return builder_->LE(MakeValue(op->a), MakeValue(op->b));
192 }
193
VisitExpr_(const GTNode * op)194 spirv::Value CodeGenSPIRV::VisitExpr_(const GTNode* op) {
195 return builder_->GT(MakeValue(op->a), MakeValue(op->b));
196 }
197
VisitExpr_(const GENode * op)198 spirv::Value CodeGenSPIRV::VisitExpr_(const GENode* op) {
199 return builder_->GE(MakeValue(op->a), MakeValue(op->b));
200 }
201
VisitExpr_(const EQNode * op)202 spirv::Value CodeGenSPIRV::VisitExpr_(const EQNode* op) {
203 return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
204 }
205
VisitExpr_(const NENode * op)206 spirv::Value CodeGenSPIRV::VisitExpr_(const NENode* op) {
207 return builder_->NE(MakeValue(op->a), MakeValue(op->b));
208 }
209
VisitExpr_(const AndNode * op)210 spirv::Value CodeGenSPIRV::VisitExpr_(const AndNode* op) {
211 spirv::Value a = MakeValue(op->a);
212 spirv::Value b = MakeValue(op->b);
213 return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
214 }
215
VisitExpr_(const OrNode * op)216 spirv::Value CodeGenSPIRV::VisitExpr_(const OrNode* op) {
217 spirv::Value a = MakeValue(op->a);
218 spirv::Value b = MakeValue(op->b);
219 return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
220 }
221
VisitExpr_(const NotNode * op)222 spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) {
223 spirv::Value a = MakeValue(op->a);
224 return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
225 }
226
VisitExpr_(const SelectNode * op)227 spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
228 return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value),
229 MakeValue(op->false_value));
230 }
231
VisitExpr_(const LetNode * op)232 spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
233 auto it = let_binding_.find(op->var);
234 if (it != let_binding_.end()) {
235 CHECK(deep_equal_(it->second->value, op->value))
236 << "Let cannot bind the same var to two different values";
237 } else {
238 let_binding_[op->var] = op;
239 }
240 var_map_[op->var.get()] = MakeValue(op->value);
241 analyzer_->Bind(op->var, op->value);
242 return MakeValue(op->body);
243 }
244
VisitExpr_(const CallNode * op)245 spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
246 if (op->op.same_as(builtin::call_spirv_pure_glsl450())) {
247 CHECK_GE(op->args.size(), 2U);
248 uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
249 std::vector<spirv::Value> values;
250 for (size_t i = 1; i < op->args.size(); ++i) {
251 values.push_back(MakeValue(op->args[i]));
252 }
253 return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values);
254 } else if (op->op.same_as(builtin::bitwise_and())) {
255 CHECK_EQ(op->args.size(), 2U);
256 spirv::Value a = MakeValue(op->args[0]);
257 spirv::Value b = MakeValue(op->args[1]);
258 return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
259 } else if (op->op.same_as(builtin::bitwise_xor())) {
260 CHECK_EQ(op->args.size(), 2U);
261 spirv::Value a = MakeValue(op->args[0]);
262 spirv::Value b = MakeValue(op->args[1]);
263 return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
264 } else if (op->op.same_as(builtin::bitwise_or())) {
265 CHECK_EQ(op->args.size(), 2U);
266 spirv::Value a = MakeValue(op->args[0]);
267 spirv::Value b = MakeValue(op->args[1]);
268 return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
269 } else if (op->op.same_as(builtin::bitwise_not())) {
270 CHECK_EQ(op->args.size(), 1U);
271 spirv::Value a = MakeValue(op->args[0]);
272 return builder_->MakeValue(spv::OpNot, a.stype, a);
273 } else if (op->op.same_as(builtin::shift_left())) {
274 CHECK_EQ(op->args.size(), 2U);
275 spirv::Value a = MakeValue(op->args[0]);
276 spirv::Value b = MakeValue(op->args[1]);
277 return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
278 } else if (op->op.same_as(builtin::shift_right())) {
279 CHECK_EQ(op->args.size(), 2U);
280 spirv::Value a = MakeValue(op->args[0]);
281 spirv::Value b = MakeValue(op->args[1]);
282 if (op->args[0].dtype().is_int()) {
283 return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
284 } else {
285 return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
286 }
287 } else if (op->op.same_as(builtin::reinterpret())) {
288 return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
289 MakeValue(op->args[0]));
290 } else if (op->op.same_as(builtin::large_uint_imm())) {
291 CHECK_EQ(op->args.size(), 2U);
292 uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
293 uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
294 uint64_t val = (high << 32U) | low;
295 return builder_->UIntImm(builder_->GetSType(op->dtype), val);
296 } else if (op->op.same_as(builtin::tvm_storage_sync())) {
297 return this->CreateStorageSync(op);
298 } else if (op->op.same_as(builtin::if_then_else())) {
299 CHECK_EQ(op->args.size(), 3U);
300 spirv::Value cond = MakeValue(op->args[0]);
301 spirv::Label then_label = builder_->NewLabel();
302 spirv::Label else_label = builder_->NewLabel();
303 spirv::Label merge_label = builder_->NewLabel();
304 builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
305 builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
306 // then block, must get label after we see the value
307 builder_->StartLabel(then_label);
308 spirv::Value then_value = MakeValue(op->args[1]);
309 spirv::Label then_value_label = builder_->CurrentLabel();
310 builder_->MakeInst(spv::OpBranch, merge_label);
311 // else block
312 builder_->StartLabel(else_label);
313 spirv::Value else_value = MakeValue(op->args[2]);
314 spirv::Label else_value_label = builder_->CurrentLabel();
315 builder_->MakeInst(spv::OpBranch, merge_label);
316 // merge block
317 builder_->StartLabel(merge_label);
318 spirv::PhiValue phi = builder_->MakePhi(then_value.stype, 2);
319 phi.SetIncoming(0, then_value, then_value_label);
320 phi.SetIncoming(1, else_value, else_value_label);
321 return phi;
322 } else if (op->op.same_as(builtin::popcount())) {
323 return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
324 MakeValue(op->args[0]));
325 } else {
326 LOG(FATAL) << "Unresolved call " << op->op;
327 return spirv::Value();
328 }
329 }
330
VisitExpr_(const RampNode * op)331 spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
332 std::vector<spirv::Value> values;
333 spirv::Value base = MakeValue(op->base);
334 for (int i = 0; i < op->lanes; ++i) {
335 spirv::Value v = base;
336 if (i != 0) {
337 spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride);
338 v = builder_->Add(v, offset);
339 }
340 values.push_back(v);
341 }
342 return builder_->Concat(values);
343 }
344
VisitExpr_(const BroadcastNode * op)345 spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
346 std::vector<spirv::Value> values;
347 spirv::Value v = MakeValue(op->value);
348 for (int i = 0; i < op->lanes; i++) {
349 values.push_back(v);
350 }
351 return builder_->Concat(values);
352 }
353
VisitExpr_(const LoadNode * op)354 spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
355 CHECK(is_one(op->predicate));
356 auto it = storage_info_.find(op->buffer_var.get());
357 CHECK(it != storage_info_.end());
358 StorageInfo& info = it->second;
359 if (!info.content_fixed) {
360 info.UpdateContentType(op->dtype);
361 }
362
363 spirv::SType content_type = builder_->GetSType(info.content_type);
364 spirv::Value buffer = MakeValue(op->buffer_var);
365 spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
366
367 uint32_t mask = spv::MemoryAccessMaskNone;
368 if (info.is_volatile) {
369 mask |= spv::MemoryAccessVolatileMask;
370 }
371 if (op->dtype.lanes() == 1) {
372 CHECK_EQ(info.content_type, op->dtype)
373 << "Vulkan only allow one type access to the same buffer";
374 spirv::Value index = MakeValue(op->index);
375 spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
376 return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
377 } else {
378 if (op->dtype.element_of() == info.content_type) {
379 // because content type is element type, we can only do scalarize load.
380 std::vector<spirv::Value> values;
381 auto f = [&](int i, spirv::Value index) {
382 spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
383 values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
384 };
385 this->Scalarize(op->index, f);
386 return builder_->Concat(values);
387 } else {
388 if (const RampNode* ramp = op->index.as<RampNode>()) {
389 if (is_one(ramp->stride)) {
390 CHECK_EQ(ramp->lanes, op->dtype.lanes());
391 arith::ModularSet me = analyzer_->modular_set(ramp->base);
392 CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
393 << "Only aligned vector access is allowed in SPIRV";
394 PrimExpr vec_index =
395 analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
396 spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
397 return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
398 }
399 }
400 }
401 LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
402 }
403 LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
404 return spirv::Value();
405 }
406
Scalarize(const PrimExpr & e,std::function<void (int i,spirv::Value v)> f)407 void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f) {
408 if (const RampNode* ramp = e.as<RampNode>()) {
409 for (int i = 0; i < ramp->dtype.lanes(); ++i) {
410 PrimExpr offset = ramp->base + ramp->stride * i;
411 f(i, MakeValue(offset));
412 }
413 } else {
414 spirv::SType etype = builder_->GetSType(e.dtype().element_of());
415 spirv::Value value = MakeValue(e);
416 for (int i = 0; i < e.dtype().lanes(); ++i) {
417 f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i));
418 }
419 }
420 }
421
VisitStmt_(const StoreNode * op)422 void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
423 CHECK(is_one(op->predicate));
424 auto it = storage_info_.find(op->buffer_var.get());
425 CHECK(it != storage_info_.end());
426 StorageInfo& info = it->second;
427
428 if (!info.content_fixed) {
429 info.UpdateContentType(op->value.dtype());
430 }
431
432 spirv::SType content_type = builder_->GetSType(info.content_type);
433 spirv::Value buffer = MakeValue(op->buffer_var);
434 spirv::Value value = MakeValue(op->value);
435 spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
436
437 uint32_t mask = spv::MemoryAccessMaskNone;
438 if (info.is_volatile) {
439 mask |= spv::MemoryAccessVolatileMask;
440 }
441
442 if (op->value.dtype().lanes() == 1) {
443 CHECK_EQ(info.content_type, op->value.dtype())
444 << "Vulkan only allow one type access to the same buffer";
445 spirv::Value index = MakeValue(op->index);
446 spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
447 builder_->MakeInst(spv::OpStore, ptr, value, mask);
448 } else {
449 if (op->value.dtype().element_of() == info.content_type) {
450 // because content type is element type, we can only do scalarize load.
451 auto f = [&](int i, spirv::Value index) {
452 spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
453 spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
454 builder_->MakeInst(spv::OpStore, ptr, elem, mask);
455 };
456 this->Scalarize(op->index, f);
457 } else {
458 if (const RampNode* ramp = op->index.as<RampNode>()) {
459 if (is_one(ramp->stride)) {
460 CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
461 arith::ModularSet me = analyzer_->modular_set(ramp->base);
462 CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
463 << "Only aligned vector access is allowed in SPIRV";
464 PrimExpr vec_index =
465 analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
466 spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
467 builder_->MakeInst(spv::OpStore, ptr, value, mask);
468 return;
469 }
470 }
471 LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
472 }
473 }
474 }
475
VisitStmt_(const ForNode * op)476 void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
477 CHECK(is_zero(op->min));
478 analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
479 spirv::Value init_value = MakeValue(op->min);
480 spirv::Value extent_value = MakeValue(op->extent);
481 // Must get init label after making value(to make sure they are correct)
482 spirv::Label init_label = builder_->CurrentLabel();
483 spirv::Label head_label = builder_->NewLabel();
484 spirv::Label body_label = builder_->NewLabel();
485 spirv::Label continue_label = builder_->NewLabel();
486 spirv::Label merge_label = builder_->NewLabel();
487 builder_->MakeInst(spv::OpBranch, head_label);
488
489 // Loop head
490 builder_->StartLabel(head_label);
491 spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
492 loop_var.SetIncoming(0, init_value, init_label);
493 spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
494 uint32_t control =
495 (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
496 builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
497 builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
498 weight_likely_branch_, 1);
499
500 // loop body
501 builder_->StartLabel(body_label);
502 var_map_[op->loop_var.get()] = spirv::Value(loop_var);
503 this->VisitStmt(op->body);
504 builder_->MakeInst(spv::OpBranch, continue_label);
505
506 // loop continue
507 builder_->StartLabel(continue_label);
508 spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
509 : builder_->UIntImm(loop_var.stype, 1);
510 spirv::Value next_value = builder_->Add(loop_var, one);
511 loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
512 builder_->MakeInst(spv::OpBranch, head_label);
513 // loop merge
514 builder_->StartLabel(merge_label);
515 }
516
VisitStmt_(const IfThenElseNode * op)517 void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
518 spirv::Value cond = MakeValue(op->condition);
519 spirv::Label then_label = builder_->NewLabel();
520 spirv::Label merge_label = builder_->NewLabel();
521 if (op->else_case.defined()) {
522 spirv::Label else_label = builder_->NewLabel();
523 builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
524 builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
525 // then block
526 builder_->StartLabel(then_label);
527 this->VisitStmt(op->then_case);
528 builder_->MakeInst(spv::OpBranch, merge_label);
529 // else block
530 builder_->StartLabel(else_label);
531 this->VisitStmt(op->else_case);
532 builder_->MakeInst(spv::OpBranch, merge_label);
533 } else {
534 builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
535 builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label,
536 weight_likely_branch_, 1);
537 // then block
538 builder_->StartLabel(then_label);
539 this->VisitStmt(op->then_case);
540 builder_->MakeInst(spv::OpBranch, merge_label);
541 }
542 // start merge label;
543 builder_->StartLabel(merge_label);
544 }
545
VisitStmt_(const AllocateNode * op)546 void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
547 CHECK(!is_zero(op->condition));
548 CHECK(!op->dtype.is_handle());
549 int32_t constant_size = op->constant_allocation_size();
550 CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
551 spirv::Value buf;
552 StorageInfo& info = storage_info_[op->buffer_var.get()];
553 spirv::SType etype = builder_->GetSType(op->dtype);
554 if (info.scope.rank == runtime::StorageRank::kLocal) {
555 buf =
556 builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
557 } else {
558 // shared memory
559 CHECK(info.scope.rank == runtime::StorageRank::kShared)
560 << "Can only allocate shared or local memory inside kernel";
561 // Shared memory
562 buf =
563 builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassWorkgroup);
564 }
565 CHECK(!info.content_fixed);
566 info.UpdateContentType(op->dtype);
567 CHECK(!var_map_.count(op->buffer_var.get()));
568 var_map_[op->buffer_var.get()] = buf;
569 this->VisitStmt(op->body);
570 }
571
VisitStmt_(const AttrStmtNode * op)572 void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
573 if (op->attr_key == tir::attr::thread_extent) {
574 IterVar iv = Downcast<IterVar>(op->node);
575 if (iv->thread_tag.length() != 0) {
576 if (!var_map_.count(iv->var.get())) {
577 var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
578 analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
579 }
580 }
581 } else if (op->attr_key == tir::attr::storage_scope) {
582 const VarNode* v = op->node.as<VarNode>();
583 CHECK(v);
584 storage_info_[v].scope = runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
585 } else if (op->attr_key == tir::attr::volatile_scope) {
586 const VarNode* v = op->node.as<VarNode>();
587 CHECK(v);
588 storage_info_[v].is_volatile = true;
589 }
590 this->VisitStmt(op->body);
591 }
592
VisitStmt_(const AssertStmtNode * op)593 void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
594 With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
595 this->VisitStmt(op->body);
596 }
597
VisitStmt_(const LetStmtNode * op)598 void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
599 CHECK(!var_map_.count(op->var.get()));
600 CHECK(!op->var.dtype().is_handle());
601 var_map_[op->var.get()] = MakeValue(op->value);
602 analyzer_->Bind(op->var, op->value);
603 this->VisitStmt(op->body);
604 }
605
VisitStmt_(const SeqStmtNode * op)606 void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
607 for (Stmt stmt : op->seq) {
608 this->VisitStmt(stmt);
609 }
610 }
611
VisitStmt_(const EvaluateNode * op)612 void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
613
614 } // namespace codegen
615 } // namespace tvm
616