1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file codegen_cuda.cc
22  */
23 
24 #include "codegen_cuda.h"
25 
26 #include <tvm/runtime/registry.h>
27 
28 #include <cmath>
29 #include <string>
30 #include <utility>
31 #include <vector>
32 
33 #include "literal/cuda_half_t.h"
34 
35 namespace tvm {
36 namespace codegen {
37 
CodeGenCUDA()38 CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
39 
Init(bool output_ssa)40 void CodeGenCUDA::Init(bool output_ssa) {
41   CodeGenC::Init(output_ssa);
42   vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state);
43   vid_global_barrier_expect_ = GetUniqueName("__barrier_expect");
44   CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
45 }
46 
PrintFuncPrefix()47 void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }
48 
Finish()49 std::string CodeGenCUDA::Finish() {
50   if (enable_fp16_) {
51     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
52     decl_stream << "#include <cuda_fp16.h>\n";
53     decl_stream << "__device__ half max"
54                 << "(half a, half b)\n"
55                 << "{\n  return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
56     decl_stream << "__device__ half min(half a, half b)\n"
57                 << "{\n  return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
58     decl_stream << "#else\n";
59     decl_stream << _cuda_half_t_def;
60     decl_stream << "#endif\n\n";
61     decl_stream << _cuda_half_util;
62   }
63 
64   if (enable_warp_shuffle_) {
65     decl_stream << _cuda_warp_intrinsic_util;
66   }
67 
68   if (enable_int8_) {
69     decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
70     decl_stream << "#include <sm_61_intrinsics.h>\n";
71     decl_stream << "#endif\n";
72   }
73 
74   if (need_math_constants_h_) {
75     decl_stream << "#include <math_constants.h>\n";
76   }
77 
78   if (need_mma_h_) {
79     decl_stream << "#include <mma.h>\n";
80   }
81 
82   return CodeGenC::Finish();
83 }
84 
VisitStmt_(const tir::ForNode * op)85 void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
86   CHECK(is_const_int(op->min, 0));
87   if (op->for_type == tir::ForType::Unrolled) {
88     PrintIndent();
89     stream << "#pragma unroll\n";
90   }
91   CodeGenC::VisitStmt_(op);
92 }
93 
BindThreadIndex(const IterVar & iv)94 void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
95   CHECK(!var_idmap_.count(iv->var.get()));
96   var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
97 }
98 
PrintType(DataType t,std::ostream & os)99 void CodeGenCUDA::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
100   int lanes = t.lanes();
101   if (t.is_handle()) {
102     CHECK_EQ(lanes, 1) << "do not yet support vector types";
103     os << "void*";
104     return;
105   }
106   bool fail = false;
107   if (t.is_float()) {
108     switch (t.bits()) {
109       case 16:
110         enable_fp16_ = true;
111         if (lanes == 1) {
112           os << "half";
113         } else if (lanes <= 8) {
114           // Emit CUDA code to access fp16 vector elements.
115           //
116           // half4 is stored as uint2
117           //
118           // h4.x is emitted as *(half2*)(&(u2.x)).x
119           // h4.y is emitted as *(half2*)(&(u2.x)).y
120           // h4.z is emitted as *(half2*)(&(u2.y)).x
121           // h4.w is emitted as *(half2*)(&(u2.y)).y
122           //
123           CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
124           os << "uint" << lanes / 2;
125         } else {
126           fail = true;
127         }
128         break;
129       case 32:
130         os << "float";
131         break;
132       case 64:
133         os << "double";
134         break;
135       default:
136         fail = true;
137         break;
138     }
139     if (!fail && (lanes == 1 || t.bits() == 16)) return;
140     if (!fail && (lanes >= 2 && lanes <= 4)) {
141       os << lanes;
142       return;
143     }
144   } else if (t == DataType::Bool()) {
145     os << "bool";
146     return;
147   } else if (t.is_vector_bool()) {
148     // CUDA does not support bool vectors.
149     // Use ushort vectors to represent instead.
150     int n = t.lanes();
151     if (n <= 4) {
152       os << "ushort" << n;
153       return;
154     }
155   } else if (t.is_uint() || t.is_int()) {
156     if (t.is_uint()) {
157       if (t.lanes() != 1) {
158         os << "u";
159       } else {
160         os << "unsigned ";
161       }
162     }
163     switch (t.bits()) {
164       case 1: {
165         if (t.lanes() == 1) {
166           os << "int";
167           return;
168         } else if (t.lanes() == 8) {
169           os << "int8_t";
170           return;
171         } else if (t.lanes() == 16) {
172           os << "int16_t";
173           return;
174         } else if (t.lanes() == 32) {
175           os << "int";
176           return;
177         } else {
178           LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
179         }
180       }
181       case 4: {
182         if (t.lanes() == 1) {
183           os << "int";
184           return;
185         } else if (t.lanes() == 4) {
186           os << "int16_t";
187           return;
188         } else if (t.lanes() == 8) {
189           // directly 8 4-bit int in integer.
190           os << "int";
191           return;
192         } else if (t.lanes() == 16) {
193           os << "int2";
194           return;
195         } else if (t.lanes() == 32) {
196           os << "int4";
197           return;
198         } else if (t.lanes() == 64) {
199           os << "int8";
200           return;
201         } else {
202           LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
203         }
204       }
205       case 8: {
206         if (t.lanes() == 4) {
207           // directly 4 8 bit int in integer.
208           enable_int8_ = true;
209 
210           // We use int for int8x4 instead of char4 because using char4 is
211           // likely to produce extra instructions to pack four int8 elements
212           // into 32-bit data.
213           os << "int";
214           return;
215         } else if (t.lanes() == 8) {
216           enable_int8_ = true;
217           os << "int2";
218           return;
219         } else if (t.lanes() == 16) {
220           enable_int8_ = true;
221           os << "int4";
222           return;
223         } else if (!t.is_uint() && t.lanes() == 1) {
224           os << "signed char";
225           break;
226         } else {
227           os << "char";
228           break;
229         }
230       }
231       case 16:
232         os << "short";
233         break;
234       case 32:
235         os << "int";
236         break;
237       case 64: {
238         if (sizeof(long) != 8) {  // NOLINT(*)
239           if (t.lanes() == 1) {
240             os << "long long";
241             break;
242           } else if (t.lanes() == 2) {
243             os << "longlong";
244             break;
245           } else {
246             // No longlong3, longlong4
247             LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform";
248             break;
249           }
250         } else {
251           os << "long";
252           break;
253         }
254       }
255       default:
256         fail = true;
257         break;
258     }
259     if (!fail && lanes == 1) {
260       return;
261     }
262     if (!fail && (lanes >= 2 && lanes <= 4)) {
263       os << lanes;
264       return;
265     }
266   }
267   LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
268 }
269 
PrintVecBinaryOp(const std::string & op,DataType t,PrimExpr lhs,PrimExpr rhs,std::ostream & os)270 void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
271                                    std::ostream& os) {  // NOLINT(*)
272   // Delcare the result.
273   std::string sret = GetUniqueName("_");
274   this->PrintIndent();
275   this->PrintType(t, stream);
276   stream << ' ' << sret << ";\n";
277   {
278     // Unpack into individual ops.
279     std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
280     std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
281 
282     for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
283       std::ostringstream value_temp;
284       if (isalpha(op[0])) {
285         value_temp << op << "(";
286         PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
287         value_temp << ", ";
288         PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
289         value_temp << ")";
290       } else {
291         value_temp << "(";
292         PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
293         value_temp << op;
294         PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
295         value_temp << ")";
296       }
297       PrintVecElemStore(sret, t, i, value_temp.str());
298     }
299   }
300   os << sret;
301 }
302 
PrintVecElemLoad(const std::string & vec,DataType t,int i,std::ostream & os)303 void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
304                                    std::ostream& os) {  // NOLINT(*)
305   if (t.is_scalar()) {
306     os << vec;
307     return;
308   }
309 
310   static const char access[] = {'x', 'y', 'z', 'w'};
311   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
312   if ((t.is_int()) && t.bits() == 8) {
313     if (t.lanes() == 2 || t.lanes() == 3) {
314       os << vec << "." << access[i % t.lanes()];
315     } else {
316       os << "((char)(" << vec << " >> " << i * 8 << "))";
317     }
318   } else if ((t.is_uint()) && t.bits() == 8) {
319     if (t.lanes() == 2 || t.lanes() == 3) {
320       os << vec << "." << access[i % t.lanes()];
321     } else {
322       os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
323     }
324   } else if (t.is_float16()) {
325     os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
326   } else {
327     os << vec << "." << access[i];
328   }
329 }
330 
PrintVecElemStore(const std::string & vec,DataType t,int i,const std::string & value)331 void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
332                                     const std::string& value) {
333   this->PrintIndent();
334   static const char access[] = {'x', 'y', 'z', 'w'};
335   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
336   if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
337     if (t.lanes() == 2 || t.lanes() == 3) {
338       stream << vec << '.' << access[i % t.lanes()] << "="
339              << "(" << value << ");\n";
340     } else {
341       stream << vec << "=";
342       // Do not read the first undef lane.
343       if (i != 0) {
344         stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
345       }
346       stream << "(" << value << " << " << i * 8 << ");\n";
347     }
348   } else if (t.is_float16()) {
349     stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
350            << value << ";\n";
351   } else {
352     stream << vec << "." << access[i] << " = " << value << ";\n";
353   }
354 }
355 
PrintStorageSync(const CallNode * op)356 void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
357   const std::string& sync = op->args[0].as<StringImmNode>()->value;
358   if (sync == "warp") {
359     // DO nothing.
360   } else if (sync == "shared") {
361     this->PrintIndent();
362     this->stream << "__syncthreads();\n";
363   } else if (sync == "global") {
364     if (!need_global_barrier_) {
365       need_global_barrier_ = true;
366       this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_
367                         << ";\n";
368     }
369     // global synchronizer
370     std::string is_load = PrintExpr(op->args[1]);
371     std::string num_blocks = PrintExpr(op->args[2]);
372     this->PrintIndent();
373     // In theory only threadfence is needed
374     // but we observed problems with only threadfence
375     this->stream << "__threadfence_system();\n";
376     this->PrintIndent();
377     this->stream << "if (" << is_load << ") {\n";
378     int wb = this->BeginScope();
379     this->PrintIndent();
380     this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
381     this->PrintIndent();
382     std::string ptr = GetUniqueName("pf");
383     this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n";
384     this->PrintIndent();
385     this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
386     this->PrintIndent();
387     this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n";
388     this->EndScope(wb);
389     this->PrintIndent();
390     this->stream << "}\n";
391     this->PrintIndent();
392     this->stream << "__syncthreads();\n";
393   }
394 }
395 
PrintStorageScope(const std::string & scope,std::ostream & os)396 void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) {  // NOLINT(*)
397   CHECK_NE(scope, "global");
398   if (scope == "shared") {
399     os << "__shared__ ";
400   }
401 }
402 
VisitExpr_(const CastNode * op,std::ostream & os)403 void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
404   DataType from_ty = op->value.dtype();
405   DataType target_ty = op->dtype;
406   CHECK_EQ(target_ty.lanes(), from_ty.lanes());
407 
408   // Emit simple C-style type conversion.
409   if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
410 
411   // We could emit make_float4 like calls, but the emitted code looks
412   // too compact to read. Emit this as vectorized unary ops.
413   std::string sret = GetUniqueName("_");
414   this->PrintIndent();
415   this->PrintType(target_ty, stream);
416   stream << ' ' << sret << ";\n";
417   {
418     std::string src = SSAGetID(PrintExpr(op->value), from_ty);
419     for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
420       std::ostringstream val;
421       val << "(";
422       PrintType(target_ty.element_of(), val);
423       val << ")(";
424       PrintVecElemLoad(src, from_ty, i, val);
425       val << ")";
426       PrintVecElemStore(sret, target_ty, i, val.str());
427     }
428   }
429   os << sret;
430 }
431 
PrintCallExtern(Type ret_type,String global_symbol,const Array<PrimExpr> & args,bool skip_first_arg,std::ostream & os)432 void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
433                                   bool skip_first_arg, std::ostream& os) {  // NOLINT(*)
434   DataType ret_dtype = GetRuntimeDataType(ret_type);
435   if (ret_dtype.is_vector()) {
436     //
437     // Emit an unsupported vector call
438     //
439     // v = intrin_f((float4*)A[0], (float4*)B[0])
440     //
441     // as
442     //
443     // float4 __ret;
444     // {
445     //   float4 __arg0 = ((float4*)A)[0];
446     //   float4 __arg1 = ((float4*)B)[0];
447     //   __ret.x = intrin_f(__arg0.x, __arg1.x);
448     //   __ret.y = intrin_f(__arg0.y, __arg1.y);
449     //   __ret.z = intrin_f(__arg0.z, __arg1.z);
450     //   __ret.w = intrin_f(__arg0.w, __arg1.w);
451     // }
452     // v = __ret;
453     //
454     // Declare the result vector.
455     std::string sret = GetUniqueName("_");
456     this->PrintIndent();
457     this->PrintType(ret_dtype, stream);
458     stream << ' ' << sret << ";\n";
459     {
460       // Load arguments.
461       std::vector<std::string> sargs;
462       size_t arg_begin = static_cast<size_t>(skip_first_arg);
463       for (size_t i = arg_begin; i < args.size(); ++i) {
464         std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
465         sargs.push_back(std::move(val));
466       }
467 
468       // Emit a scalar call for each lane.
469       for (int i = 0; i < ret_dtype.lanes(); ++i) {
470         std::ostringstream scall;
471         scall << global_symbol << "(";
472         for (size_t j = 0; j < sargs.size(); ++j) {
473           if (j > 0) scall << ", ";
474           PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
475         }
476         scall << ")";
477         PrintVecElemStore(sret, ret_dtype, i, scall.str());
478       }
479     }
480     os << sret;
481   } else {
482     CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
483   }
484 }
485 
VisitExpr_(const CallNode * op,std::ostream & os)486 void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
487   if (auto* ptr_op = op->op.as<OpNode>()) {
488     Op call_op = GetRef<Op>(ptr_op);
489     // This is only for backward compatibility with __shfl_{up/down}.
490     // A macro will be used to replace *_sync calls to legacy ones.
491     if (op_need_warp_shuffle_.get(call_op, false)) {
492       enable_warp_shuffle_ = true;
493     }
494   }
495 
496   if (op->op.same_as(builtin::tvm_fill_fragment())) {
497     need_mma_h_ = true;
498     CHECK_EQ(op->args.size(), 6U);
499     os << "nvcuda::wmma::fill_fragment(";
500     this->PrintExpr(op->args[0], os);
501     os << "[";
502     this->PrintExpr(op->args[4], os);
503     os << "], ";
504     this->PrintExpr(op->args[5], os);
505     os << ")";
506   } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
507     need_mma_h_ = true;
508     CHECK_EQ(op->args.size(), 8U);
509     os << "nvcuda::wmma::load_matrix_sync(";
510     this->PrintExpr(op->args[0], os);
511     os << "[";
512     this->PrintExpr(op->args[4], os);
513     os << "], ";
514     this->PrintExpr(op->args[5], os);
515     os << ", ";
516     this->PrintExpr(op->args[6], os);
517     os << ")";
518   } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
519     need_mma_h_ = true;
520     CHECK_EQ(op->args.size(), 8U);
521     os << "nvcuda::wmma::store_matrix_sync(";
522     this->PrintExpr(op->args[5], os);
523     os << ", ";
524     this->PrintExpr(op->args[0], os);
525     os << "[";
526     this->PrintExpr(op->args[4], os);
527     os << "], ";
528     this->PrintExpr(op->args[6], os);
529     if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
530       os << ", nvcuda::wmma::mem_" << str->value;
531     } else {
532       LOG(FATAL) << "Invalid parameters";
533     }
534     os << ")";
535   } else if (op->op.same_as(builtin::tvm_mma_sync())) {
536     need_mma_h_ = true;
537     CHECK_EQ(op->args.size(), 8U);
538     os << "nvcuda::wmma::mma_sync(";
539     for (int i = 0; i < 4; ++i) {
540       this->PrintExpr(op->args[i * 2], os);
541       os << "[";
542       this->PrintExpr(op->args[i * 2 + 1], os);
543       os << "]" << ((i < 3) ? ", " : ")");
544     }
545   } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
546     need_mma_h_ = true;
547     CHECK_EQ(op->args.size(), 8U);
548     os << "nvcuda::wmma::bmma_sync(";
549     for (int i = 0; i < 4; ++i) {
550       this->PrintExpr(op->args[i * 2], os);
551       os << "[";
552       this->PrintExpr(op->args[i * 2 + 1], os);
553       os << "]" << ((i < 3) ? ", " : ")");
554     }
555   } else {
556     CodeGenC::VisitExpr_(op, os);
557   }
558 }
559 
VisitStmt_(const AttrStmtNode * op)560 void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
561   if (op->attr_key == tir::attr::fragment_shape) {
562     const VarNode* buffer = op->node.as<VarNode>();
563     const StringImmNode* shape_str = op->value.as<StringImmNode>();
564     fragment_shapes[buffer] = shape_str->value;
565   } else if (op->attr_key == tir::attr::fragment_layout) {
566     const VarNode* buffer = op->node.as<VarNode>();
567     const StringImmNode* layout_str = op->value.as<StringImmNode>();
568     fragment_layouts[buffer] = layout_str->value;
569   }
570   CodeGenC::VisitStmt_(op);
571 }
572 
VisitStmt_(const AllocateNode * op)573 void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
574   CHECK(!is_zero(op->condition));
575   std::string vid = AllocVarID(op->buffer_var.get());
576 
577   this->PrintIndent();
578   int32_t constant_size = op->constant_allocation_size();
579   CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
580   const VarNode* buffer = op->buffer_var.as<VarNode>();
581   std::string scope = alloc_storage_scope_.at(buffer);
582   if (scope.find("wmma.") == 0) {
583     if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
584       CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
585             op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
586             op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1))
587           << "Matrix_a and matrix_b only support half or char or unsigned char "
588           << "or uint4 or int4 or int1 type for now";
589     } else {
590       CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
591             op->dtype == DataType::Int(32))
592           << "Accumulator only support half, float and int type for now";
593     }
594     constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
595     PrintWmmaScope(scope, op->dtype, buffer, stream);
596   } else {
597     PrintStorageScope(scope, stream);
598     PrintType(op->dtype, stream);
599   }
600   if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
601        op->dtype == DataType::Int(1)) &&
602       scope == "shared") {
603     constant_size = constant_size / (32 / op->dtype.bits());
604   }
605   stream << ' ' << vid << '[' << constant_size << "];\n";
606 
607   RegisterHandleType(op->buffer_var.get(), op->dtype);
608   this->PrintStmt(op->body);
609 }
610 
VisitStmt_(const EvaluateNode * op)611 void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
612   if (is_const_int(op->value)) return;
613   const CallNode* call = op->value.as<CallNode>();
614   if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
615     PrintIndent();
616     stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
617     PrintIndent();
618     stream << "if (threadIdx.x == 0) {\n";
619     PrintIndent();
620     stream << "  " << vid_global_barrier_expect_ << " = 0;\n";
621     PrintIndent();
622     stream << "}\n";
623   } else {
624     CodeGenC::VisitStmt_(op);
625   }
626 }
627 
VisitExpr_(const RampNode * op,std::ostream & os)628 void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
629   os << "((make_int" << op->lanes << ")(";
630   for (int i = 0; i < op->lanes; i++) {
631     os << "(" << PrintExpr(op->base) << ")"
632        << "+(" << PrintExpr(op->stride) << "*" << i << ")";
633     if (i != op->lanes - 1) os << ", ";
634   }
635   os << "))";
636 }
637 
VisitExpr_(const BroadcastNode * op,std::ostream & os)638 void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) {  // NOLINT(*)
639   if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) {
640     // make_int8x4
641     const int64_t* p = as_const_int(op->value);
642     CHECK(p);
643     int64_t v = *p & 0xFF;
644     v = (v << 24) | (v << 16) | (v << 8) | v;
645     if (op->dtype.is_uint()) {
646       os << "(uint)" << v;
647     } else {
648       os << "(int)" << v;
649     }
650     return;
651   }
652 
653   if (op->dtype.is_float16()) {
654     std::string v = PrintExpr(op->value);
655     os << "make_";
656     PrintType(op->dtype, os);
657     os << '(';
658     for (int i = 0; i < op->lanes / 2; ++i) {
659       if (i != 0) os << ", ";
660       os << "__pack_half2(" << v << ", " << v << ")";
661     }
662     os << ')';
663     return;
664   }
665 
666   std::string v = PrintExpr(op->value);
667   os << "make_";
668   PrintType(op->dtype, os);
669   os << '(';
670   for (int i = 0; i < op->lanes; ++i) {
671     if (i != 0) os << ", ";
672     os << v;
673   }
674   os << ')';
675 }
676 
VisitExpr_(const ShuffleNode * op,std::ostream & os)677 void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
678   std::vector<std::string> to_shuffle(op->vectors.size());
679   for (int i = 0, e = op->vectors.size(); i < e; ++i) {
680     CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
681     to_shuffle[i] = PrintExpr(op->vectors[i]);
682   }
683   os << "make_";
684   PrintType(op->dtype, os);
685   os << '(';
686   for (int i = 0, e = op->indices.size(); i < e; ++i) {
687     const int64_t* val = as_const_int(op->indices[i]);
688     CHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
689     if (i != 0) os << ", ";
690     os << to_shuffle[*val];
691   }
692   os << ')';
693 }
694 
VisitExpr_(const SelectNode * op,std::ostream & os)695 void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
696   // Non-vector cases.
697   if (!op->dtype.is_vector()) {
698     CodeGenC::VisitExpr_(op, os);
699     return;
700   }
701 
702   // Codegen vector condition case by serializing the select op.
703   CHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype &&
704         op->dtype.lanes() == op->condition.dtype().lanes());
705 
706   std::string r_var = GetUniqueName("_");
707   this->PrintIndent();
708   this->PrintType(op->dtype, stream);
709   stream << ' ' << r_var << ";\n";
710   {
711     std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
712     std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
713     std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
714 
715     // The condition is stored as an ushort vector.
716     int lanes = op->dtype.lanes();
717     DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);
718 
719     for (int i = 0; i < lanes; ++i) {
720       std::ostringstream item;
721       item << "(bool(";
722       PrintVecElemLoad(c_var, memory_ty, i, item);
723       item << ")?";
724       PrintVecElemLoad(t_var, op->dtype, i, item);
725       item << ':';
726       PrintVecElemLoad(f_var, op->dtype, i, item);
727       item << ')';
728       PrintVecElemStore(r_var, op->dtype, i, item.str());
729     }
730   }
731   os << r_var;
732 }
733 
PrintConst(const FloatImmNode * op,std::ostream & os,CodeGenCUDA * p)734 inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) {  // NOLINT(*)
735   switch (op->dtype.bits()) {
736     case 64:
737     case 32: {
738       std::ostringstream temp;
739       if (std::isinf(op->value)) {
740         if (op->value < 0) {
741           temp << "-";
742         }
743         temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
744         p->need_math_constants_h_ = true;
745       } else if (std::isnan(op->value)) {
746         temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
747         p->need_math_constants_h_ = true;
748       } else {
749         temp << std::scientific << op->value;
750         if (op->dtype.bits() == 32) temp << 'f';
751       }
752       p->MarkConst(temp.str());
753       os << temp.str();
754       break;
755     }
756     case 16: {
757       os << "__float2half_rn";
758       os << '(' << std::scientific << op->value << 'f' << ')';
759       break;
760     }
761     default:
762       LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
763   }
764 }
765 
VisitExpr_(const FloatImmNode * op,std::ostream & os)766 void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) {  // NOLINT(*)
767   PrintConst(op, os, this);
768 }
769 
PrintWmmaScope(const std::string & scope,DataType t,const VarNode * variable,std::ostream & os)770 void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
771                                  std::ostream& os) {
772   std::stringstream type;
773   PrintType(t, type);
774   std::string shape_str = fragment_shapes[variable];
775   if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
776     type.str(std::string());
777     if (t.is_int()) {
778       if (t.bits() == 4) {
779         type << "nvcuda::wmma::experimental::precision::s4";
780       } else if (t.bits() == 1) {
781         type << "nvcuda::wmma::experimental::precision::b1";
782       } else {
783         LOG(FATAL) << "Unhandled interger type for wmma fragment!";
784       }
785     } else if (t.is_uint()) {
786       if (t.bits() == 4) {
787         type << "nvcuda::wmma::experimental::precision::u4";
788       } else {
789         LOG(FATAL) << "Unhandled interger type for wmma fragment!";
790       }
791     }
792   }
793   if (scope == "wmma.matrix_a") {
794     need_mma_h_ = true;
795     std::string layout_str = fragment_layouts[variable];
796     os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
797        << ", nvcuda::wmma::" << layout_str << ">";
798   } else if (scope == "wmma.matrix_b") {
799     need_mma_h_ = true;
800     std::string layout_str = fragment_layouts[variable];
801     os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
802        << ", nvcuda::wmma::" << layout_str << ">";
803   } else if (scope == "wmma.accumulator") {
804     need_mma_h_ = true;
805     os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
806        << ">";
807   }
808 }
809 
GetWmmaFragmentSize(const std::string & scope,const VarNode * variable,int32_t size)810 int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
811                                          int32_t size) {
812   std::string shape_str = fragment_shapes[variable];
813   size_t m, n, k;
814   size_t last_pos = 0, pos = 0;
815   pos = shape_str.find(", ", last_pos);
816   m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
817   last_pos = pos + 2;
818   pos = shape_str.find(", ", last_pos);
819   n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
820   last_pos = pos + 2;
821   k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
822   if (scope == "wmma.matrix_a") {
823     return size / m / k;
824   } else if (scope == "wmma.matrix_b") {
825     return size / n / k;
826   } else if (scope == "wmma.accumulator") {
827     return size / m / n;
828   }
829   return 0;
830 }
831 
HandleVolatileLoads(const std::string & value,const LoadNode * op,std::ostream & os)832 void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op,
833                                       std::ostream& os) {
834   // Cast away volatile qualifier for fp16 types. That is, only loads and
835   // stores are volatile. The loaded objects are not marked as volatile.
836   //
837   if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
838     os << "(";
839     PrintType(op->dtype, os);
840     os << ")(" << value << ")";
841   } else {
842     os << value;
843   }
844 }
845 
PrintVecElemLoadExpr(DataType t,int i,const std::string & value,std::ostream & os)846 void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
847                                        std::ostream& os) {
848   CHECK_GT(t.lanes(), 1);
849   if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
850     if (!(t.lanes() == 2 || t.lanes() == 3)) {
851       if (i != 0) {
852         os << "|";
853       }
854       os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
855       return;
856     }
857   }
858 
859   if (t.is_float16()) {
860     if (i == 0) {
861       os << "make_";
862       PrintType(t, os);
863       os << '(';
864     }
865     if (i % 2 == 0) {
866       os << "__pack_half2(" << value;
867     } else {
868       os << "," << value << ")";
869       if (i != t.lanes() - 1) {
870         os << ",";
871       } else {
872         os << ")";
873       }
874     }
875     return;
876   }
877 
878   if (i == 0) {
879     os << "make_";
880     PrintType(t, os);
881     os << "(";
882   }
883   os << value;
884   if (i != t.lanes() - 1) {
885     os << ",";
886   } else {
887     os << ")";
888   }
889   return;
890 }
891 
892 }  // namespace codegen
893 }  // namespace tvm
894