1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_spirv.cc
22  * \brief Generate SPIRV block
23  */
24 #include "codegen_spirv.h"
25 
26 #include <tvm/runtime/container.h>
27 #include <tvm/tir/builtin.h>
28 #include <tvm/tir/expr.h>
29 #include <tvm/tir/op.h>
30 
31 #include <string>
32 
33 namespace tvm {
34 namespace codegen {
35 
BuildFunction(const PrimFunc & f,const std::string & name)36 std::vector<uint32_t> CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) {
37   this->InitFuncState();
38   CHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model";
39   std::vector<Var> pod_args;
40   uint32_t num_buffer = 0;
41 
42   for (Var arg : f->params) {
43     DataType t = arg.dtype();
44     if (t.is_handle()) {
45       if (auto* ptr = arg->type_annotation.as<PointerTypeNode>()) {
46         auto* prim = ptr->element_type.as<PrimTypeNode>();
47         CHECK(prim);
48         DataType value_type = prim->dtype;
49         spirv::Value arg_value =
50             builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer);
51         storage_info_[arg.get()].UpdateContentType(value_type);
52         var_map_[arg.get()] = arg_value;
53       } else {
54         LOG(FATAL) << "require all handles to be typed";
55       }
56       ++num_buffer;
57     } else {
58       pod_args.push_back(arg);
59     }
60   }
61   spirv::Value func_ptr = builder_->NewFunction();
62   builder_->StartFunction(func_ptr);
63 
64   // All the POD arguments are passed in through PushConstant
65   if (pod_args.size() != 0) {
66     std::vector<spirv::SType> value_types;
67     for (size_t i = 0; i < pod_args.size(); ++i) {
68       value_types.push_back(builder_->GetSType(pod_args[i].dtype()));
69     }
70     spirv::Value ptr = builder_->DeclarePushConstant(value_types);
71     for (size_t i = 0; i < pod_args.size(); ++i) {
72       spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast<uint32_t>(i));
73       var_map_[pod_args[i].get()] = value;
74     }
75   }
76   this->VisitStmt(f->body);
77   builder_->SetLocalSize(func_ptr, workgroup_size_);
78   builder_->MakeInst(spv::OpReturn);
79   builder_->MakeInst(spv::OpFunctionEnd);
80 
81   builder_->CommitKernelFunction(func_ptr, name);
82 
83   return builder_->Finalize();
84 }
85 
InitFuncState()86 void CodeGenSPIRV::InitFuncState() {
87   std::fill(workgroup_size_, workgroup_size_ + 3, 1);
88   var_map_.clear();
89   storage_info_.clear();
90   analyzer_.reset(new arith::Analyzer());
91   builder_.reset(new spirv::IRBuilder());
92   builder_->InitHeader();
93 }
94 
GetThreadIndex(const IterVar & iv,const PrimExpr & extent)95 spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
96   runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
97   spirv::Value v;
98   if (ts.rank == 1) {
99     v = builder_->GetLocalID(ts.dim_index);
100     auto* sizeptr = extent.as<tir::IntImmNode>();
101     CHECK(sizeptr) << "SPIRV only allows constant thread group size "
102                    << " get " << extent;
103     CHECK_LT(ts.dim_index, 3);
104     workgroup_size_[ts.dim_index] = static_cast<uint32_t>(sizeptr->value);
105   } else {
106     v = builder_->GetWorkgroupID(ts.dim_index);
107   }
108   return builder_->Cast(builder_->GetSType(iv->var.dtype()), v);
109 }
110 
CreateStorageSync(const CallNode * op)111 spirv::Value CodeGenSPIRV::CreateStorageSync(const CallNode* op) {
112   const std::string& sync = op->args[0].as<StringImmNode>()->value;
113   spirv::Value value;
114   if (sync == "warp") {
115     return value;
116   } else if (sync == "shared") {
117     auto type_int = builder_->GetSType(DataType::Int(32));
118     builder_->MakeInst(
119         spv::OpControlBarrier,
120         builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
121         builder_->IntImm(type_int, static_cast<int64_t>(spv::ScopeWorkgroup)),
122         builder_->IntImm(type_int,
123                          static_cast<int64_t>(spv::MemorySemanticsSequentiallyConsistentMask |
124                                               spv::MemorySemanticsWorkgroupMemoryMask)));
125   } else {
126     LOG(FATAL) << "Do not support sync " << sync;
127   }
128   return value;
129 }
130 
VisitExpr_(const VarNode * op)131 spirv::Value CodeGenSPIRV::VisitExpr_(const VarNode* op) {
132   auto it = var_map_.find(op);
133   CHECK(it != var_map_.end()) << "cannot find variable " << op->name_hint;
134   return it->second;
135 }
136 
VisitExpr_(const IntImmNode * op)137 spirv::Value CodeGenSPIRV::VisitExpr_(const IntImmNode* op) {
138   return builder_->IntImm(builder_->GetSType(op->dtype), op->value);
139 }
140 
VisitExpr_(const FloatImmNode * op)141 spirv::Value CodeGenSPIRV::VisitExpr_(const FloatImmNode* op) {
142   return builder_->FloatImm(builder_->GetSType(op->dtype), op->value);
143 }
144 
VisitExpr_(const StringImmNode * op)145 spirv::Value CodeGenSPIRV::VisitExpr_(const StringImmNode* op) {
146   LOG(FATAL) << "StringImm is not supported in Device code";
147   return spirv::Value();
148 }
149 
VisitExpr_(const CastNode * op)150 spirv::Value CodeGenSPIRV::VisitExpr_(const CastNode* op) {
151   return builder_->Cast(builder_->GetSType(op->dtype), MakeValue(op->value));
152 }
153 
VisitExpr_(const AddNode * op)154 spirv::Value CodeGenSPIRV::VisitExpr_(const AddNode* op) {
155   return builder_->Add(MakeValue(op->a), MakeValue(op->b));
156 }
157 
VisitExpr_(const SubNode * op)158 spirv::Value CodeGenSPIRV::VisitExpr_(const SubNode* op) {
159   return builder_->Sub(MakeValue(op->a), MakeValue(op->b));
160 }
161 
VisitExpr_(const MulNode * op)162 spirv::Value CodeGenSPIRV::VisitExpr_(const MulNode* op) {
163   return builder_->Mul(MakeValue(op->a), MakeValue(op->b));
164 }
165 
VisitExpr_(const DivNode * op)166 spirv::Value CodeGenSPIRV::VisitExpr_(const DivNode* op) {
167   return builder_->Div(MakeValue(op->a), MakeValue(op->b));
168 }
169 
VisitExpr_(const ModNode * op)170 spirv::Value CodeGenSPIRV::VisitExpr_(const ModNode* op) {
171   return builder_->Mod(MakeValue(op->a), MakeValue(op->b));
172 }
173 
VisitExpr_(const MinNode * op)174 spirv::Value CodeGenSPIRV::VisitExpr_(const MinNode* op) {
175   spirv::Value a = MakeValue(op->a);
176   spirv::Value b = MakeValue(op->b);
177   return builder_->Select(builder_->LT(a, b), a, b);
178 }
179 
VisitExpr_(const MaxNode * op)180 spirv::Value CodeGenSPIRV::VisitExpr_(const MaxNode* op) {
181   spirv::Value a = MakeValue(op->a);
182   spirv::Value b = MakeValue(op->b);
183   return builder_->Select(builder_->GT(a, b), a, b);
184 }
185 
VisitExpr_(const LTNode * op)186 spirv::Value CodeGenSPIRV::VisitExpr_(const LTNode* op) {
187   return builder_->LT(MakeValue(op->a), MakeValue(op->b));
188 }
189 
VisitExpr_(const LENode * op)190 spirv::Value CodeGenSPIRV::VisitExpr_(const LENode* op) {
191   return builder_->LE(MakeValue(op->a), MakeValue(op->b));
192 }
193 
VisitExpr_(const GTNode * op)194 spirv::Value CodeGenSPIRV::VisitExpr_(const GTNode* op) {
195   return builder_->GT(MakeValue(op->a), MakeValue(op->b));
196 }
197 
VisitExpr_(const GENode * op)198 spirv::Value CodeGenSPIRV::VisitExpr_(const GENode* op) {
199   return builder_->GE(MakeValue(op->a), MakeValue(op->b));
200 }
201 
VisitExpr_(const EQNode * op)202 spirv::Value CodeGenSPIRV::VisitExpr_(const EQNode* op) {
203   return builder_->EQ(MakeValue(op->a), MakeValue(op->b));
204 }
205 
VisitExpr_(const NENode * op)206 spirv::Value CodeGenSPIRV::VisitExpr_(const NENode* op) {
207   return builder_->NE(MakeValue(op->a), MakeValue(op->b));
208 }
209 
VisitExpr_(const AndNode * op)210 spirv::Value CodeGenSPIRV::VisitExpr_(const AndNode* op) {
211   spirv::Value a = MakeValue(op->a);
212   spirv::Value b = MakeValue(op->b);
213   return builder_->MakeValue(spv::OpLogicalAnd, a.stype, a, b);
214 }
215 
VisitExpr_(const OrNode * op)216 spirv::Value CodeGenSPIRV::VisitExpr_(const OrNode* op) {
217   spirv::Value a = MakeValue(op->a);
218   spirv::Value b = MakeValue(op->b);
219   return builder_->MakeValue(spv::OpLogicalOr, a.stype, a, b);
220 }
221 
VisitExpr_(const NotNode * op)222 spirv::Value CodeGenSPIRV::VisitExpr_(const NotNode* op) {
223   spirv::Value a = MakeValue(op->a);
224   return builder_->MakeValue(spv::OpLogicalNot, a.stype, a);
225 }
226 
VisitExpr_(const SelectNode * op)227 spirv::Value CodeGenSPIRV::VisitExpr_(const SelectNode* op) {
228   return builder_->Select(MakeValue(op->condition), MakeValue(op->true_value),
229                           MakeValue(op->false_value));
230 }
231 
VisitExpr_(const LetNode * op)232 spirv::Value CodeGenSPIRV::VisitExpr_(const LetNode* op) {
233   auto it = let_binding_.find(op->var);
234   if (it != let_binding_.end()) {
235     CHECK(deep_equal_(it->second->value, op->value))
236         << "Let cannot bind the same var to two different values";
237   } else {
238     let_binding_[op->var] = op;
239   }
240   var_map_[op->var.get()] = MakeValue(op->value);
241   analyzer_->Bind(op->var, op->value);
242   return MakeValue(op->body);
243 }
244 
VisitExpr_(const CallNode * op)245 spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
246   if (op->op.same_as(builtin::call_spirv_pure_glsl450())) {
247     CHECK_GE(op->args.size(), 2U);
248     uint32_t inst_id = static_cast<uint32_t>(op->args[0].as<IntImmNode>()->value);
249     std::vector<spirv::Value> values;
250     for (size_t i = 1; i < op->args.size(); ++i) {
251       values.push_back(MakeValue(op->args[i]));
252     }
253     return builder_->CallGLSL450(builder_->GetSType(op->dtype), inst_id, values);
254   } else if (op->op.same_as(builtin::bitwise_and())) {
255     CHECK_EQ(op->args.size(), 2U);
256     spirv::Value a = MakeValue(op->args[0]);
257     spirv::Value b = MakeValue(op->args[1]);
258     return builder_->MakeValue(spv::OpBitwiseAnd, a.stype, a, b);
259   } else if (op->op.same_as(builtin::bitwise_xor())) {
260     CHECK_EQ(op->args.size(), 2U);
261     spirv::Value a = MakeValue(op->args[0]);
262     spirv::Value b = MakeValue(op->args[1]);
263     return builder_->MakeValue(spv::OpBitwiseXor, a.stype, a, b);
264   } else if (op->op.same_as(builtin::bitwise_or())) {
265     CHECK_EQ(op->args.size(), 2U);
266     spirv::Value a = MakeValue(op->args[0]);
267     spirv::Value b = MakeValue(op->args[1]);
268     return builder_->MakeValue(spv::OpBitwiseOr, a.stype, a, b);
269   } else if (op->op.same_as(builtin::bitwise_not())) {
270     CHECK_EQ(op->args.size(), 1U);
271     spirv::Value a = MakeValue(op->args[0]);
272     return builder_->MakeValue(spv::OpNot, a.stype, a);
273   } else if (op->op.same_as(builtin::shift_left())) {
274     CHECK_EQ(op->args.size(), 2U);
275     spirv::Value a = MakeValue(op->args[0]);
276     spirv::Value b = MakeValue(op->args[1]);
277     return builder_->MakeValue(spv::OpShiftLeftLogical, a.stype, a, b);
278   } else if (op->op.same_as(builtin::shift_right())) {
279     CHECK_EQ(op->args.size(), 2U);
280     spirv::Value a = MakeValue(op->args[0]);
281     spirv::Value b = MakeValue(op->args[1]);
282     if (op->args[0].dtype().is_int()) {
283       return builder_->MakeValue(spv::OpShiftRightArithmetic, a.stype, a, b);
284     } else {
285       return builder_->MakeValue(spv::OpShiftRightLogical, a.stype, a, b);
286     }
287   } else if (op->op.same_as(builtin::reinterpret())) {
288     return builder_->MakeValue(spv::OpBitcast, builder_->GetSType(op->dtype),
289                                MakeValue(op->args[0]));
290   } else if (op->op.same_as(builtin::large_uint_imm())) {
291     CHECK_EQ(op->args.size(), 2U);
292     uint64_t low = static_cast<uint64_t>(Downcast<IntImm>(op->args[0])->value);
293     uint64_t high = static_cast<uint64_t>(Downcast<IntImm>(op->args[1])->value);
294     uint64_t val = (high << 32U) | low;
295     return builder_->UIntImm(builder_->GetSType(op->dtype), val);
296   } else if (op->op.same_as(builtin::tvm_storage_sync())) {
297     return this->CreateStorageSync(op);
298   } else if (op->op.same_as(builtin::if_then_else())) {
299     CHECK_EQ(op->args.size(), 3U);
300     spirv::Value cond = MakeValue(op->args[0]);
301     spirv::Label then_label = builder_->NewLabel();
302     spirv::Label else_label = builder_->NewLabel();
303     spirv::Label merge_label = builder_->NewLabel();
304     builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
305     builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
306     // then block, must get label after we see the value
307     builder_->StartLabel(then_label);
308     spirv::Value then_value = MakeValue(op->args[1]);
309     spirv::Label then_value_label = builder_->CurrentLabel();
310     builder_->MakeInst(spv::OpBranch, merge_label);
311     // else block
312     builder_->StartLabel(else_label);
313     spirv::Value else_value = MakeValue(op->args[2]);
314     spirv::Label else_value_label = builder_->CurrentLabel();
315     builder_->MakeInst(spv::OpBranch, merge_label);
316     // merge block
317     builder_->StartLabel(merge_label);
318     spirv::PhiValue phi = builder_->MakePhi(then_value.stype, 2);
319     phi.SetIncoming(0, then_value, then_value_label);
320     phi.SetIncoming(1, else_value, else_value_label);
321     return phi;
322   } else if (op->op.same_as(builtin::popcount())) {
323     return builder_->MakeValue(spv::OpBitCount, builder_->GetSType(op->dtype),
324                                MakeValue(op->args[0]));
325   } else {
326     LOG(FATAL) << "Unresolved call  " << op->op;
327     return spirv::Value();
328   }
329 }
330 
VisitExpr_(const RampNode * op)331 spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
332   std::vector<spirv::Value> values;
333   spirv::Value base = MakeValue(op->base);
334   for (int i = 0; i < op->lanes; ++i) {
335     spirv::Value v = base;
336     if (i != 0) {
337       spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) * op->stride);
338       v = builder_->Add(v, offset);
339     }
340     values.push_back(v);
341   }
342   return builder_->Concat(values);
343 }
344 
VisitExpr_(const BroadcastNode * op)345 spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
346   std::vector<spirv::Value> values;
347   spirv::Value v = MakeValue(op->value);
348   for (int i = 0; i < op->lanes; i++) {
349     values.push_back(v);
350   }
351   return builder_->Concat(values);
352 }
353 
VisitExpr_(const LoadNode * op)354 spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) {
355   CHECK(is_one(op->predicate));
356   auto it = storage_info_.find(op->buffer_var.get());
357   CHECK(it != storage_info_.end());
358   StorageInfo& info = it->second;
359   if (!info.content_fixed) {
360     info.UpdateContentType(op->dtype);
361   }
362 
363   spirv::SType content_type = builder_->GetSType(info.content_type);
364   spirv::Value buffer = MakeValue(op->buffer_var);
365   spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
366 
367   uint32_t mask = spv::MemoryAccessMaskNone;
368   if (info.is_volatile) {
369     mask |= spv::MemoryAccessVolatileMask;
370   }
371   if (op->dtype.lanes() == 1) {
372     CHECK_EQ(info.content_type, op->dtype)
373         << "Vulkan only allow one type access to the same buffer";
374     spirv::Value index = MakeValue(op->index);
375     spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
376     return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
377   } else {
378     if (op->dtype.element_of() == info.content_type) {
379       // because content type is element type, we can only do scalarize load.
380       std::vector<spirv::Value> values;
381       auto f = [&](int i, spirv::Value index) {
382         spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
383         values.emplace_back(builder_->MakeValue(spv::OpLoad, content_type, ptr, mask));
384       };
385       this->Scalarize(op->index, f);
386       return builder_->Concat(values);
387     } else {
388       if (const RampNode* ramp = op->index.as<RampNode>()) {
389         if (is_one(ramp->stride)) {
390           CHECK_EQ(ramp->lanes, op->dtype.lanes());
391           arith::ModularSet me = analyzer_->modular_set(ramp->base);
392           CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
393               << "Only aligned vector access is allowed in SPIRV";
394           PrimExpr vec_index =
395               analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
396           spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
397           return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask);
398         }
399       }
400     }
401     LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
402   }
403   LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
404   return spirv::Value();
405 }
406 
Scalarize(const PrimExpr & e,std::function<void (int i,spirv::Value v)> f)407 void CodeGenSPIRV::Scalarize(const PrimExpr& e, std::function<void(int i, spirv::Value v)> f) {
408   if (const RampNode* ramp = e.as<RampNode>()) {
409     for (int i = 0; i < ramp->dtype.lanes(); ++i) {
410       PrimExpr offset = ramp->base + ramp->stride * i;
411       f(i, MakeValue(offset));
412     }
413   } else {
414     spirv::SType etype = builder_->GetSType(e.dtype().element_of());
415     spirv::Value value = MakeValue(e);
416     for (int i = 0; i < e.dtype().lanes(); ++i) {
417       f(i, builder_->MakeValue(spv::OpCompositeExtract, etype, value, i));
418     }
419   }
420 }
421 
VisitStmt_(const StoreNode * op)422 void CodeGenSPIRV::VisitStmt_(const StoreNode* op) {
423   CHECK(is_one(op->predicate));
424   auto it = storage_info_.find(op->buffer_var.get());
425   CHECK(it != storage_info_.end());
426   StorageInfo& info = it->second;
427 
428   if (!info.content_fixed) {
429     info.UpdateContentType(op->value.dtype());
430   }
431 
432   spirv::SType content_type = builder_->GetSType(info.content_type);
433   spirv::Value buffer = MakeValue(op->buffer_var);
434   spirv::Value value = MakeValue(op->value);
435   spirv::SType ptr_type = builder_->GetPointerType(content_type, buffer.stype.storage_class);
436 
437   uint32_t mask = spv::MemoryAccessMaskNone;
438   if (info.is_volatile) {
439     mask |= spv::MemoryAccessVolatileMask;
440   }
441 
442   if (op->value.dtype().lanes() == 1) {
443     CHECK_EQ(info.content_type, op->value.dtype())
444         << "Vulkan only allow one type access to the same buffer";
445     spirv::Value index = MakeValue(op->index);
446     spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
447     builder_->MakeInst(spv::OpStore, ptr, value, mask);
448   } else {
449     if (op->value.dtype().element_of() == info.content_type) {
450       // because content type is element type, we can only do scalarize load.
451       auto f = [&](int i, spirv::Value index) {
452         spirv::Value elem = builder_->MakeValue(spv::OpCompositeExtract, content_type, value, i);
453         spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index);
454         builder_->MakeInst(spv::OpStore, ptr, elem, mask);
455       };
456       this->Scalarize(op->index, f);
457     } else {
458       if (const RampNode* ramp = op->index.as<RampNode>()) {
459         if (is_one(ramp->stride)) {
460           CHECK_EQ(ramp->lanes, op->value.dtype().lanes());
461           arith::ModularSet me = analyzer_->modular_set(ramp->base);
462           CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0)
463               << "Only aligned vector access is allowed in SPIRV";
464           PrimExpr vec_index =
465               analyzer_->Simplify(ramp->base / make_const(ramp->base.dtype(), ramp->lanes));
466           spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, MakeValue(vec_index));
467           builder_->MakeInst(spv::OpStore, ptr, value, mask);
468           return;
469         }
470       }
471       LOG(FATAL) << "Only aligned continuous vector access is allowed in SPIRV";
472     }
473   }
474 }
475 
VisitStmt_(const ForNode * op)476 void CodeGenSPIRV::VisitStmt_(const ForNode* op) {
477   CHECK(is_zero(op->min));
478   analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
479   spirv::Value init_value = MakeValue(op->min);
480   spirv::Value extent_value = MakeValue(op->extent);
481   // Must get init label after making value(to make sure they are correct)
482   spirv::Label init_label = builder_->CurrentLabel();
483   spirv::Label head_label = builder_->NewLabel();
484   spirv::Label body_label = builder_->NewLabel();
485   spirv::Label continue_label = builder_->NewLabel();
486   spirv::Label merge_label = builder_->NewLabel();
487   builder_->MakeInst(spv::OpBranch, head_label);
488 
489   // Loop head
490   builder_->StartLabel(head_label);
491   spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2);
492   loop_var.SetIncoming(0, init_value, init_label);
493   spirv::Value loop_cond = builder_->LT(loop_var, extent_value);
494   uint32_t control =
495       (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone);
496   builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control);
497   builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label,
498                      weight_likely_branch_, 1);
499 
500   // loop body
501   builder_->StartLabel(body_label);
502   var_map_[op->loop_var.get()] = spirv::Value(loop_var);
503   this->VisitStmt(op->body);
504   builder_->MakeInst(spv::OpBranch, continue_label);
505 
506   // loop continue
507   builder_->StartLabel(continue_label);
508   spirv::Value one = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1)
509                                                    : builder_->UIntImm(loop_var.stype, 1);
510   spirv::Value next_value = builder_->Add(loop_var, one);
511   loop_var.SetIncoming(1, next_value, builder_->CurrentLabel());
512   builder_->MakeInst(spv::OpBranch, head_label);
513   // loop merge
514   builder_->StartLabel(merge_label);
515 }
516 
VisitStmt_(const IfThenElseNode * op)517 void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
518   spirv::Value cond = MakeValue(op->condition);
519   spirv::Label then_label = builder_->NewLabel();
520   spirv::Label merge_label = builder_->NewLabel();
521   if (op->else_case.defined()) {
522     spirv::Label else_label = builder_->NewLabel();
523     builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
524     builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
525     // then block
526     builder_->StartLabel(then_label);
527     this->VisitStmt(op->then_case);
528     builder_->MakeInst(spv::OpBranch, merge_label);
529     // else block
530     builder_->StartLabel(else_label);
531     this->VisitStmt(op->else_case);
532     builder_->MakeInst(spv::OpBranch, merge_label);
533   } else {
534     builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
535     builder_->MakeInst(spv::OpBranchConditional, cond, then_label, merge_label,
536                        weight_likely_branch_, 1);
537     // then block
538     builder_->StartLabel(then_label);
539     this->VisitStmt(op->then_case);
540     builder_->MakeInst(spv::OpBranch, merge_label);
541   }
542   // start merge label;
543   builder_->StartLabel(merge_label);
544 }
545 
VisitStmt_(const AllocateNode * op)546 void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
547   CHECK(!is_zero(op->condition));
548   CHECK(!op->dtype.is_handle());
549   int32_t constant_size = op->constant_allocation_size();
550   CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU";
551   spirv::Value buf;
552   StorageInfo& info = storage_info_[op->buffer_var.get()];
553   spirv::SType etype = builder_->GetSType(op->dtype);
554   if (info.scope.rank == runtime::StorageRank::kLocal) {
555     buf =
556         builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassFunction);
557   } else {
558     // shared memory
559     CHECK(info.scope.rank == runtime::StorageRank::kShared)
560         << "Can only allocate shared or local memory inside kernel";
561     // Shared memory
562     buf =
563         builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassWorkgroup);
564   }
565   CHECK(!info.content_fixed);
566   info.UpdateContentType(op->dtype);
567   CHECK(!var_map_.count(op->buffer_var.get()));
568   var_map_[op->buffer_var.get()] = buf;
569   this->VisitStmt(op->body);
570 }
571 
VisitStmt_(const AttrStmtNode * op)572 void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) {
573   if (op->attr_key == tir::attr::thread_extent) {
574     IterVar iv = Downcast<IterVar>(op->node);
575     if (iv->thread_tag.length() != 0) {
576       if (!var_map_.count(iv->var.get())) {
577         var_map_[iv->var.get()] = GetThreadIndex(iv, op->value);
578         analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
579       }
580     }
581   } else if (op->attr_key == tir::attr::storage_scope) {
582     const VarNode* v = op->node.as<VarNode>();
583     CHECK(v);
584     storage_info_[v].scope = runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
585   } else if (op->attr_key == tir::attr::volatile_scope) {
586     const VarNode* v = op->node.as<VarNode>();
587     CHECK(v);
588     storage_info_[v].is_volatile = true;
589   }
590   this->VisitStmt(op->body);
591 }
592 
VisitStmt_(const AssertStmtNode * op)593 void CodeGenSPIRV::VisitStmt_(const AssertStmtNode* op) {
594   With<arith::ConstraintContext> cctx(analyzer_.get(), op->condition);
595   this->VisitStmt(op->body);
596 }
597 
VisitStmt_(const LetStmtNode * op)598 void CodeGenSPIRV::VisitStmt_(const LetStmtNode* op) {
599   CHECK(!var_map_.count(op->var.get()));
600   CHECK(!op->var.dtype().is_handle());
601   var_map_[op->var.get()] = MakeValue(op->value);
602   analyzer_->Bind(op->var, op->value);
603   this->VisitStmt(op->body);
604 }
605 
VisitStmt_(const SeqStmtNode * op)606 void CodeGenSPIRV::VisitStmt_(const SeqStmtNode* op) {
607   for (Stmt stmt : op->seq) {
608     this->VisitStmt(stmt);
609   }
610 }
611 
VisitStmt_(const EvaluateNode * op)612 void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); }
613 
614 }  // namespace codegen
615 }  // namespace tvm
616