1 #include <algorithm>
2 #include <iomanip>
3 #include <sstream>
4 #include <utility>
5 
6 #include "CodeGen_D3D12Compute_Dev.h"
7 #include "CodeGen_Internal.h"
8 #include "Debug.h"
9 #include "DeviceArgument.h"
10 #include "IRMutator.h"
11 #include "IROperator.h"
12 #include "Simplify.h"
13 
14 #define DEBUG_TYPES (0)
15 
16 namespace Halide {
17 namespace Internal {
18 
19 using std::ostringstream;
20 using std::sort;
21 using std::string;
22 using std::vector;
23 
24 static ostringstream nil;
25 
CodeGen_D3D12Compute_Dev(Target t)26 CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_Dev(Target t)
27     : d3d12compute_c(src_stream, t) {
28 }
29 
print_type_maybe_storage(Type type,bool storage,AppendSpaceIfNeeded space)30 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_type_maybe_storage(Type type, bool storage, AppendSpaceIfNeeded space) {
31     ostringstream oss;
32 
33     // Storage uses packed vector types.
34     if (storage && type.lanes() != 1) {
35         // NOTE(marcos): HLSL has 'packoffset' but it only applies to constant
36         // buffer fields (constant registers (c)); because constant arguments
37         // are always converted to 32bit values by the runtime prior to kernel
38         // dispatch, there is no need to complicate things with packoffset.
39     }
40 
41     if (type.is_float()) {
42         switch (type.bits()) {
43         case 16:
44             // 16-bit floating point value. This data type is provided only for language compatibility.
45             // Direct3D 10 shader targets map all half data types to float data types.
46             // A half data type cannot be used on a uniform global variable (use the /Gec flag if this functionality is desired).
47             oss << "half";
48             break;
49         case 32:
50             oss << "float";
51             break;
52         case 64:
53             // "64-bit floating point value. You cannot use double precision values as inputs and outputs for a stream.
54             //  To pass double precision values between shaders, declare each double as a pair of uint data types.
55             //  Then, use the asdouble function to pack each double into the pair of uints and the asuint function to
56             //  unpack the pair of uints back into the double."
57             user_error << "HLSL (SM 5.1) does not have transparent support for 'double' types.\n";
58             oss << "double";
59             break;
60         default:
61             user_error << "Can't represent a float with this many bits in HLSL (SM 5.1): " << type << "\n";
62         }
63     } else {
64         switch (type.bits()) {
65         case 1:
66             oss << "bool";
67             break;
68         case 8:
69         case 16:
70         case 32:
71             if (type.is_uint()) {
72                 oss << "u";
73             }
74             oss << "int";
75 #if DEBUG_TYPES
76             oss << type.bits();
77 #endif
78             break;
79         case 64:
80             user_error << "HLSL (SM 5.1) does not support 64-bit integers.\n";
81             break;
82         default:
83             user_error << "Can't represent an integer with this many bits in HLSL (SM 5.1): " << type << "\n";
84         }
85     }
86     switch (type.lanes()) {
87     case 1:
88         break;
89     case 2:
90     case 3:
91     case 4:
92 #if DEBUG_TYPES
93         oss << "_(";
94 #endif
95         oss << type.lanes();
96 #if DEBUG_TYPES
97         oss << ")";
98 #endif
99         break;
100     case 8:
101     case 16:
102         // TODO(marcos): are there 8-wide and 16-wide types in HLSL?
103         // (CodeGen_GLSLBase seems to happily generate invalid vector types)
104     default:
105         user_error << "Unsupported vector width in HLSL (SM 5.1): " << type << "\n";
106     }
107 
108     if (space == AppendSpace) {
109         oss << " ";
110     }
111 
112     return oss.str();
113 }
114 
print_type(Type type,AppendSpaceIfNeeded space)115 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_type(Type type, AppendSpaceIfNeeded space) {
116     return print_type_maybe_storage(type, false, space);
117 }
118 
print_storage_type(Type type)119 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_storage_type(Type type) {
120     return print_type_maybe_storage(type, true, DoNotAppendSpace);
121 }
122 
print_reinterpret(Type type,const Expr & e)123 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_reinterpret(Type type, const Expr &e) {
124     if (type == e.type()) {
125         return print_expr(e);
126     } else {
127         return print_reinterpret_cast(type, print_expr(e));
128     }
129 }
130 
131 namespace {
simt_intrinsic(const string & name)132 string simt_intrinsic(const string &name) {
133     if (ends_with(name, ".__thread_id_x")) {
134         return "tid_in_tgroup.x";
135     } else if (ends_with(name, ".__thread_id_y")) {
136         return "tid_in_tgroup.y";
137     } else if (ends_with(name, ".__thread_id_z")) {
138         return "tid_in_tgroup.z";
139     } else if (ends_with(name, ".__thread_id_w")) {
140         user_error << "HLSL (SM5.1) does not support more than three dimensions for compute kernel threads.\n";
141     } else if (ends_with(name, ".__block_id_x")) {
142         return "tgroup_index.x";
143     } else if (ends_with(name, ".__block_id_y")) {
144         return "tgroup_index.y";
145     } else if (ends_with(name, ".__block_id_z")) {
146         return "tgroup_index.z";
147     } else if (ends_with(name, ".__block_id_w")) {
148         user_error << "HLSL (SM5.1) does not support more than three dimensions for compute dispatch groups.\n";
149     }
150     internal_error << "simt_intrinsic called on bad variable name: " << name << "\n";
151     return "";
152 }
153 }  // namespace
154 
visit(const Evaluate * op)155 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Evaluate *op) {
156     if (is_const(op->value)) return;
157     print_expr(op->value);
158 }
159 
print_extern_call(const Call * op)160 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_extern_call(const Call *op) {
161     internal_assert(!function_takes_user_context(op->name));
162 
163     vector<string> args(op->args.size());
164     for (size_t i = 0; i < op->args.size(); i++) {
165         args[i] = print_expr(op->args[i]);
166     }
167     ostringstream rhs;
168     rhs << op->name << "(" << with_commas(args) << ")";
169     return rhs.str();
170 }
171 
visit(const Max * op)172 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Max *op) {
173     print_expr(Call::make(op->type, "max", {op->a, op->b}, Call::Extern));
174 }
175 
visit(const Min * op)176 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Min *op) {
177     print_expr(Call::make(op->type, "min", {op->a, op->b}, Call::Extern));
178 }
179 
visit(const Div * op)180 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Div *op) {
181     int bits;
182     if (is_const_power_of_two_integer(op->b, &bits)) {
183         ostringstream oss;
184         oss << print_expr(op->a) << " >> " << bits;
185         print_assignment(op->type, oss.str());
186     } else if (op->type.is_int()) {
187         print_expr(lower_euclidean_div(op->a, op->b));
188     } else {
189         visit_binop(op->type, op->a, op->b, "/");
190     }
191 }
192 
visit(const Mod * op)193 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Mod *op) {
194     int bits;
195     if (is_const_power_of_two_integer(op->b, &bits)) {
196         ostringstream oss;
197         oss << print_expr(op->a) << " & " << ((1 << bits) - 1);
198         print_assignment(op->type, oss.str());
199     } else if (op->type.is_int()) {
200         print_expr(lower_euclidean_mod(op->a, op->b));
201     } else {
202         visit_binop(op->type, op->a, op->b, "%");
203     }
204 }
205 
visit(const For * loop)206 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const For *loop) {
207     user_assert(loop->for_type != ForType::GPULane)
208         << "The D3D12Compute backend does not support the gpu_lanes() scheduling directive.";
209 
210     if (!is_gpu_var(loop->name)) {
211         user_assert(loop->for_type != ForType::Parallel) << "Cannot use parallel loops inside D3D12Compute kernel\n";
212         CodeGen_C::visit(loop);
213         return;
214     }
215 
216     internal_assert((loop->for_type == ForType::GPUBlock) ||
217                     (loop->for_type == ForType::GPUThread))
218         << "kernel loop must be either gpu block or gpu thread\n";
219     internal_assert(is_zero(loop->min));
220 
221     stream << get_indent() << print_type(Int(32)) << " " << print_name(loop->name)
222            << " = " << simt_intrinsic(loop->name) << ";\n";
223 
224     loop->body.accept(this);
225 }
226 
visit(const Ramp * op)227 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Ramp *op) {
228     ostringstream rhs;
229     rhs << print_expr(op->base)
230         << " + "
231         << print_expr(op->stride)
232         << " * "
233         << print_type(op->type.with_lanes(op->lanes))
234         << "(";
235 
236     rhs << "0";
237     for (int i = 1; i < op->lanes; ++i) {
238         rhs << ", " << i;
239     }
240 
241     rhs << ")";
242     print_assignment(op->type.with_lanes(op->lanes), rhs.str());
243 }
244 
visit(const Broadcast * op)245 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Broadcast *op) {
246     string id_value = print_expr(op->value);
247     user_assert(op->value.type().lanes() == 1) << "Broadcast source must be 1-wide.\n";
248 
249     ostringstream rhs;
250     rhs << print_type(op->type.with_lanes(op->lanes))
251         << "(";
252     for (int i = 0; i < op->lanes; ++i) {
253         rhs << id_value;
254         if (i < op->lanes - 1)
255             rhs << ", ";
256     }
257     rhs << ")";
258 
259     print_assignment(op->type.with_lanes(op->lanes), rhs.str());
260 }
261 
visit(const Call * op)262 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
263     if (op->is_intrinsic(Call::gpu_thread_barrier)) {
264         internal_assert(op->args.size() == 1) << "gpu_thread_barrier() intrinsic must specify memory fence type.\n";
265 
266         auto fence_type_ptr = as_const_int(op->args[0]);
267         internal_assert(fence_type_ptr) << "gpu_thread_barrier() parameter is not a constant integer.\n";
268         auto fence_type = *fence_type_ptr;
269 
270         // By default, we'll just issue a Group (aka Shared) memory barrier,
271         // since it seems there isn't a sync without a memory barrier
272         // available
273         // NOTE(shoaibkamil): should we replace with AllMemoryBarrierWithGroupSync()
274         // if both are required?
275         if (fence_type & CodeGen_GPU_Dev::MemoryFenceType::Device &&
276             !(fence_type & CodeGen_GPU_Dev::MemoryFenceType::Shared)) {
277             stream << get_indent() << "DeviceMemoryBarrierWithGroupSync();\n";
278         } else if (fence_type & CodeGen_GPU_Dev::MemoryFenceType::Device) {
279             stream << get_indent() << "DeviceMemoryBarrier();\n";
280         }
281         stream << get_indent() << "GroupMemoryBarrierWithGroupSync();\n";
282         print_assignment(op->type, "0");
283     } else if (op->name == "pow_f32" && can_prove(op->args[0] > 0)) {
284         // If we know pow(x, y) is called with x > 0, we can use HLSL's pow
285         // directly.
286         stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")";
287     } else {
288         CodeGen_C::visit(op);
289     }
290 }
291 
292 namespace {
293 
294 // If e is a ramp expression with stride 1, return the base, otherwise undefined.
is_ramp_one(const Expr & e)295 Expr is_ramp_one(const Expr &e) {
296     const Ramp *r = e.as<Ramp>();
297     if (r == nullptr) {
298         return Expr();
299     }
300 
301     if (is_one(r->stride)) {
302         return r->base;
303     }
304 
305     return Expr();
306 }
307 
308 template<typename T>
hex_literal(T value)309 string hex_literal(T value) {
310     ostringstream hex;
311     hex << "0x" << std::uppercase << std::setfill('0') << std::setw(8) << std::hex
312         << value;
313     return hex.str();
314 }
315 
316 }  // namespace
317 
318 struct StoragePackUnpack {
319     using CodeGen = CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C;
320 
321     // Shader Model 5.1: threadgroup shared memory is limited 32KB
322     static const size_t ThreadGroupSharedStorageLimit = 32 * 1024;
323 
pack_storageHalide::Internal::StoragePackUnpack324     void pack_storage(const Allocate *op, size_t elements, size_t size_in_bytes) {
325         // we could try to compact things for smaller types:
326         size_t packing_factor = 1;
327         while (size_in_bytes > ThreadGroupSharedStorageLimit) {
328             // must pack/unpack elements to/from shared memory...
329             elements = (elements + 1) & ~size_t(1);
330             elements /= 2;
331             size_in_bytes = (size_in_bytes + 1) & ~size_t(1);
332             size_in_bytes /= 2;
333             packing_factor *= 2;
334         }
335         // smallest possible pack type is a byte (no nibbles)
336         internal_assert(packing_factor <= 4);
337     }
338 
pack_storeHalide::Internal::StoragePackUnpack339     std::ostringstream pack_store(CodeGen &cg, const Store *op) {
340         std::ostringstream lhs;
341         // NOTE(marcos): 8bit and 16bit word packing -- the smallest integer
342         // type granularity available in HLSL SM 5.1 is 32bit (int/uint):
343         Type value_type = op->value.type();
344         if (value_type.bits() == 32) {
345             // storing a 32bit word? great! just reinterpret value to uint32:
346             lhs << cg.print_name(op->name)
347                 << "[" << cg.print_expr(op->index) << "]"
348                 << " = "
349                 << cg.print_reinterpret(UInt(32), op->value)
350                 << ";\n";
351         } else {
352             // nightmare:
353             internal_assert(value_type.bits() < 32);
354             // must map element index to uint word array index:
355             ostringstream index;
356             auto bits = value_type.bits();
357             auto divisor = (32 / bits);
358             auto i = cg.print_expr(op->index);
359             index << i << " / " << divisor;
360             ostringstream word;
361             word << cg.print_name(op->name)
362                  << "[" + index.str() + "]";
363             // now mask the appropriate bits:
364             ostringstream mask;
365             mask << "("
366                  << hex_literal((1 << bits) - 1)
367                  << " << "
368                  << "(" << bits << "*(" << i << " % " << divisor << " ))"
369                  << ")";
370             // apply the mask to the rhs value:
371             ostringstream value;
372             value << "("
373                   << mask.str()
374                   << " & "
375                   << "(" << cg.print_expr(op->value) << ")"
376                   << ")";
377 
378             // the performance impact of atomic operations on shared memory is
379             // not well documented... here is something:
380             // https://stackoverflow.com/a/19548723
381             lhs << cg.get_indent() << "InterlockedAnd(" << word.str() << ", "
382                 << "~" << mask.str() << ");\n";
383             lhs << cg.get_indent() << "InterlockedXor(" << word.str() << ", " << value.str() << ");\n";
384         }
385         return lhs;
386     }
387 
unpack_loadHalide::Internal::StoragePackUnpack388     std::ostringstream unpack_load(CodeGen &cg, const Load *op) {
389         std::ostringstream rhs;
390         // NOTE(marcos): let's keep this block of code here (disabled) in case
391         // we need to "emulate" byte/short packing in shared memory (recall that
392         // the smallest type granularity in HLSL SM 5.1 allows is 32bit types):
393         if (op->type.bits() == 32) {
394             // loading a 32bit word? great! just reinterpret as float/int/uint
395             rhs << "as" << cg.print_type(op->type.element_of())
396                 << "("
397                 << cg.print_name(op->name)
398                 << "[" << cg.print_expr(op->index) << "]"
399                 << ")";
400         } else {
401             // not a 32bit word? hell ensues:
402             internal_assert(op->type.bits() < 32);
403             // must map element index to uint word array index:
404             ostringstream index;
405             auto bits = op->type.bits();
406             auto divisor = (32 / bits);
407             auto i = cg.print_expr(op->index);
408             index << i << " / " << divisor;
409             ostringstream word;
410             word << cg.print_name(op->name)
411                  << "[" + index.str() + "]";
412             // now mask the appropriate bits:
413             ostringstream mask;
414             mask << "("
415                  << hex_literal((1 << bits) - 1)
416                  << " << "
417                  << "(" << bits << "*(" << i << " % " << divisor << " ))"
418                  << ")";
419             // extract the correct bits from the word
420             ostringstream cut;
421             cut << "("
422                 << word.str()
423                 << " & "
424                 << mask.str()
425                 << ")"
426                 << " >> "
427                 << "(" << bits << "*(" << i << " % " << divisor << " ))";
428             // if it is signed, need to propagate sign... shift-up then down:
429             if (op->type.is_int()) {
430                 rhs << "asint"
431                     << "("
432                     << cut.str()
433                     << " << "
434                     << (32 - bits)
435                     << ")"
436                     << " >> "
437                     << (32 - bits);
438             } else {
439                 rhs << cut.str();
440             }
441         }
442         return rhs;
443     }
444 };
445 
visit(const Load * op)446 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
447     user_assert(is_one(op->predicate)) << "Predicated load is not supported inside D3D12Compute kernel.\n";
448 
449     // elements in a threadgroup shared buffer are always 32bits:
450     // must reinterpret (and maybe unpack) bits...
451     if (groupshared_allocations.contains(op->name)) {
452         ostringstream rhs;
453         internal_assert(allocations.contains(op->name));
454         // no ramps when accessing shared memory...
455         Expr ramp_base = is_ramp_one(op->index);
456         internal_assert(!ramp_base.defined());
457         internal_assert(op->type.lanes() == 1);
458 
459         string id_index = print_expr(op->index);
460         internal_assert(op->type.bits() <= 32);
461         Type promoted = op->type.with_bits(32);
462         if (promoted == op->type) {
463             rhs << print_name(op->name)
464                 << "[" << id_index << "]";
465         } else {
466             rhs << "as" << print_type(promoted)
467                 << "("
468                 << print_name(op->name)
469                 << "[" << id_index << "]"
470                 << ")";
471         }
472 
473         // NOTE(marcos): might need to resort to StoragePackUnpack::unpack_load() here...
474         print_assignment(op->type, rhs.str());
475         return;
476     }
477 
478     // If we're loading a contiguous ramp, "unroll" the ramp into loads:
479     Expr ramp_base = is_ramp_one(op->index);
480     if (ramp_base.defined()) {
481         internal_assert(op->type.is_vector());
482 
483         ostringstream rhs;
484         rhs << print_type(op->type)
485             << "(";
486         const int lanes = op->type.lanes();
487         for (int i = 0; i < lanes; ++i) {
488             rhs << print_name(op->name)
489                 << "["
490                 << print_expr(ramp_base)
491                 << "+" << i
492                 << "]";
493             if (i < lanes - 1)
494                 rhs << ", ";
495         }
496         rhs << ")";
497 
498         print_assignment(op->type, rhs.str());
499 
500         return;
501     }
502 
503     string id_index = print_expr(op->index);
504 
505     // Get the rhs just for the cache.
506     bool type_cast_needed = !(allocations.contains(op->name) &&
507                               allocations.get(op->name).type == op->type);
508 
509     ostringstream rhs;
510     if (type_cast_needed) {
511         ostringstream element;
512         element << print_name(op->name)
513                 << "[" << id_index << "]";
514         Type target_type = op->type;
515         Type source_type = allocations.get(op->name).type;
516         rhs << print_cast(target_type, source_type, element.str());
517     } else {
518         rhs << print_name(op->name);
519         rhs << "[" << id_index << "]";
520     }
521 
522     std::map<string, string>::iterator cached = cache.find(rhs.str());
523     if (cached != cache.end()) {
524         id = cached->second;
525         return;
526     }
527 
528     if (op->index.type().is_vector()) {
529         // If index is a vector, gather vector elements.
530         internal_assert(op->type.is_vector());
531 
532         // This has to be underscore as print_name prepends an underscore to
533         // names without one and that results in a name mismatch if a Load
534         // appears as the value of a Let
535         id = unique_name('_');
536         cache[rhs.str()] = id;
537 
538         stream << get_indent() << print_type(op->type)
539                << " " << id << ";\n";
540 
541         for (int i = 0; i < op->type.lanes(); ++i) {
542             stream << get_indent() << id << "[" << i << "] = "
543                    << print_type(op->type.element_of())
544                    << "("
545                    << print_name(op->name)
546                    << "[" << id_index << "[" << i << "]]"
547                    << ");\n";
548         }
549     } else {
550         print_assignment(op->type, rhs.str());
551     }
552 }
553 
visit(const Store * op)554 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Store *op) {
555     user_assert(is_one(op->predicate)) << "Predicated store is not supported inside D3D12Compute kernel.\n";
556 
557     Type value_type = op->value.type();
558 
559     // elements in a threadgroup shared buffer are always 32bits:
560     // must reinterpret (and maybe pack) bits...
561     if (groupshared_allocations.contains(op->name)) {
562         internal_assert(value_type.bits() <= 32);
563         ostringstream rhs;
564         rhs << print_name(op->name)
565             << "[" << print_expr(op->index) << "]"
566             << " = "
567             << print_reinterpret(allocations.get(op->name).type, op->value)
568             << ";\n";
569         // NOTE(marcos): might need to resort to StoragePackUnpack::pack_store() here...
570         stream << get_indent() << rhs.str();
571         return;
572     }
573 
574     // If we're writing a contiguous ramp, "unroll" the ramp into stores:
575     Expr ramp_base = is_ramp_one(op->index);
576     if (ramp_base.defined()) {
577         internal_assert(value_type.is_vector());
578 
579         int lanes = value_type.lanes();
580         for (int i = 0; i < lanes; ++i) {
581             // TODO(marcos): this indentation looks funny in the generated code
582             ostringstream rhs;
583             rhs << print_name(op->name)
584                 << "["
585                 << print_expr(ramp_base)
586                 << " + " << i
587                 << "]"
588                 << " = "
589                 << print_expr(op->value)
590                 << "["
591                 << i
592                 << "]"
593                 << ";\n";
594             stream << get_indent() << rhs.str();
595         }
596     } else if (op->index.type().is_vector()) {
597         // If index is a vector, scatter vector elements.
598         internal_assert(value_type.is_vector());
599 
600         for (int i = 0; i < value_type.lanes(); ++i) {
601             ostringstream rhs;
602             rhs << print_name(op->name)
603                 << "["
604                 << print_expr(op->index) << "[" << i << "]"
605                 << "]"
606                 << " = "
607                 << print_expr(op->value) << "[" << i << "];\n";
608             stream << get_indent() << rhs.str();
609         }
610     } else {
611         ostringstream rhs;
612         rhs << print_name(op->name)
613             << "[" << print_expr(op->index) << "]"
614             << " = "
615             << print_expr(op->value)
616             << ";\n";
617         stream << get_indent() << rhs.str();
618     }
619 
620     cache.clear();
621 }
622 
visit(const Select * op)623 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Select *op) {
624     ostringstream rhs;
625     string true_val = print_expr(op->true_value);
626     string false_val = print_expr(op->false_value);
627     string cond = print_expr(op->condition);
628     rhs << print_type(op->type)
629         << "(" << cond
630         << " ? " << true_val
631         << " : " << false_val
632         << ")";
633     print_assignment(op->type, rhs.str());
634 }
635 
is_shared_allocation(const Allocate * op)636 static bool is_shared_allocation(const Allocate *op) {
637     return op->memory_type == MemoryType::GPUShared;
638 }
639 
visit(const Allocate * op)640 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Allocate *op) {
641 
642     if (is_shared_allocation(op)) {
643         // Already handled
644         internal_assert(!groupshared_allocations.contains(op->name));
645         groupshared_allocations.push(op->name);
646         op->body.accept(this);
647     } else {
648         open_scope();
649 
650         debug(2) << "Allocate " << op->name << " on device\n";
651 
652         debug(3) << "Pushing allocation called " << op->name << " onto the symbol table\n";
653 
654         // Allocation is not a shared memory allocation, just make a local declaration.
655         // It must have a constant size.
656         int32_t size = op->constant_allocation_size();
657         user_assert(size > 0)
658             << "Allocation " << op->name << " has a dynamic size. "
659             << "Only fixed-size allocations are supported on the gpu. "
660             << "Try storing into shared memory instead.";
661 
662         stream << get_indent() << print_storage_type(op->type) << " "
663                << print_name(op->name) << "[" << size << "];\n";
664         stream << get_indent();
665 
666         Allocation alloc;
667         alloc.type = op->type;
668         allocations.push(op->name, alloc);
669 
670         op->body.accept(this);
671 
672         // Should have been freed internally
673         internal_assert(!allocations.contains(op->name));
674 
675         close_scope("alloc " + print_name(op->name));
676     }
677 }
678 
visit(const Free * op)679 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Free *op) {
680     if (groupshared_allocations.contains(op->name)) {
681         groupshared_allocations.pop(op->name);
682         return;
683     } else {
684         // Should have been freed internally
685         internal_assert(allocations.contains(op->name));
686         allocations.pop(op->name);
687         stream << get_indent();
688     }
689 }
690 
print_assignment(Type type,const string & rhs)691 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_assignment(Type type, const string &rhs) {
692     string rhs_modified = print_reinforced_cast(type, rhs);
693     return CodeGen_C::print_assignment(type, rhs_modified);
694 }
695 
print_vanilla_cast(Type type,const string & value_expr)696 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_vanilla_cast(Type type, const string &value_expr) {
697     ostringstream ss;
698     ss << print_type(type) << "(" << value_expr << ")";
699     return ss.str();
700 }
701 
print_reinforced_cast(Type type,const string & value_expr)702 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_reinforced_cast(Type type, const string &value_expr) {
703     if (type.is_float() || type.is_bool() || type.bits() == 32) {
704         return value_expr;
705     }
706 
707     // HLSL SM 5.1 only supports 32bit integer types; smaller integer types have
708     // to be placed in 32bit integers, with special attention to signed integers
709     // that require propagation of the sign bit (MSB):
710     // a) for signed types: shift-up then shift-down
711     // b) for unsigned types: simply mask the LSB (but shift-up and down also works)
712     ostringstream sl;
713     sl << "(" << value_expr << ")"
714        << " << "
715        << "(" << (32 - type.bits()) << ")";  // 1. shift-up to MSB
716     ostringstream rsr;
717     rsr << print_reinterpret_cast(type, sl.str())  // 2. reinterpret bits
718         << " >> " << (32 - type.bits());           // 3. shift-down to LSB
719     return rsr.str();
720 }
721 
print_reinterpret_cast(Type type,const string & value_expr)722 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_reinterpret_cast(Type type, const string &value_expr) {
723     type = type.element_of();
724 
725     string cast_expr;
726     cast_expr += "as";
727     switch (type.code()) {
728     case halide_type_uint:
729         cast_expr += "uint";
730         break;
731     case halide_type_int:
732         cast_expr += "int";
733         break;
734     case halide_type_float:
735         cast_expr += "float";
736         break;
737     case halide_type_handle:
738     default:
739         cast_expr = "BADCAST";
740         user_error << "Invalid reinterpret cast.\n";
741         break;
742     }
743     cast_expr += "(" + value_expr + ")";
744     return cast_expr;
745 }
746 
print_cast(Type target_type,Type source_type,const string & value_expr)747 string CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::print_cast(Type target_type, Type source_type, const string &value_expr) {
748     // casting to or from a float type? just use the language cast:
749     if (target_type.is_float() || source_type.is_float()) {
750         return print_vanilla_cast(target_type, value_expr);
751     }
752 
753     // casting to or from a bool type? just use the language cast:
754     if (target_type.is_bool() || source_type.is_bool()) {
755         return print_vanilla_cast(target_type, value_expr);
756     }
757 
758     // let the integer cast zoo begin...
759     internal_assert(!target_type.is_float());
760     internal_assert(!source_type.is_float());
761 
762     // HLSL (SM 5.1) only supports 32bit integers (signed and unsigned)...
763     // integer downcasting-to (or upcasting-from) lower bit integers require
764     // some emulation in code...
765     internal_assert(target_type.bits() >= 8);
766     internal_assert(source_type.bits() >= 8);
767     internal_assert(target_type.bits() <= 32);
768     internal_assert(source_type.bits() <= 32);
769     internal_assert(target_type.bits() % 8 == 0);
770     internal_assert(source_type.bits() % 8 == 0);
771 
772     // Case 1: source and target both have the same "signess"
773     bool same_signess = false;
774     same_signess |= target_type.is_int() && source_type.is_int();
775     same_signess |= target_type.is_uint() && source_type.is_uint();
776     if (same_signess) {
777         ostringstream ss;
778         if (target_type.bits() >= source_type.bits()) {
779             // target has enough bits to fully accommodate the source:
780             // it's a no-op, but we print a vanilla cast for clarity:
781             ss << "(" << print_vanilla_cast(target_type, value_expr) << ")";
782         } else {
783             // for signed types: shift-up then shift-down
784             // for unsigned types: mask the target LSB (but shift-up and down also works)
785             ss << "("
786                << "(" << value_expr << ")"
787                << " << "
788                << "(" << (32 - target_type.bits()) << ")"  // 1. shift-up to MSB
789                << ")"
790                << " >> " << (32 - target_type.bits());  // 2. shift-down to LSB
791         }
792         return ss.str();
793     }
794 
795     // Case 2: casting from a signed source to an unsigned target
796     if (source_type.is_int() && target_type.is_uint()) {
797         // reinterpret resulting bits as uint(32):
798         string masked = value_expr;
799         if (target_type.bits() < 32) {
800             masked = "(" + value_expr + ")" + " & " + hex_literal((1 << target_type.bits()) - 1);  // mask target LSB
801         }
802         return print_reinterpret_cast(target_type, masked);
803     }
804 
805     // Case 3: casting from an unsigned source to a signed target
806     internal_assert(source_type.is_uint());
807     internal_assert(target_type.is_int());
808     ostringstream ss;
809     if (target_type.bits() > source_type.bits()) {
810         // target has enough bits to fully accommodate the source:
811         // it's a no-op, but we print a vanilla cast for clarity:
812         ss << "(" << print_vanilla_cast(target_type, value_expr) << ")";
813     } else {
814         // shift-up, reinterpret as int (target_type), then shift-down
815         ss << print_reinforced_cast(target_type, value_expr);
816     }
817     return ss.str();
818 }
819 
visit(const Cast * op)820 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Cast *op) {
821     Type target_type = op->type;
822     Type source_type = op->value.type();
823     string value_expr = print_expr(op->value);
824 
825     string cast_expr = print_cast(target_type, source_type, print_expr(op->value));
826 
827     print_assignment(target_type, cast_expr);
828 }
829 
visit(const Atomic * op)830 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Atomic *op) {
831     // TODO: atomics
832     user_assert(false) << "Atomics operations are not supported inside D3D12Compute kernel.\n";
833 }
834 
visit(const FloatImm * op)835 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const FloatImm *op) {
836     // TODO(marcos): just a pass-through for now, but we might consider doing
837     // something different, such as adding the suffic 'u' to the integer that
838     // gets passed to float_from_bits() to eliminate HLSL shader warnings; we
839     // have seen division-by-zero shader warnings, and we postulated that it
840     // could be indirectly related to compiler assumptions on signed integer
841     // overflow when float_from_bits() is called, but we don't know for sure
842     return CodeGen_C::visit(op);
843 }
844 
add_kernel(Stmt s,const string & name,const vector<DeviceArgument> & args)845 void CodeGen_D3D12Compute_Dev::add_kernel(Stmt s,
846                                           const string &name,
847                                           const vector<DeviceArgument> &args) {
848     debug(2) << "CodeGen_D3D12Compute_Dev::compile " << name << "\n";
849 
850     // TODO: do we have to uniquify these names, or can we trust that they are safe?
851     cur_kernel_name = name;
852     d3d12compute_c.add_kernel(s, name, args);
853 }
854 
855 namespace {
856 struct BufferSize {
857     string name;
858     size_t size;
859 
BufferSizeHalide::Internal::__anone99566d60311::BufferSize860     BufferSize()
861         : size(0) {
862     }
BufferSizeHalide::Internal::__anone99566d60311::BufferSize863     BufferSize(string name, size_t size)
864         : name(std::move(name)), size(size) {
865     }
866 
operator <Halide::Internal::__anone99566d60311::BufferSize867     bool operator<(const BufferSize &r) const {
868         return size < r.size;
869     }
870 };
871 }  // namespace
872 
add_kernel(Stmt s,const string & name,const vector<DeviceArgument> & args)873 void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::add_kernel(Stmt s,
874                                                                   const string &name,
875                                                                   const vector<DeviceArgument> &args) {
876 
877     debug(2) << "Adding D3D12Compute kernel " << name << "\n";
878 
879     // Figure out which arguments should be passed in constant.
880     // Such arguments should be:
881     // - not written to,
882     // - loads are block-uniform,
883     // - constant size,
884     // - and all allocations together should be less than the max constant
885     //   buffer size given by D3D11_REQ_CONSTANT_BUFFER_ELEMENT_COUNT*(4*sizeof(float))
886     // The last condition is handled via the preprocessor in the kernel
887     // declaration.
888     vector<BufferSize> constants;
889     auto isConstantBuffer = [&s](const DeviceArgument &arg) {
890         return arg.is_buffer && CodeGen_GPU_Dev::is_buffer_constant(s, arg.name) && arg.size > 0;
891     };
892     for (auto &arg : args) {
893         if (isConstantBuffer(arg)) {
894             constants.emplace_back(arg.name, arg.size);
895         }
896     }
897 
898     // Sort the constant candidates from smallest to largest. This will put
899     // as many of the constant allocations in constant as possible.
900     // Ideally, we would prioritize constant buffers by how frequently they
901     // are accessed.
902     sort(constants.begin(), constants.end());
903 
904     // Compute the cumulative sum of the constants.
905     for (size_t i = 1; i < constants.size(); i++) {
906         constants[i].size += constants[i - 1].size;
907     }
908 
909     // Find all the shared allocations, uniquify their names,
910     // and declare them at global scope.
911     class FindSharedAllocationsAndUniquify : public IRMutator {
912         using IRMutator::visit;
913         Stmt visit(const Allocate *op) override {
914             if (is_shared_allocation(op)) {
915                 // Because these will go in global scope,
916                 // we need to ensure they have unique names.
917                 std::string new_name = unique_name(op->name);
918                 replacements[op->name] = new_name;
919 
920                 std::vector<Expr> new_extents;
921                 for (size_t i = 0; i < op->extents.size(); i++) {
922                     new_extents.push_back(mutate(op->extents[i]));
923                 }
924                 Stmt new_body = mutate(op->body);
925                 Expr new_condition = mutate(op->condition);
926                 Expr new_new_expr;
927                 if (op->new_expr.defined()) {
928                     new_new_expr = mutate(op->new_expr);
929                 }
930 
931                 Stmt new_alloc = Allocate::make(new_name, op->type, op->memory_type, new_extents,
932                                                 std::move(new_condition), std::move(new_body),
933                                                 std::move(new_new_expr), op->free_function);
934 
935                 allocs.push_back(new_alloc);
936                 replacements.erase(op->name);
937                 return new_alloc;
938             } else {
939                 return IRMutator::visit(op);
940             }
941         }
942 
943         Stmt visit(const Free *op) override {
944             auto it = replacements.find(op->name);
945             if (it != replacements.end()) {
946                 return Free::make(it->second);
947             } else {
948                 return IRMutator::visit(op);
949             }
950         }
951 
952         Expr visit(const Load *op) override {
953             auto it = replacements.find(op->name);
954             if (it != replacements.end()) {
955                 return Load::make(op->type, it->second,
956                                   mutate(op->index), op->image, op->param,
957                                   mutate(op->predicate), op->alignment);
958             } else {
959                 return IRMutator::visit(op);
960             }
961         }
962 
963         Stmt visit(const Store *op) override {
964             auto it = replacements.find(op->name);
965             if (it != replacements.end()) {
966                 return Store::make(it->second, mutate(op->value),
967                                    mutate(op->index), op->param,
968                                    mutate(op->predicate), op->alignment);
969             } else {
970                 return IRMutator::visit(op);
971             }
972         }
973 
974         std::map<string, string> replacements;
975         friend class CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C;
976         vector<Stmt> allocs;
977     };
978 
979     FindSharedAllocationsAndUniquify fsa;
980     s = fsa.mutate(s);
981 
982     uint32_t total_shared_bytes = 0;
983     for (Stmt sop : fsa.allocs) {
984         const Allocate *op = sop.as<Allocate>();
985         internal_assert(op->extents.size() == 1);
986         internal_assert(op->type.lanes() == 1);
987         // In D3D12/HLSL, only 32bit types (int/uint/float) are suppoerted (even
988         // though things are changing with newer shader models). Since there is
989         // no uint8 type, we'll have to emulate it with 32bit types...
990         // This will also require pack/unpack logic with bit-masking and aliased
991         // type reinterpretation via asfloat()/asuint() in the shader code... :(
992         Type smem_type = op->type;
993         smem_type.with_bits(32);
994         stream << "groupshared"
995                << " " << print_type(smem_type)
996                << " " << print_name(op->name);
997         if (is_const(op->extents[0])) {
998             std::stringstream ss;
999             ss << op->extents[0];
1000             size_t elements = 0;
1001             ss >> elements;
1002             size_t bytesize = elements * sizeof(uint32_t);
1003             // NOTE(marcos): might need to resort to StoragePackUnpack::pack_storage() here...
1004             internal_assert(bytesize <= StoragePackUnpack::ThreadGroupSharedStorageLimit);
1005             total_shared_bytes += bytesize;
1006             stream << " [" << elements << "];\n";
1007         } else {
1008             // fill-in __GROUPSHARED_SIZE_IN_BYTES later on when D3DCompile()
1009             // is invoked in halide_d3d12compute_run(); must get divided by 4
1010             // since groupshared memory elements have 32bit granularity:
1011             stream << " [ ( __GROUPSHARED_SIZE_IN_BYTES + 3 ) / 4 ];\n";
1012         }
1013         if (total_shared_bytes > StoragePackUnpack::ThreadGroupSharedStorageLimit) {
1014             debug(1) << "D3D12 CodeGen ERROR: Total thread group shared memory required for kernel '" << name
1015                      << "' exceeds the SM 5.1 limit of 32KB: " << total_shared_bytes << " bytes required.\n";
1016         }
1017         Allocation alloc;
1018         alloc.type = op->type;
1019         allocations.push(op->name, alloc);
1020     }
1021 
1022     // Emit the kernel function preamble (numtreads):
1023     // must first figure out the thread group size by traversing the stmt:
1024     struct FindThreadGroupSize : public IRVisitor {
1025         using IRVisitor::visit;
1026         void visit(const For *loop) override {
1027             if (!is_gpu_var(loop->name))
1028                 return loop->body.accept(this);
1029             if (loop->for_type != ForType::GPUThread)
1030                 return loop->body.accept(this);
1031             internal_assert(is_zero(loop->min));
1032             int index = thread_loop_workgroup_index(loop->name);
1033             user_assert(index >= 0) << "Invalid 'numthreads' index for loop variable '" << loop->name << "'.\n";
1034             // if 'numthreads' for a given dimension can't be determined at code
1035             // generation time, emit code such that it can be patched some point
1036             // later when calling D3DCompile() / halide_d3d12compute_run()
1037             numthreads[index] = 0;  // <-- 0 indicates 'undetermined'
1038             const IntImm *int_limit = loop->extent.as<IntImm>();
1039             if (nullptr != int_limit) {
1040                 numthreads[index] = int_limit->value;
1041                 user_assert(numthreads[index] > 0) << "For D3D12Compute, 'numthreads[" << index << "]' values must be greater than zero.\n";
1042             }
1043             debug(4) << "Thread group size for index " << index << " is " << numthreads[index] << "\n";
1044             loop->body.accept(this);
1045         }
1046         int thread_loop_workgroup_index(const string &name) {
1047             string ids[] = {".__thread_id_x",
1048                             ".__thread_id_y",
1049                             ".__thread_id_z",
1050                             ".__thread_id_w"};
1051             for (auto &id : ids) {
1052                 if (ends_with(name, id)) {
1053                     return (&id - ids);
1054                 }
1055             }
1056             return -1;
1057         }
1058         int numthreads[3] = {1, 1, 1};
1059     };
1060     FindThreadGroupSize ftg;
1061     s.accept(&ftg);
1062     // for undetermined 'numthreads' dimensions, insert placeholders to the code
1063     // such as '__NUM_TREADS_X' that will later be patched when D3DCompile() is
1064     // invoked in halide_d3d12compute_run()
1065     stream << "[ numthreads(";
1066     (ftg.numthreads[0] > 0) ? (stream << " " << ftg.numthreads[0]) : (stream << " __NUM_TREADS_X ");
1067     (ftg.numthreads[1] > 0) ? (stream << ", " << ftg.numthreads[1]) : (stream << ", __NUM_TREADS_Y ");
1068     (ftg.numthreads[2] > 0) ? (stream << ", " << ftg.numthreads[2]) : (stream << ", __NUM_TREADS_Z ");
1069     stream << ") ]\n";
1070 
1071     // Emit the kernel function prototype:
1072 
1073     stream << "void " << name << "(\n";
1074     stream << " "
1075            << "uint3 tgroup_index  : SV_GroupID,\n"
1076            << " "
1077            << "uint3 tid_in_tgroup : SV_GroupThreadID";
1078     for (auto &arg : args) {
1079         stream << ",\n";
1080         stream << " ";
1081         if (arg.is_buffer) {
1082             // NOTE(marcos): Passing all buffers as RWBuffers in order to bind
1083             // all buffers as UAVs since there is no way the runtime can know
1084             // if a given halide_buffer_t is read-only (SRV) or read-write...
1085             stream << "RW"
1086                    << "Buffer"
1087                    << "<" << print_type(arg.type) << ">"
1088                    << " " << print_name(arg.name);
1089             Allocation alloc;
1090             alloc.type = arg.type;
1091             allocations.push(arg.name, alloc);
1092         } else {
1093             stream << "uniform"
1094                    << " " << print_type(arg.type)
1095                    << " " << print_name(arg.name);
1096         }
1097     }
1098     stream << ")\n";
1099 
1100     // Emit the kernel code:
1101 
1102     open_scope();
1103     print(s);
1104     close_scope("kernel " + name);
1105 
1106     for (size_t i = 0; i < args.size(); i++) {
1107         // Remove buffer arguments from allocation scope
1108         if (args[i].is_buffer) {
1109             allocations.pop(args[i].name);
1110         }
1111     }
1112 
1113     stream << "\n";
1114 }
1115 
init_module()1116 void CodeGen_D3D12Compute_Dev::init_module() {
1117     debug(2) << "D3D12Compute device codegen init_module\n";
1118 
1119     // wipe the internal kernel source
1120     src_stream.str("");
1121     src_stream.clear();
1122 
1123     // compiler control pragmas
1124     src_stream
1125         // Disable innocent warnings:
1126         // warning X3078 : loop control variable conflicts with a previous declaration in the outer scope; most recent declaration will be used
1127         << "#pragma warning( disable : 3078 )"
1128            "\n"
1129         // warning X3557: loop only executes for 1 iteration(s), forcing loop to unroll
1130         << "#pragma warning( disable : 3557 )"
1131            "\n"
1132         // Disable more serious warnings:
1133         // TODO(marcos): should revisit the warnings below, as they are likely to impact performance (and possibly correctness too)
1134         // warning X3556 : integer modulus may be much slower, try using uints if possible
1135         // TODO(marcos): can we interchangeably replace ints by uints when we have modulo operations in the generated code?
1136         << "#pragma warning( disable : 3556 )"
1137            "\n"
1138         // warning X3571 : pow(f, e) will not work for negative f, use abs(f) or conditionally handle negative values if you expect them
1139         << "#pragma warning( disable : 3571 )"
1140            "\n"
1141         // warning X4714 : sum of temp registers and indexable temp registers times 256 threads exceeds the recommended total 16384.  Performance may be reduced
1142         << "#pragma warning( disable : 4714 )"
1143            "\n"
1144         << "\n";
1145 
1146     src_stream << "#define halide_unused(x) (void)(x)\n";
1147 
1148     // Write out the Halide math functions.
1149     src_stream
1150     //<< "namespace {\n"   // HLSL does not support unnamed namespaces...
1151 #if DEBUG_TYPES
1152         << "#define  int8   int\n"
1153         << "#define  int16  int\n"
1154         << "#define  int32  int\n"
1155         << "#define uint8  uint\n"
1156         << "#define uint16 uint\n"
1157         << "#define uint32 uint\n"
1158         << "\n"
1159         << "#define  bool_(x)   bool##x\n"
1160         << "#define  int8_(x)   int##x\n"
1161         << "#define  int16_(x)  int##x\n"
1162         << "#define  int32_(x)  int##x\n"
1163         << "#define uint8_(x)  uint##x\n"
1164         << "#define uint16_(x) uint##x\n"
1165         << "#define uint32_(x) uint##x\n"
1166         << "\n"
1167         << "#define asint32  asint\n"
1168         << "#define asuint32 asuint\n"
1169         << "\n"
1170 #endif
1171         << "float nan_f32()     { return  1.#IND; } \n"  // Quiet NaN with minimum fractional value.
1172         << "float neg_inf_f32() { return -1.#INF; } \n"
1173         << "float inf_f32()     { return +1.#INF; } \n"
1174         << "#define is_inf_f32     isinf    \n"
1175         << "#define is_finite_f32  isfinite \n"
1176         << "#define is_nan_f32     isnan    \n"
1177         << "#define float_from_bits asfloat \n"
1178         << "#define sqrt_f32    sqrt   \n"
1179         << "#define sin_f32     sin    \n"
1180         << "#define cos_f32     cos    \n"
1181         << "#define exp_f32     exp    \n"
1182         << "#define log_f32     log    \n"
1183         << "#define abs_f32     abs    \n"
1184         << "#define floor_f32   floor  \n"
1185         << "#define ceil_f32    ceil   \n"
1186         << "#define round_f32   round  \n"
1187         << "#define trunc_f32   trunc  \n"
1188         // pow() in HLSL has the same semantics as C if
1189         // x > 0.  Otherwise, we need to emulate C
1190         // behavior.
1191         // TODO(shoaibkamil): Can we simplify this?
1192         << "float pow_f32(float x, float y) { \n"
1193         << "  if (x > 0.0) {                  \n"
1194         << "    return pow(x, y);             \n"
1195         << "  } else if (y == 0.0) {          \n"
1196         << "    return 1.0f;                  \n"
1197         << "  } else if (trunc(y) == y) {     \n"
1198         << "    if (fmod(y, 2) == 0) {        \n"
1199         << "      return pow(abs(x), y);      \n"
1200         << "    } else {                      \n"
1201         << "      return -pow(abs(x), y);     \n"
1202         << "    }                             \n"
1203         << "  } else {                        \n"
1204         << "    return nan_f32();             \n"
1205         << "  }                               \n"
1206         << "}                                 \n"
1207         << "#define asin_f32    asin   \n"
1208         << "#define acos_f32    acos   \n"
1209         << "#define tan_f32     tan    \n"
1210         << "#define atan_f32    atan   \n"
1211         << "#define atan2_f32   atan2  \n"
1212         << "#define sinh_f32    sinh   \n"
1213         << "#define cosh_f32    cosh   \n"
1214         << "#define tanh_f32    tanh   \n"
1215         << "#define asinh_f32(x) (log_f32(x + sqrt_f32(x*x + 1))) \n"
1216         << "#define acosh_f32(x) (log_f32(x + sqrt_f32(x*x - 1))) \n"
1217         << "#define atanh_f32(x) (log_f32((1+x)/(1-x))/2) \n"
1218         << "#define fast_inverse_f32      rcp   \n"
1219         << "#define fast_inverse_sqrt_f32 rsqrt \n"
1220         << "\n";
1221     //<< "}\n"; // close namespace
1222 
1223     src_stream << "\n";
1224 
1225     d3d12compute_c.add_common_macros(src_stream);
1226 
1227     cur_kernel_name = "";
1228 }
1229 
compile_to_src()1230 vector<char> CodeGen_D3D12Compute_Dev::compile_to_src() {
1231     string str = src_stream.str();
1232     debug(1) << "D3D12Compute kernel:\n"
1233              << str << "\n";
1234     vector<char> buffer(str.begin(), str.end());
1235     buffer.push_back(0);
1236     return buffer;
1237 }
1238 
get_current_kernel_name()1239 string CodeGen_D3D12Compute_Dev::get_current_kernel_name() {
1240     return cur_kernel_name;
1241 }
1242 
dump()1243 void CodeGen_D3D12Compute_Dev::dump() {
1244     std::cerr << src_stream.str() << "\n";
1245 }
1246 
print_gpu_name(const std::string & name)1247 std::string CodeGen_D3D12Compute_Dev::print_gpu_name(const std::string &name) {
1248     return name;
1249 }
1250 
1251 }  // namespace Internal
1252 }  // namespace Halide
1253