1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file codegen_cuda.cc
22 */
23
24 #include "codegen_cuda.h"
25
26 #include <tvm/runtime/registry.h>
27
28 #include <cmath>
29 #include <string>
30 #include <utility>
31 #include <vector>
32
33 #include "literal/cuda_half_t.h"
34
35 namespace tvm {
36 namespace codegen {
37
CodeGenCUDA()38 CodeGenCUDA::CodeGenCUDA() { restrict_keyword_ = "__restrict__"; }
39
Init(bool output_ssa)40 void CodeGenCUDA::Init(bool output_ssa) {
41 CodeGenC::Init(output_ssa);
42 vid_global_barrier_state_ = GetUniqueName(runtime::symbol::tvm_global_barrier_state);
43 vid_global_barrier_expect_ = GetUniqueName("__barrier_expect");
44 CHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
45 }
46
PrintFuncPrefix()47 void CodeGenCUDA::PrintFuncPrefix() { stream << "extern \"C\" __global__ void"; }
48
Finish()49 std::string CodeGenCUDA::Finish() {
50 if (enable_fp16_) {
51 decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)\n";
52 decl_stream << "#include <cuda_fp16.h>\n";
53 decl_stream << "__device__ half max"
54 << "(half a, half b)\n"
55 << "{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
56 decl_stream << "__device__ half min(half a, half b)\n"
57 << "{\n return __hlt(__half(a), __half(b)) ? a : b;\n}\n";
58 decl_stream << "#else\n";
59 decl_stream << _cuda_half_t_def;
60 decl_stream << "#endif\n\n";
61 decl_stream << _cuda_half_util;
62 }
63
64 if (enable_warp_shuffle_) {
65 decl_stream << _cuda_warp_intrinsic_util;
66 }
67
68 if (enable_int8_) {
69 decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
70 decl_stream << "#include <sm_61_intrinsics.h>\n";
71 decl_stream << "#endif\n";
72 }
73
74 if (need_math_constants_h_) {
75 decl_stream << "#include <math_constants.h>\n";
76 }
77
78 if (need_mma_h_) {
79 decl_stream << "#include <mma.h>\n";
80 }
81
82 return CodeGenC::Finish();
83 }
84
VisitStmt_(const tir::ForNode * op)85 void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) {
86 CHECK(is_const_int(op->min, 0));
87 if (op->for_type == tir::ForType::Unrolled) {
88 PrintIndent();
89 stream << "#pragma unroll\n";
90 }
91 CodeGenC::VisitStmt_(op);
92 }
93
BindThreadIndex(const IterVar & iv)94 void CodeGenCUDA::BindThreadIndex(const IterVar& iv) {
95 CHECK(!var_idmap_.count(iv->var.get()));
96 var_idmap_[iv->var.get()] = CastFromTo(iv->thread_tag, DataType::UInt(32), iv->var.dtype());
97 }
98
PrintType(DataType t,std::ostream & os)99 void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
100 int lanes = t.lanes();
101 if (t.is_handle()) {
102 CHECK_EQ(lanes, 1) << "do not yet support vector types";
103 os << "void*";
104 return;
105 }
106 bool fail = false;
107 if (t.is_float()) {
108 switch (t.bits()) {
109 case 16:
110 enable_fp16_ = true;
111 if (lanes == 1) {
112 os << "half";
113 } else if (lanes <= 8) {
114 // Emit CUDA code to access fp16 vector elements.
115 //
116 // half4 is stored as uint2
117 //
118 // h4.x is emitted as *(half2*)(&(u2.x)).x
119 // h4.y is emitted as *(half2*)(&(u2.x)).y
120 // h4.z is emitted as *(half2*)(&(u2.y)).x
121 // h4.w is emitted as *(half2*)(&(u2.y)).y
122 //
123 CHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
124 os << "uint" << lanes / 2;
125 } else {
126 fail = true;
127 }
128 break;
129 case 32:
130 os << "float";
131 break;
132 case 64:
133 os << "double";
134 break;
135 default:
136 fail = true;
137 break;
138 }
139 if (!fail && (lanes == 1 || t.bits() == 16)) return;
140 if (!fail && (lanes >= 2 && lanes <= 4)) {
141 os << lanes;
142 return;
143 }
144 } else if (t == DataType::Bool()) {
145 os << "bool";
146 return;
147 } else if (t.is_vector_bool()) {
148 // CUDA does not support bool vectors.
149 // Use ushort vectors to represent instead.
150 int n = t.lanes();
151 if (n <= 4) {
152 os << "ushort" << n;
153 return;
154 }
155 } else if (t.is_uint() || t.is_int()) {
156 if (t.is_uint()) {
157 if (t.lanes() != 1) {
158 os << "u";
159 } else {
160 os << "unsigned ";
161 }
162 }
163 switch (t.bits()) {
164 case 1: {
165 if (t.lanes() == 1) {
166 os << "int";
167 return;
168 } else if (t.lanes() == 8) {
169 os << "int8_t";
170 return;
171 } else if (t.lanes() == 16) {
172 os << "int16_t";
173 return;
174 } else if (t.lanes() == 32) {
175 os << "int";
176 return;
177 } else {
178 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
179 }
180 }
181 case 4: {
182 if (t.lanes() == 1) {
183 os << "int";
184 return;
185 } else if (t.lanes() == 4) {
186 os << "int16_t";
187 return;
188 } else if (t.lanes() == 8) {
189 // directly 8 4-bit int in integer.
190 os << "int";
191 return;
192 } else if (t.lanes() == 16) {
193 os << "int2";
194 return;
195 } else if (t.lanes() == 32) {
196 os << "int4";
197 return;
198 } else if (t.lanes() == 64) {
199 os << "int8";
200 return;
201 } else {
202 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type!";
203 }
204 }
205 case 8: {
206 if (t.lanes() == 4) {
207 // directly 4 8 bit int in integer.
208 enable_int8_ = true;
209
210 // We use int for int8x4 instead of char4 because using char4 is
211 // likely to produce extra instructions to pack four int8 elements
212 // into 32-bit data.
213 os << "int";
214 return;
215 } else if (t.lanes() == 8) {
216 enable_int8_ = true;
217 os << "int2";
218 return;
219 } else if (t.lanes() == 16) {
220 enable_int8_ = true;
221 os << "int4";
222 return;
223 } else if (!t.is_uint() && t.lanes() == 1) {
224 os << "signed char";
225 break;
226 } else {
227 os << "char";
228 break;
229 }
230 }
231 case 16:
232 os << "short";
233 break;
234 case 32:
235 os << "int";
236 break;
237 case 64: {
238 if (sizeof(long) != 8) { // NOLINT(*)
239 if (t.lanes() == 1) {
240 os << "long long";
241 break;
242 } else if (t.lanes() == 2) {
243 os << "longlong";
244 break;
245 } else {
246 // No longlong3, longlong4
247 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type on a L32 platform";
248 break;
249 }
250 } else {
251 os << "long";
252 break;
253 }
254 }
255 default:
256 fail = true;
257 break;
258 }
259 if (!fail && lanes == 1) {
260 return;
261 }
262 if (!fail && (lanes >= 2 && lanes <= 4)) {
263 os << lanes;
264 return;
265 }
266 }
267 LOG(FATAL) << "Cannot convert type " << t << " to CUDA type";
268 }
269
PrintVecBinaryOp(const std::string & op,DataType t,PrimExpr lhs,PrimExpr rhs,std::ostream & os)270 void CodeGenCUDA::PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
271 std::ostream& os) { // NOLINT(*)
272 // Delcare the result.
273 std::string sret = GetUniqueName("_");
274 this->PrintIndent();
275 this->PrintType(t, stream);
276 stream << ' ' << sret << ";\n";
277 {
278 // Unpack into individual ops.
279 std::string vlhs = SSAGetID(PrintExpr(lhs), lhs.dtype());
280 std::string vrhs = SSAGetID(PrintExpr(rhs), rhs.dtype());
281
282 for (int i = 0, lanes = t.lanes(); i < lanes; ++i) {
283 std::ostringstream value_temp;
284 if (isalpha(op[0])) {
285 value_temp << op << "(";
286 PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
287 value_temp << ", ";
288 PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
289 value_temp << ")";
290 } else {
291 value_temp << "(";
292 PrintVecElemLoad(vlhs, lhs.dtype(), i, value_temp);
293 value_temp << op;
294 PrintVecElemLoad(vrhs, rhs.dtype(), i, value_temp);
295 value_temp << ")";
296 }
297 PrintVecElemStore(sret, t, i, value_temp.str());
298 }
299 }
300 os << sret;
301 }
302
PrintVecElemLoad(const std::string & vec,DataType t,int i,std::ostream & os)303 void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
304 std::ostream& os) { // NOLINT(*)
305 if (t.is_scalar()) {
306 os << vec;
307 return;
308 }
309
310 static const char access[] = {'x', 'y', 'z', 'w'};
311 CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
312 if ((t.is_int()) && t.bits() == 8) {
313 if (t.lanes() == 2 || t.lanes() == 3) {
314 os << vec << "." << access[i % t.lanes()];
315 } else {
316 os << "((char)(" << vec << " >> " << i * 8 << "))";
317 }
318 } else if ((t.is_uint()) && t.bits() == 8) {
319 if (t.lanes() == 2 || t.lanes() == 3) {
320 os << vec << "." << access[i % t.lanes()];
321 } else {
322 os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
323 }
324 } else if (t.is_float16()) {
325 os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
326 } else {
327 os << vec << "." << access[i];
328 }
329 }
330
PrintVecElemStore(const std::string & vec,DataType t,int i,const std::string & value)331 void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
332 const std::string& value) {
333 this->PrintIndent();
334 static const char access[] = {'x', 'y', 'z', 'w'};
335 CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
336 if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
337 if (t.lanes() == 2 || t.lanes() == 3) {
338 stream << vec << '.' << access[i % t.lanes()] << "="
339 << "(" << value << ");\n";
340 } else {
341 stream << vec << "=";
342 // Do not read the first undef lane.
343 if (i != 0) {
344 stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
345 }
346 stream << "(" << value << " << " << i * 8 << ");\n";
347 }
348 } else if (t.is_float16()) {
349 stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
350 << value << ";\n";
351 } else {
352 stream << vec << "." << access[i] << " = " << value << ";\n";
353 }
354 }
355
PrintStorageSync(const CallNode * op)356 void CodeGenCUDA::PrintStorageSync(const CallNode* op) {
357 const std::string& sync = op->args[0].as<StringImmNode>()->value;
358 if (sync == "warp") {
359 // DO nothing.
360 } else if (sync == "shared") {
361 this->PrintIndent();
362 this->stream << "__syncthreads();\n";
363 } else if (sync == "global") {
364 if (!need_global_barrier_) {
365 need_global_barrier_ = true;
366 this->decl_stream << "extern \"C\" __device__ unsigned " << vid_global_barrier_state_
367 << ";\n";
368 }
369 // global synchronizer
370 std::string is_load = PrintExpr(op->args[1]);
371 std::string num_blocks = PrintExpr(op->args[2]);
372 this->PrintIndent();
373 // In theory only threadfence is needed
374 // but we observed problems with only threadfence
375 this->stream << "__threadfence_system();\n";
376 this->PrintIndent();
377 this->stream << "if (" << is_load << ") {\n";
378 int wb = this->BeginScope();
379 this->PrintIndent();
380 this->stream << "atomicAdd(&" << vid_global_barrier_state_ << ", 1);\n";
381 this->PrintIndent();
382 std::string ptr = GetUniqueName("pf");
383 this->stream << "volatile unsigned* " << ptr << " = &" << vid_global_barrier_state_ << ";\n";
384 this->PrintIndent();
385 this->stream << vid_global_barrier_expect_ << " += " << num_blocks << ";\n";
386 this->PrintIndent();
387 this->stream << "while (" << ptr << "[0] < " << vid_global_barrier_expect_ << ");\n";
388 this->EndScope(wb);
389 this->PrintIndent();
390 this->stream << "}\n";
391 this->PrintIndent();
392 this->stream << "__syncthreads();\n";
393 }
394 }
395
PrintStorageScope(const std::string & scope,std::ostream & os)396 void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) { // NOLINT(*)
397 CHECK_NE(scope, "global");
398 if (scope == "shared") {
399 os << "__shared__ ";
400 }
401 }
402
VisitExpr_(const CastNode * op,std::ostream & os)403 void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
404 DataType from_ty = op->value.dtype();
405 DataType target_ty = op->dtype;
406 CHECK_EQ(target_ty.lanes(), from_ty.lanes());
407
408 // Emit simple C-style type conversion.
409 if (from_ty.is_scalar()) return CodeGenC::VisitExpr_(op, os);
410
411 // We could emit make_float4 like calls, but the emitted code looks
412 // too compact to read. Emit this as vectorized unary ops.
413 std::string sret = GetUniqueName("_");
414 this->PrintIndent();
415 this->PrintType(target_ty, stream);
416 stream << ' ' << sret << ";\n";
417 {
418 std::string src = SSAGetID(PrintExpr(op->value), from_ty);
419 for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) {
420 std::ostringstream val;
421 val << "(";
422 PrintType(target_ty.element_of(), val);
423 val << ")(";
424 PrintVecElemLoad(src, from_ty, i, val);
425 val << ")";
426 PrintVecElemStore(sret, target_ty, i, val.str());
427 }
428 }
429 os << sret;
430 }
431
PrintCallExtern(Type ret_type,String global_symbol,const Array<PrimExpr> & args,bool skip_first_arg,std::ostream & os)432 void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
433 bool skip_first_arg, std::ostream& os) { // NOLINT(*)
434 DataType ret_dtype = GetRuntimeDataType(ret_type);
435 if (ret_dtype.is_vector()) {
436 //
437 // Emit an unsupported vector call
438 //
439 // v = intrin_f((float4*)A[0], (float4*)B[0])
440 //
441 // as
442 //
443 // float4 __ret;
444 // {
445 // float4 __arg0 = ((float4*)A)[0];
446 // float4 __arg1 = ((float4*)B)[0];
447 // __ret.x = intrin_f(__arg0.x, __arg1.x);
448 // __ret.y = intrin_f(__arg0.y, __arg1.y);
449 // __ret.z = intrin_f(__arg0.z, __arg1.z);
450 // __ret.w = intrin_f(__arg0.w, __arg1.w);
451 // }
452 // v = __ret;
453 //
454 // Declare the result vector.
455 std::string sret = GetUniqueName("_");
456 this->PrintIndent();
457 this->PrintType(ret_dtype, stream);
458 stream << ' ' << sret << ";\n";
459 {
460 // Load arguments.
461 std::vector<std::string> sargs;
462 size_t arg_begin = static_cast<size_t>(skip_first_arg);
463 for (size_t i = arg_begin; i < args.size(); ++i) {
464 std::string val = SSAGetID(PrintExpr(args[i]), args[i].dtype());
465 sargs.push_back(std::move(val));
466 }
467
468 // Emit a scalar call for each lane.
469 for (int i = 0; i < ret_dtype.lanes(); ++i) {
470 std::ostringstream scall;
471 scall << global_symbol << "(";
472 for (size_t j = 0; j < sargs.size(); ++j) {
473 if (j > 0) scall << ", ";
474 PrintVecElemLoad(sargs[j], args[arg_begin + j].dtype(), i, scall);
475 }
476 scall << ")";
477 PrintVecElemStore(sret, ret_dtype, i, scall.str());
478 }
479 }
480 os << sret;
481 } else {
482 CodeGenC::PrintCallExtern(ret_type, global_symbol, args, skip_first_arg, os);
483 }
484 }
485
VisitExpr_(const CallNode * op,std::ostream & os)486 void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
487 if (auto* ptr_op = op->op.as<OpNode>()) {
488 Op call_op = GetRef<Op>(ptr_op);
489 // This is only for backward compatibility with __shfl_{up/down}.
490 // A macro will be used to replace *_sync calls to legacy ones.
491 if (op_need_warp_shuffle_.get(call_op, false)) {
492 enable_warp_shuffle_ = true;
493 }
494 }
495
496 if (op->op.same_as(builtin::tvm_fill_fragment())) {
497 need_mma_h_ = true;
498 CHECK_EQ(op->args.size(), 6U);
499 os << "nvcuda::wmma::fill_fragment(";
500 this->PrintExpr(op->args[0], os);
501 os << "[";
502 this->PrintExpr(op->args[4], os);
503 os << "], ";
504 this->PrintExpr(op->args[5], os);
505 os << ")";
506 } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) {
507 need_mma_h_ = true;
508 CHECK_EQ(op->args.size(), 8U);
509 os << "nvcuda::wmma::load_matrix_sync(";
510 this->PrintExpr(op->args[0], os);
511 os << "[";
512 this->PrintExpr(op->args[4], os);
513 os << "], ";
514 this->PrintExpr(op->args[5], os);
515 os << ", ";
516 this->PrintExpr(op->args[6], os);
517 os << ")";
518 } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) {
519 need_mma_h_ = true;
520 CHECK_EQ(op->args.size(), 8U);
521 os << "nvcuda::wmma::store_matrix_sync(";
522 this->PrintExpr(op->args[5], os);
523 os << ", ";
524 this->PrintExpr(op->args[0], os);
525 os << "[";
526 this->PrintExpr(op->args[4], os);
527 os << "], ";
528 this->PrintExpr(op->args[6], os);
529 if (const StringImmNode* str = op->args[7].as<StringImmNode>()) {
530 os << ", nvcuda::wmma::mem_" << str->value;
531 } else {
532 LOG(FATAL) << "Invalid parameters";
533 }
534 os << ")";
535 } else if (op->op.same_as(builtin::tvm_mma_sync())) {
536 need_mma_h_ = true;
537 CHECK_EQ(op->args.size(), 8U);
538 os << "nvcuda::wmma::mma_sync(";
539 for (int i = 0; i < 4; ++i) {
540 this->PrintExpr(op->args[i * 2], os);
541 os << "[";
542 this->PrintExpr(op->args[i * 2 + 1], os);
543 os << "]" << ((i < 3) ? ", " : ")");
544 }
545 } else if (op->op.same_as(builtin::tvm_bmma_sync())) {
546 need_mma_h_ = true;
547 CHECK_EQ(op->args.size(), 8U);
548 os << "nvcuda::wmma::bmma_sync(";
549 for (int i = 0; i < 4; ++i) {
550 this->PrintExpr(op->args[i * 2], os);
551 os << "[";
552 this->PrintExpr(op->args[i * 2 + 1], os);
553 os << "]" << ((i < 3) ? ", " : ")");
554 }
555 } else {
556 CodeGenC::VisitExpr_(op, os);
557 }
558 }
559
VisitStmt_(const AttrStmtNode * op)560 void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
561 if (op->attr_key == tir::attr::fragment_shape) {
562 const VarNode* buffer = op->node.as<VarNode>();
563 const StringImmNode* shape_str = op->value.as<StringImmNode>();
564 fragment_shapes[buffer] = shape_str->value;
565 } else if (op->attr_key == tir::attr::fragment_layout) {
566 const VarNode* buffer = op->node.as<VarNode>();
567 const StringImmNode* layout_str = op->value.as<StringImmNode>();
568 fragment_layouts[buffer] = layout_str->value;
569 }
570 CodeGenC::VisitStmt_(op);
571 }
572
VisitStmt_(const AllocateNode * op)573 void CodeGenCUDA::VisitStmt_(const AllocateNode* op) {
574 CHECK(!is_zero(op->condition));
575 std::string vid = AllocVarID(op->buffer_var.get());
576
577 this->PrintIndent();
578 int32_t constant_size = op->constant_allocation_size();
579 CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now";
580 const VarNode* buffer = op->buffer_var.as<VarNode>();
581 std::string scope = alloc_storage_scope_.at(buffer);
582 if (scope.find("wmma.") == 0) {
583 if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") {
584 CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) ||
585 op->dtype == DataType::UInt(8) || op->dtype == DataType::Int(4) ||
586 op->dtype == DataType::UInt(4) || op->dtype == DataType::Int(1))
587 << "Matrix_a and matrix_b only support half or char or unsigned char "
588 << "or uint4 or int4 or int1 type for now";
589 } else {
590 CHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Float(32) ||
591 op->dtype == DataType::Int(32))
592 << "Accumulator only support half, float and int type for now";
593 }
594 constant_size = GetWmmaFragmentSize(scope, buffer, constant_size);
595 PrintWmmaScope(scope, op->dtype, buffer, stream);
596 } else {
597 PrintStorageScope(scope, stream);
598 PrintType(op->dtype, stream);
599 }
600 if ((op->dtype == DataType::Int(4) || op->dtype == DataType::UInt(4) ||
601 op->dtype == DataType::Int(1)) &&
602 scope == "shared") {
603 constant_size = constant_size / (32 / op->dtype.bits());
604 }
605 stream << ' ' << vid << '[' << constant_size << "];\n";
606
607 RegisterHandleType(op->buffer_var.get(), op->dtype);
608 this->PrintStmt(op->body);
609 }
610
VisitStmt_(const EvaluateNode * op)611 void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
612 if (is_const_int(op->value)) return;
613 const CallNode* call = op->value.as<CallNode>();
614 if (call && call->op.same_as(builtin::tvm_global_barrier_kinit())) {
615 PrintIndent();
616 stream << "__shared__ unsigned " << vid_global_barrier_expect_ << ";\n";
617 PrintIndent();
618 stream << "if (threadIdx.x == 0) {\n";
619 PrintIndent();
620 stream << " " << vid_global_barrier_expect_ << " = 0;\n";
621 PrintIndent();
622 stream << "}\n";
623 } else {
624 CodeGenC::VisitStmt_(op);
625 }
626 }
627
VisitExpr_(const RampNode * op,std::ostream & os)628 void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
629 os << "((make_int" << op->lanes << ")(";
630 for (int i = 0; i < op->lanes; i++) {
631 os << "(" << PrintExpr(op->base) << ")"
632 << "+(" << PrintExpr(op->stride) << "*" << i << ")";
633 if (i != op->lanes - 1) os << ", ";
634 }
635 os << "))";
636 }
637
VisitExpr_(const BroadcastNode * op,std::ostream & os)638 void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NOLINT(*)
639 if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && op->lanes == 4) {
640 // make_int8x4
641 const int64_t* p = as_const_int(op->value);
642 CHECK(p);
643 int64_t v = *p & 0xFF;
644 v = (v << 24) | (v << 16) | (v << 8) | v;
645 if (op->dtype.is_uint()) {
646 os << "(uint)" << v;
647 } else {
648 os << "(int)" << v;
649 }
650 return;
651 }
652
653 if (op->dtype.is_float16()) {
654 std::string v = PrintExpr(op->value);
655 os << "make_";
656 PrintType(op->dtype, os);
657 os << '(';
658 for (int i = 0; i < op->lanes / 2; ++i) {
659 if (i != 0) os << ", ";
660 os << "__pack_half2(" << v << ", " << v << ")";
661 }
662 os << ')';
663 return;
664 }
665
666 std::string v = PrintExpr(op->value);
667 os << "make_";
668 PrintType(op->dtype, os);
669 os << '(';
670 for (int i = 0; i < op->lanes; ++i) {
671 if (i != 0) os << ", ";
672 os << v;
673 }
674 os << ')';
675 }
676
VisitExpr_(const ShuffleNode * op,std::ostream & os)677 void CodeGenCUDA::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
678 std::vector<std::string> to_shuffle(op->vectors.size());
679 for (int i = 0, e = op->vectors.size(); i < e; ++i) {
680 CHECK(op->vectors[i].dtype().lanes() == 1) << "Only scalars can be shuffled in CUDA!";
681 to_shuffle[i] = PrintExpr(op->vectors[i]);
682 }
683 os << "make_";
684 PrintType(op->dtype, os);
685 os << '(';
686 for (int i = 0, e = op->indices.size(); i < e; ++i) {
687 const int64_t* val = as_const_int(op->indices[i]);
688 CHECK(val && *val >= 0 && (int)*val < (int)to_shuffle.size());
689 if (i != 0) os << ", ";
690 os << to_shuffle[*val];
691 }
692 os << ')';
693 }
694
VisitExpr_(const SelectNode * op,std::ostream & os)695 void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
696 // Non-vector cases.
697 if (!op->dtype.is_vector()) {
698 CodeGenC::VisitExpr_(op, os);
699 return;
700 }
701
702 // Codegen vector condition case by serializing the select op.
703 CHECK(op->false_value->dtype == op->dtype && op->true_value->dtype == op->dtype &&
704 op->dtype.lanes() == op->condition.dtype().lanes());
705
706 std::string r_var = GetUniqueName("_");
707 this->PrintIndent();
708 this->PrintType(op->dtype, stream);
709 stream << ' ' << r_var << ";\n";
710 {
711 std::string c_var = SSAGetID(PrintExpr(op->condition), op->dtype);
712 std::string t_var = SSAGetID(PrintExpr(op->true_value), op->dtype);
713 std::string f_var = SSAGetID(PrintExpr(op->false_value), op->dtype);
714
715 // The condition is stored as an ushort vector.
716 int lanes = op->dtype.lanes();
717 DataType memory_ty(DataType::TypeCode::kUInt, 16, lanes);
718
719 for (int i = 0; i < lanes; ++i) {
720 std::ostringstream item;
721 item << "(bool(";
722 PrintVecElemLoad(c_var, memory_ty, i, item);
723 item << ")?";
724 PrintVecElemLoad(t_var, op->dtype, i, item);
725 item << ':';
726 PrintVecElemLoad(f_var, op->dtype, i, item);
727 item << ')';
728 PrintVecElemStore(r_var, op->dtype, i, item.str());
729 }
730 }
731 os << r_var;
732 }
733
PrintConst(const FloatImmNode * op,std::ostream & os,CodeGenCUDA * p)734 inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) { // NOLINT(*)
735 switch (op->dtype.bits()) {
736 case 64:
737 case 32: {
738 std::ostringstream temp;
739 if (std::isinf(op->value)) {
740 if (op->value < 0) {
741 temp << "-";
742 }
743 temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
744 p->need_math_constants_h_ = true;
745 } else if (std::isnan(op->value)) {
746 temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
747 p->need_math_constants_h_ = true;
748 } else {
749 temp << std::scientific << op->value;
750 if (op->dtype.bits() == 32) temp << 'f';
751 }
752 p->MarkConst(temp.str());
753 os << temp.str();
754 break;
755 }
756 case 16: {
757 os << "__float2half_rn";
758 os << '(' << std::scientific << op->value << 'f' << ')';
759 break;
760 }
761 default:
762 LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
763 }
764 }
765
VisitExpr_(const FloatImmNode * op,std::ostream & os)766 void CodeGenCUDA::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
767 PrintConst(op, os, this);
768 }
769
PrintWmmaScope(const std::string & scope,DataType t,const VarNode * variable,std::ostream & os)770 void CodeGenCUDA::PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
771 std::ostream& os) {
772 std::stringstream type;
773 PrintType(t, type);
774 std::string shape_str = fragment_shapes[variable];
775 if ((t.is_int() || t.is_uint()) && t.bits() < 8 && t.lanes() == 1) {
776 type.str(std::string());
777 if (t.is_int()) {
778 if (t.bits() == 4) {
779 type << "nvcuda::wmma::experimental::precision::s4";
780 } else if (t.bits() == 1) {
781 type << "nvcuda::wmma::experimental::precision::b1";
782 } else {
783 LOG(FATAL) << "Unhandled interger type for wmma fragment!";
784 }
785 } else if (t.is_uint()) {
786 if (t.bits() == 4) {
787 type << "nvcuda::wmma::experimental::precision::u4";
788 } else {
789 LOG(FATAL) << "Unhandled interger type for wmma fragment!";
790 }
791 }
792 }
793 if (scope == "wmma.matrix_a") {
794 need_mma_h_ = true;
795 std::string layout_str = fragment_layouts[variable];
796 os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, " << shape_str << ", " << type.str()
797 << ", nvcuda::wmma::" << layout_str << ">";
798 } else if (scope == "wmma.matrix_b") {
799 need_mma_h_ = true;
800 std::string layout_str = fragment_layouts[variable];
801 os << "nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, " << shape_str << ", " << type.str()
802 << ", nvcuda::wmma::" << layout_str << ">";
803 } else if (scope == "wmma.accumulator") {
804 need_mma_h_ = true;
805 os << "nvcuda::wmma::fragment<nvcuda::wmma::accumulator, " << shape_str << ", " << type.str()
806 << ">";
807 }
808 }
809
GetWmmaFragmentSize(const std::string & scope,const VarNode * variable,int32_t size)810 int32_t CodeGenCUDA::GetWmmaFragmentSize(const std::string& scope, const VarNode* variable,
811 int32_t size) {
812 std::string shape_str = fragment_shapes[variable];
813 size_t m, n, k;
814 size_t last_pos = 0, pos = 0;
815 pos = shape_str.find(", ", last_pos);
816 m = std::stoi(shape_str.substr(last_pos, pos - last_pos));
817 last_pos = pos + 2;
818 pos = shape_str.find(", ", last_pos);
819 n = std::stoi(shape_str.substr(last_pos, pos - last_pos));
820 last_pos = pos + 2;
821 k = std::stoi(shape_str.substr(last_pos, shape_str.length() - last_pos));
822 if (scope == "wmma.matrix_a") {
823 return size / m / k;
824 } else if (scope == "wmma.matrix_b") {
825 return size / n / k;
826 } else if (scope == "wmma.accumulator") {
827 return size / m / n;
828 }
829 return 0;
830 }
831
HandleVolatileLoads(const std::string & value,const LoadNode * op,std::ostream & os)832 void CodeGenCUDA::HandleVolatileLoads(const std::string& value, const LoadNode* op,
833 std::ostream& os) {
834 // Cast away volatile qualifier for fp16 types. That is, only loads and
835 // stores are volatile. The loaded objects are not marked as volatile.
836 //
837 if (op->dtype.is_float16() && IsVolatile(op->buffer_var.get())) {
838 os << "(";
839 PrintType(op->dtype, os);
840 os << ")(" << value << ")";
841 } else {
842 os << value;
843 }
844 }
845
PrintVecElemLoadExpr(DataType t,int i,const std::string & value,std::ostream & os)846 void CodeGenCUDA::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
847 std::ostream& os) {
848 CHECK_GT(t.lanes(), 1);
849 if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
850 if (!(t.lanes() == 2 || t.lanes() == 3)) {
851 if (i != 0) {
852 os << "|";
853 }
854 os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
855 return;
856 }
857 }
858
859 if (t.is_float16()) {
860 if (i == 0) {
861 os << "make_";
862 PrintType(t, os);
863 os << '(';
864 }
865 if (i % 2 == 0) {
866 os << "__pack_half2(" << value;
867 } else {
868 os << "," << value << ")";
869 if (i != t.lanes() - 1) {
870 os << ",";
871 } else {
872 os << ")";
873 }
874 }
875 return;
876 }
877
878 if (i == 0) {
879 os << "make_";
880 PrintType(t, os);
881 os << "(";
882 }
883 os << value;
884 if (i != t.lanes() - 1) {
885 os << ",";
886 } else {
887 os << ")";
888 }
889 return;
890 }
891
892 } // namespace codegen
893 } // namespace tvm
894