1 /*
2  * Copyright 2016 Advanced Micro Devices, Inc.
3  * All Rights Reserved.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * on the rights to use, copy, modify, merge, publish, distribute, sub
9  * license, and/or sell copies of the Software, and to permit persons to whom
10  * the Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
19  * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
20  * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
21  * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
22  * USE OR OTHER DEALINGS IN THE SOFTWARE.
23  */
24 
25 #include "ac_nir_to_llvm.h"
26 #include "ac_rtld.h"
27 #include "si_pipe.h"
28 #include "si_shader_internal.h"
29 #include "sid.h"
30 #include "tgsi/tgsi_from_mesa.h"
31 #include "util/u_memory.h"
32 
33 struct si_llvm_diagnostics {
34    struct pipe_debug_callback *debug;
35    unsigned retval;
36 };
37 
si_diagnostic_handler(LLVMDiagnosticInfoRef di,void * context)38 static void si_diagnostic_handler(LLVMDiagnosticInfoRef di, void *context)
39 {
40    struct si_llvm_diagnostics *diag = (struct si_llvm_diagnostics *)context;
41    LLVMDiagnosticSeverity severity = LLVMGetDiagInfoSeverity(di);
42    const char *severity_str = NULL;
43 
44    switch (severity) {
45    case LLVMDSError:
46       severity_str = "error";
47       break;
48    case LLVMDSWarning:
49       severity_str = "warning";
50       break;
51    case LLVMDSRemark:
52    case LLVMDSNote:
53    default:
54       return;
55    }
56 
57    char *description = LLVMGetDiagInfoDescription(di);
58 
59    pipe_debug_message(diag->debug, SHADER_INFO, "LLVM diagnostic (%s): %s", severity_str,
60                       description);
61 
62    if (severity == LLVMDSError) {
63       diag->retval = 1;
64       fprintf(stderr, "LLVM triggered Diagnostic Handler: %s\n", description);
65    }
66 
67    LLVMDisposeMessage(description);
68 }
69 
si_compile_llvm(struct si_screen * sscreen,struct si_shader_binary * binary,struct ac_shader_config * conf,struct ac_llvm_compiler * compiler,struct ac_llvm_context * ac,struct pipe_debug_callback * debug,enum pipe_shader_type shader_type,const char * name,bool less_optimized)70 bool si_compile_llvm(struct si_screen *sscreen, struct si_shader_binary *binary,
71                      struct ac_shader_config *conf, struct ac_llvm_compiler *compiler,
72                      struct ac_llvm_context *ac, struct pipe_debug_callback *debug,
73                      enum pipe_shader_type shader_type, const char *name, bool less_optimized)
74 {
75    unsigned count = p_atomic_inc_return(&sscreen->num_compilations);
76 
77    if (si_can_dump_shader(sscreen, shader_type)) {
78       fprintf(stderr, "radeonsi: Compiling shader %d\n", count);
79 
80       if (!(sscreen->debug_flags & (DBG(NO_IR) | DBG(PREOPT_IR)))) {
81          fprintf(stderr, "%s LLVM IR:\n\n", name);
82          ac_dump_module(ac->module);
83          fprintf(stderr, "\n");
84       }
85    }
86 
87    if (sscreen->record_llvm_ir) {
88       char *ir = LLVMPrintModuleToString(ac->module);
89       binary->llvm_ir_string = strdup(ir);
90       LLVMDisposeMessage(ir);
91    }
92 
93    if (!si_replace_shader(count, binary)) {
94       struct ac_compiler_passes *passes = compiler->passes;
95 
96       if (ac->wave_size == 32)
97          passes = compiler->passes_wave32;
98       else if (less_optimized && compiler->low_opt_passes)
99          passes = compiler->low_opt_passes;
100 
101       struct si_llvm_diagnostics diag = {debug};
102       LLVMContextSetDiagnosticHandler(ac->context, si_diagnostic_handler, &diag);
103 
104       if (!ac_compile_module_to_elf(passes, ac->module, (char **)&binary->elf_buffer,
105                                     &binary->elf_size))
106          diag.retval = 1;
107 
108       if (diag.retval != 0) {
109          pipe_debug_message(debug, SHADER_INFO, "LLVM compilation failed");
110          return false;
111       }
112    }
113 
114    struct ac_rtld_binary rtld;
115    if (!ac_rtld_open(&rtld, (struct ac_rtld_open_info){
116                                .info = &sscreen->info,
117                                .shader_type = tgsi_processor_to_shader_stage(shader_type),
118                                .wave_size = ac->wave_size,
119                                .num_parts = 1,
120                                .elf_ptrs = &binary->elf_buffer,
121                                .elf_sizes = &binary->elf_size}))
122       return false;
123 
124    bool ok = ac_rtld_read_config(&sscreen->info, &rtld, conf);
125    ac_rtld_close(&rtld);
126    return ok;
127 }
128 
si_llvm_context_init(struct si_shader_context * ctx,struct si_screen * sscreen,struct ac_llvm_compiler * compiler,unsigned wave_size)129 void si_llvm_context_init(struct si_shader_context *ctx, struct si_screen *sscreen,
130                           struct ac_llvm_compiler *compiler, unsigned wave_size)
131 {
132    memset(ctx, 0, sizeof(*ctx));
133    ctx->screen = sscreen;
134    ctx->compiler = compiler;
135 
136    ac_llvm_context_init(&ctx->ac, compiler, sscreen->info.chip_class, sscreen->info.family,
137                         AC_FLOAT_MODE_DEFAULT_OPENGL, wave_size, 64);
138 }
139 
si_llvm_create_func(struct si_shader_context * ctx,const char * name,LLVMTypeRef * return_types,unsigned num_return_elems,unsigned max_workgroup_size)140 void si_llvm_create_func(struct si_shader_context *ctx, const char *name, LLVMTypeRef *return_types,
141                          unsigned num_return_elems, unsigned max_workgroup_size)
142 {
143    LLVMTypeRef ret_type;
144    enum ac_llvm_calling_convention call_conv;
145    enum pipe_shader_type real_shader_type;
146 
147    if (num_return_elems)
148       ret_type = LLVMStructTypeInContext(ctx->ac.context, return_types, num_return_elems, true);
149    else
150       ret_type = ctx->ac.voidt;
151 
152    real_shader_type = ctx->type;
153 
154    /* LS is merged into HS (TCS), and ES is merged into GS. */
155    if (ctx->screen->info.chip_class >= GFX9) {
156       if (ctx->shader->key.as_ls)
157          real_shader_type = PIPE_SHADER_TESS_CTRL;
158       else if (ctx->shader->key.as_es || ctx->shader->key.as_ngg)
159          real_shader_type = PIPE_SHADER_GEOMETRY;
160    }
161 
162    switch (real_shader_type) {
163    case PIPE_SHADER_VERTEX:
164    case PIPE_SHADER_TESS_EVAL:
165       call_conv = AC_LLVM_AMDGPU_VS;
166       break;
167    case PIPE_SHADER_TESS_CTRL:
168       call_conv = AC_LLVM_AMDGPU_HS;
169       break;
170    case PIPE_SHADER_GEOMETRY:
171       call_conv = AC_LLVM_AMDGPU_GS;
172       break;
173    case PIPE_SHADER_FRAGMENT:
174       call_conv = AC_LLVM_AMDGPU_PS;
175       break;
176    case PIPE_SHADER_COMPUTE:
177       call_conv = AC_LLVM_AMDGPU_CS;
178       break;
179    default:
180       unreachable("Unhandle shader type");
181    }
182 
183    /* Setup the function */
184    ctx->return_type = ret_type;
185    ctx->main_fn = ac_build_main(&ctx->args, &ctx->ac, call_conv, name, ret_type, ctx->ac.module);
186    ctx->return_value = LLVMGetUndef(ctx->return_type);
187 
188    if (ctx->screen->info.address32_hi) {
189       ac_llvm_add_target_dep_function_attr(ctx->main_fn, "amdgpu-32bit-address-high-bits",
190                                            ctx->screen->info.address32_hi);
191    }
192 
193    LLVMAddTargetDependentFunctionAttr(ctx->main_fn, "no-signed-zeros-fp-math", "true");
194 
195    ac_llvm_set_workgroup_size(ctx->main_fn, max_workgroup_size);
196 }
197 
si_llvm_optimize_module(struct si_shader_context * ctx)198 void si_llvm_optimize_module(struct si_shader_context *ctx)
199 {
200    /* Dump LLVM IR before any optimization passes */
201    if (ctx->screen->debug_flags & DBG(PREOPT_IR) && si_can_dump_shader(ctx->screen, ctx->type))
202       LLVMDumpModule(ctx->ac.module);
203 
204    /* Run the pass */
205    LLVMRunPassManager(ctx->compiler->passmgr, ctx->ac.module);
206    LLVMDisposeBuilder(ctx->ac.builder);
207 }
208 
si_llvm_dispose(struct si_shader_context * ctx)209 void si_llvm_dispose(struct si_shader_context *ctx)
210 {
211    LLVMDisposeModule(ctx->ac.module);
212    LLVMContextDispose(ctx->ac.context);
213    ac_llvm_context_dispose(&ctx->ac);
214 }
215 
216 /**
217  * Load a dword from a constant buffer.
218  */
si_buffer_load_const(struct si_shader_context * ctx,LLVMValueRef resource,LLVMValueRef offset)219 LLVMValueRef si_buffer_load_const(struct si_shader_context *ctx, LLVMValueRef resource,
220                                   LLVMValueRef offset)
221 {
222    return ac_build_buffer_load(&ctx->ac, resource, 1, NULL, offset, NULL, 0, 0, true, true);
223 }
224 
si_llvm_build_ret(struct si_shader_context * ctx,LLVMValueRef ret)225 void si_llvm_build_ret(struct si_shader_context *ctx, LLVMValueRef ret)
226 {
227    if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
228       LLVMBuildRetVoid(ctx->ac.builder);
229    else
230       LLVMBuildRet(ctx->ac.builder, ret);
231 }
232 
si_insert_input_ret(struct si_shader_context * ctx,LLVMValueRef ret,struct ac_arg param,unsigned return_index)233 LLVMValueRef si_insert_input_ret(struct si_shader_context *ctx, LLVMValueRef ret,
234                                  struct ac_arg param, unsigned return_index)
235 {
236    return LLVMBuildInsertValue(ctx->ac.builder, ret, ac_get_arg(&ctx->ac, param), return_index, "");
237 }
238 
si_insert_input_ret_float(struct si_shader_context * ctx,LLVMValueRef ret,struct ac_arg param,unsigned return_index)239 LLVMValueRef si_insert_input_ret_float(struct si_shader_context *ctx, LLVMValueRef ret,
240                                        struct ac_arg param, unsigned return_index)
241 {
242    LLVMBuilderRef builder = ctx->ac.builder;
243    LLVMValueRef p = ac_get_arg(&ctx->ac, param);
244 
245    return LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, p), return_index, "");
246 }
247 
si_insert_input_ptr(struct si_shader_context * ctx,LLVMValueRef ret,struct ac_arg param,unsigned return_index)248 LLVMValueRef si_insert_input_ptr(struct si_shader_context *ctx, LLVMValueRef ret,
249                                  struct ac_arg param, unsigned return_index)
250 {
251    LLVMBuilderRef builder = ctx->ac.builder;
252    LLVMValueRef ptr = ac_get_arg(&ctx->ac, param);
253    ptr = LLVMBuildPtrToInt(builder, ptr, ctx->ac.i32, "");
254    return LLVMBuildInsertValue(builder, ret, ptr, return_index, "");
255 }
256 
si_prolog_get_rw_buffers(struct si_shader_context * ctx)257 LLVMValueRef si_prolog_get_rw_buffers(struct si_shader_context *ctx)
258 {
259    LLVMValueRef ptr[2], list;
260    bool merged_shader = si_is_merged_shader(ctx->shader);
261 
262    ptr[0] = LLVMGetParam(ctx->main_fn, (merged_shader ? 8 : 0) + SI_SGPR_RW_BUFFERS);
263    list =
264       LLVMBuildIntToPtr(ctx->ac.builder, ptr[0], ac_array_in_const32_addr_space(ctx->ac.v4i32), "");
265    return list;
266 }
267 
si_build_gather_64bit(struct si_shader_context * ctx,LLVMTypeRef type,LLVMValueRef val1,LLVMValueRef val2)268 LLVMValueRef si_build_gather_64bit(struct si_shader_context *ctx, LLVMTypeRef type,
269                                    LLVMValueRef val1, LLVMValueRef val2)
270 {
271    LLVMValueRef values[2] = {
272       ac_to_integer(&ctx->ac, val1),
273       ac_to_integer(&ctx->ac, val2),
274    };
275    LLVMValueRef result = ac_build_gather_values(&ctx->ac, values, 2);
276    return LLVMBuildBitCast(ctx->ac.builder, result, type, "");
277 }
278 
si_llvm_emit_barrier(struct si_shader_context * ctx)279 void si_llvm_emit_barrier(struct si_shader_context *ctx)
280 {
281    /* GFX6 only (thanks to a hw bug workaround):
282     * The real barrier instruction isn’t needed, because an entire patch
283     * always fits into a single wave.
284     */
285    if (ctx->screen->info.chip_class == GFX6 && ctx->type == PIPE_SHADER_TESS_CTRL) {
286       ac_build_waitcnt(&ctx->ac, AC_WAIT_LGKM | AC_WAIT_VLOAD | AC_WAIT_VSTORE);
287       return;
288    }
289 
290    ac_build_s_barrier(&ctx->ac);
291 }
292 
293 /* Ensure that the esgs ring is declared.
294  *
295  * We declare it with 64KB alignment as a hint that the
296  * pointer value will always be 0.
297  */
si_llvm_declare_esgs_ring(struct si_shader_context * ctx)298 void si_llvm_declare_esgs_ring(struct si_shader_context *ctx)
299 {
300    if (ctx->esgs_ring)
301       return;
302 
303    assert(!LLVMGetNamedGlobal(ctx->ac.module, "esgs_ring"));
304 
305    ctx->esgs_ring = LLVMAddGlobalInAddressSpace(ctx->ac.module, LLVMArrayType(ctx->ac.i32, 0),
306                                                 "esgs_ring", AC_ADDR_SPACE_LDS);
307    LLVMSetLinkage(ctx->esgs_ring, LLVMExternalLinkage);
308    LLVMSetAlignment(ctx->esgs_ring, 64 * 1024);
309 }
310 
si_init_exec_from_input(struct si_shader_context * ctx,struct ac_arg param,unsigned bitoffset)311 void si_init_exec_from_input(struct si_shader_context *ctx, struct ac_arg param, unsigned bitoffset)
312 {
313    LLVMValueRef args[] = {
314       ac_get_arg(&ctx->ac, param),
315       LLVMConstInt(ctx->ac.i32, bitoffset, 0),
316    };
317    ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.init.exec.from.input", ctx->ac.voidt, args, 2,
318                       AC_FUNC_ATTR_CONVERGENT);
319 }
320 
321 /**
322  * Get the value of a shader input parameter and extract a bitfield.
323  */
unpack_llvm_param(struct si_shader_context * ctx,LLVMValueRef value,unsigned rshift,unsigned bitwidth)324 static LLVMValueRef unpack_llvm_param(struct si_shader_context *ctx, LLVMValueRef value,
325                                       unsigned rshift, unsigned bitwidth)
326 {
327    if (LLVMGetTypeKind(LLVMTypeOf(value)) == LLVMFloatTypeKind)
328       value = ac_to_integer(&ctx->ac, value);
329 
330    if (rshift)
331       value = LLVMBuildLShr(ctx->ac.builder, value, LLVMConstInt(ctx->ac.i32, rshift, 0), "");
332 
333    if (rshift + bitwidth < 32) {
334       unsigned mask = (1 << bitwidth) - 1;
335       value = LLVMBuildAnd(ctx->ac.builder, value, LLVMConstInt(ctx->ac.i32, mask, 0), "");
336    }
337 
338    return value;
339 }
340 
si_unpack_param(struct si_shader_context * ctx,struct ac_arg param,unsigned rshift,unsigned bitwidth)341 LLVMValueRef si_unpack_param(struct si_shader_context *ctx, struct ac_arg param, unsigned rshift,
342                              unsigned bitwidth)
343 {
344    LLVMValueRef value = ac_get_arg(&ctx->ac, param);
345 
346    return unpack_llvm_param(ctx, value, rshift, bitwidth);
347 }
348 
si_get_primitive_id(struct si_shader_context * ctx,unsigned swizzle)349 LLVMValueRef si_get_primitive_id(struct si_shader_context *ctx, unsigned swizzle)
350 {
351    if (swizzle > 0)
352       return ctx->ac.i32_0;
353 
354    switch (ctx->type) {
355    case PIPE_SHADER_VERTEX:
356       return ac_get_arg(&ctx->ac, ctx->vs_prim_id);
357    case PIPE_SHADER_TESS_CTRL:
358       return ac_get_arg(&ctx->ac, ctx->args.tcs_patch_id);
359    case PIPE_SHADER_TESS_EVAL:
360       return ac_get_arg(&ctx->ac, ctx->args.tes_patch_id);
361    case PIPE_SHADER_GEOMETRY:
362       return ac_get_arg(&ctx->ac, ctx->args.gs_prim_id);
363    default:
364       assert(0);
365       return ctx->ac.i32_0;
366    }
367 }
368 
si_llvm_get_block_size(struct ac_shader_abi * abi)369 LLVMValueRef si_llvm_get_block_size(struct ac_shader_abi *abi)
370 {
371    struct si_shader_context *ctx = si_shader_context_from_abi(abi);
372 
373    LLVMValueRef values[3];
374    LLVMValueRef result;
375    unsigned i;
376    unsigned *properties = ctx->shader->selector->info.properties;
377 
378    if (properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH] != 0) {
379       unsigned sizes[3] = {properties[TGSI_PROPERTY_CS_FIXED_BLOCK_WIDTH],
380                            properties[TGSI_PROPERTY_CS_FIXED_BLOCK_HEIGHT],
381                            properties[TGSI_PROPERTY_CS_FIXED_BLOCK_DEPTH]};
382 
383       for (i = 0; i < 3; ++i)
384          values[i] = LLVMConstInt(ctx->ac.i32, sizes[i], 0);
385 
386       result = ac_build_gather_values(&ctx->ac, values, 3);
387    } else {
388       result = ac_get_arg(&ctx->ac, ctx->block_size);
389    }
390 
391    return result;
392 }
393 
si_llvm_declare_compute_memory(struct si_shader_context * ctx)394 void si_llvm_declare_compute_memory(struct si_shader_context *ctx)
395 {
396    struct si_shader_selector *sel = ctx->shader->selector;
397    unsigned lds_size = sel->info.properties[TGSI_PROPERTY_CS_LOCAL_SIZE];
398 
399    LLVMTypeRef i8p = LLVMPointerType(ctx->ac.i8, AC_ADDR_SPACE_LDS);
400    LLVMValueRef var;
401 
402    assert(!ctx->ac.lds);
403 
404    var = LLVMAddGlobalInAddressSpace(ctx->ac.module, LLVMArrayType(ctx->ac.i8, lds_size),
405                                      "compute_lds", AC_ADDR_SPACE_LDS);
406    LLVMSetAlignment(var, 64 * 1024);
407 
408    ctx->ac.lds = LLVMBuildBitCast(ctx->ac.builder, var, i8p, "");
409 }
410 
si_nir_build_llvm(struct si_shader_context * ctx,struct nir_shader * nir)411 bool si_nir_build_llvm(struct si_shader_context *ctx, struct nir_shader *nir)
412 {
413    if (nir->info.stage == MESA_SHADER_VERTEX) {
414       si_llvm_load_vs_inputs(ctx, nir);
415    } else if (nir->info.stage == MESA_SHADER_FRAGMENT) {
416       unsigned colors_read = ctx->shader->selector->info.colors_read;
417       LLVMValueRef main_fn = ctx->main_fn;
418 
419       LLVMValueRef undef = LLVMGetUndef(ctx->ac.f32);
420 
421       unsigned offset = SI_PARAM_POS_FIXED_PT + 1;
422 
423       if (colors_read & 0x0f) {
424          unsigned mask = colors_read & 0x0f;
425          LLVMValueRef values[4];
426          values[0] = mask & 0x1 ? LLVMGetParam(main_fn, offset++) : undef;
427          values[1] = mask & 0x2 ? LLVMGetParam(main_fn, offset++) : undef;
428          values[2] = mask & 0x4 ? LLVMGetParam(main_fn, offset++) : undef;
429          values[3] = mask & 0x8 ? LLVMGetParam(main_fn, offset++) : undef;
430          ctx->abi.color0 = ac_to_integer(&ctx->ac, ac_build_gather_values(&ctx->ac, values, 4));
431       }
432       if (colors_read & 0xf0) {
433          unsigned mask = (colors_read & 0xf0) >> 4;
434          LLVMValueRef values[4];
435          values[0] = mask & 0x1 ? LLVMGetParam(main_fn, offset++) : undef;
436          values[1] = mask & 0x2 ? LLVMGetParam(main_fn, offset++) : undef;
437          values[2] = mask & 0x4 ? LLVMGetParam(main_fn, offset++) : undef;
438          values[3] = mask & 0x8 ? LLVMGetParam(main_fn, offset++) : undef;
439          ctx->abi.color1 = ac_to_integer(&ctx->ac, ac_build_gather_values(&ctx->ac, values, 4));
440       }
441 
442       ctx->abi.interp_at_sample_force_center =
443          ctx->shader->key.mono.u.ps.interpolate_at_sample_force_center;
444 
445       ctx->abi.kill_ps_if_inf_interp =
446          (ctx->screen->debug_flags & DBG(KILL_PS_INF_INTERP)) &&
447          (ctx->shader->selector->info.uses_persp_center ||
448           ctx->shader->selector->info.uses_persp_centroid ||
449           ctx->shader->selector->info.uses_persp_sample);
450 
451    } else if (nir->info.stage == MESA_SHADER_COMPUTE) {
452       if (nir->info.cs.user_data_components_amd) {
453          ctx->abi.user_data = ac_get_arg(&ctx->ac, ctx->cs_user_data);
454          ctx->abi.user_data = ac_build_expand_to_vec4(&ctx->ac, ctx->abi.user_data,
455                                                       nir->info.cs.user_data_components_amd);
456       }
457    }
458 
459    ctx->abi.inputs = &ctx->inputs[0];
460    ctx->abi.clamp_shadow_reference = true;
461    ctx->abi.robust_buffer_access = true;
462    ctx->abi.convert_undef_to_zero = true;
463    ctx->abi.clamp_div_by_zero = ctx->screen->options.clamp_div_by_zero;
464 
465    if (ctx->shader->selector->info.properties[TGSI_PROPERTY_CS_LOCAL_SIZE]) {
466       assert(gl_shader_stage_is_compute(nir->info.stage));
467       si_llvm_declare_compute_memory(ctx);
468    }
469    ac_nir_translate(&ctx->ac, &ctx->abi, &ctx->args, nir);
470 
471    return true;
472 }
473 
474 /**
475  * Given a list of shader part functions, build a wrapper function that
476  * runs them in sequence to form a monolithic shader.
477  */
si_build_wrapper_function(struct si_shader_context * ctx,LLVMValueRef * parts,unsigned num_parts,unsigned main_part,unsigned next_shader_first_part)478 void si_build_wrapper_function(struct si_shader_context *ctx, LLVMValueRef *parts,
479                                unsigned num_parts, unsigned main_part,
480                                unsigned next_shader_first_part)
481 {
482    LLVMBuilderRef builder = ctx->ac.builder;
483    /* PS epilog has one arg per color component; gfx9 merged shader
484     * prologs need to forward 40 SGPRs.
485     */
486    LLVMValueRef initial[AC_MAX_ARGS], out[AC_MAX_ARGS];
487    LLVMTypeRef function_type;
488    unsigned num_first_params;
489    unsigned num_out, initial_num_out;
490    ASSERTED unsigned num_out_sgpr;         /* used in debug checks */
491    ASSERTED unsigned initial_num_out_sgpr; /* used in debug checks */
492    unsigned num_sgprs, num_vgprs;
493    unsigned gprs;
494 
495    memset(&ctx->args, 0, sizeof(ctx->args));
496 
497    for (unsigned i = 0; i < num_parts; ++i) {
498       ac_add_function_attr(ctx->ac.context, parts[i], -1, AC_FUNC_ATTR_ALWAYSINLINE);
499       LLVMSetLinkage(parts[i], LLVMPrivateLinkage);
500    }
501 
502    /* The parameters of the wrapper function correspond to those of the
503     * first part in terms of SGPRs and VGPRs, but we use the types of the
504     * main part to get the right types. This is relevant for the
505     * dereferenceable attribute on descriptor table pointers.
506     */
507    num_sgprs = 0;
508    num_vgprs = 0;
509 
510    function_type = LLVMGetElementType(LLVMTypeOf(parts[0]));
511    num_first_params = LLVMCountParamTypes(function_type);
512 
513    for (unsigned i = 0; i < num_first_params; ++i) {
514       LLVMValueRef param = LLVMGetParam(parts[0], i);
515 
516       if (ac_is_sgpr_param(param)) {
517          assert(num_vgprs == 0);
518          num_sgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
519       } else {
520          num_vgprs += ac_get_type_size(LLVMTypeOf(param)) / 4;
521       }
522    }
523 
524    gprs = 0;
525    while (gprs < num_sgprs + num_vgprs) {
526       LLVMValueRef param = LLVMGetParam(parts[main_part], ctx->args.arg_count);
527       LLVMTypeRef type = LLVMTypeOf(param);
528       unsigned size = ac_get_type_size(type) / 4;
529 
530       /* This is going to get casted anyways, so we don't have to
531        * have the exact same type. But we do have to preserve the
532        * pointer-ness so that LLVM knows about it.
533        */
534       enum ac_arg_type arg_type = AC_ARG_INT;
535       if (LLVMGetTypeKind(type) == LLVMPointerTypeKind) {
536          type = LLVMGetElementType(type);
537 
538          if (LLVMGetTypeKind(type) == LLVMVectorTypeKind) {
539             if (LLVMGetVectorSize(type) == 4)
540                arg_type = AC_ARG_CONST_DESC_PTR;
541             else if (LLVMGetVectorSize(type) == 8)
542                arg_type = AC_ARG_CONST_IMAGE_PTR;
543             else
544                assert(0);
545          } else if (type == ctx->ac.f32) {
546             arg_type = AC_ARG_CONST_FLOAT_PTR;
547          } else {
548             assert(0);
549          }
550       }
551 
552       ac_add_arg(&ctx->args, gprs < num_sgprs ? AC_ARG_SGPR : AC_ARG_VGPR, size, arg_type, NULL);
553 
554       assert(ac_is_sgpr_param(param) == (gprs < num_sgprs));
555       assert(gprs + size <= num_sgprs + num_vgprs &&
556              (gprs >= num_sgprs || gprs + size <= num_sgprs));
557 
558       gprs += size;
559    }
560 
561    /* Prepare the return type. */
562    unsigned num_returns = 0;
563    LLVMTypeRef returns[AC_MAX_ARGS], last_func_type, return_type;
564 
565    last_func_type = LLVMGetElementType(LLVMTypeOf(parts[num_parts - 1]));
566    return_type = LLVMGetReturnType(last_func_type);
567 
568    switch (LLVMGetTypeKind(return_type)) {
569    case LLVMStructTypeKind:
570       num_returns = LLVMCountStructElementTypes(return_type);
571       assert(num_returns <= ARRAY_SIZE(returns));
572       LLVMGetStructElementTypes(return_type, returns);
573       break;
574    case LLVMVoidTypeKind:
575       break;
576    default:
577       unreachable("unexpected type");
578    }
579 
580    si_llvm_create_func(ctx, "wrapper", returns, num_returns,
581                        si_get_max_workgroup_size(ctx->shader));
582 
583    if (si_is_merged_shader(ctx->shader))
584       ac_init_exec_full_mask(&ctx->ac);
585 
586    /* Record the arguments of the function as if they were an output of
587     * a previous part.
588     */
589    num_out = 0;
590    num_out_sgpr = 0;
591 
592    for (unsigned i = 0; i < ctx->args.arg_count; ++i) {
593       LLVMValueRef param = LLVMGetParam(ctx->main_fn, i);
594       LLVMTypeRef param_type = LLVMTypeOf(param);
595       LLVMTypeRef out_type = ctx->args.args[i].file == AC_ARG_SGPR ? ctx->ac.i32 : ctx->ac.f32;
596       unsigned size = ac_get_type_size(param_type) / 4;
597 
598       if (size == 1) {
599          if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
600             param = LLVMBuildPtrToInt(builder, param, ctx->ac.i32, "");
601             param_type = ctx->ac.i32;
602          }
603 
604          if (param_type != out_type)
605             param = LLVMBuildBitCast(builder, param, out_type, "");
606          out[num_out++] = param;
607       } else {
608          LLVMTypeRef vector_type = LLVMVectorType(out_type, size);
609 
610          if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
611             param = LLVMBuildPtrToInt(builder, param, ctx->ac.i64, "");
612             param_type = ctx->ac.i64;
613          }
614 
615          if (param_type != vector_type)
616             param = LLVMBuildBitCast(builder, param, vector_type, "");
617 
618          for (unsigned j = 0; j < size; ++j)
619             out[num_out++] =
620                LLVMBuildExtractElement(builder, param, LLVMConstInt(ctx->ac.i32, j, 0), "");
621       }
622 
623       if (ctx->args.args[i].file == AC_ARG_SGPR)
624          num_out_sgpr = num_out;
625    }
626 
627    memcpy(initial, out, sizeof(out));
628    initial_num_out = num_out;
629    initial_num_out_sgpr = num_out_sgpr;
630 
631    /* Now chain the parts. */
632    LLVMValueRef ret = NULL;
633    for (unsigned part = 0; part < num_parts; ++part) {
634       LLVMValueRef in[AC_MAX_ARGS];
635       LLVMTypeRef ret_type;
636       unsigned out_idx = 0;
637       unsigned num_params = LLVMCountParams(parts[part]);
638 
639       /* Merged shaders are executed conditionally depending
640        * on the number of enabled threads passed in the input SGPRs. */
641       if (si_is_multi_part_shader(ctx->shader) && part == 0) {
642          LLVMValueRef ena, count = initial[3];
643 
644          count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
645          ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
646          ac_build_ifcc(&ctx->ac, ena, 6506);
647       }
648 
649       /* Derive arguments for the next part from outputs of the
650        * previous one.
651        */
652       for (unsigned param_idx = 0; param_idx < num_params; ++param_idx) {
653          LLVMValueRef param;
654          LLVMTypeRef param_type;
655          bool is_sgpr;
656          unsigned param_size;
657          LLVMValueRef arg = NULL;
658 
659          param = LLVMGetParam(parts[part], param_idx);
660          param_type = LLVMTypeOf(param);
661          param_size = ac_get_type_size(param_type) / 4;
662          is_sgpr = ac_is_sgpr_param(param);
663 
664          if (is_sgpr) {
665             ac_add_function_attr(ctx->ac.context, parts[part], param_idx + 1, AC_FUNC_ATTR_INREG);
666          } else if (out_idx < num_out_sgpr) {
667             /* Skip returned SGPRs the current part doesn't
668              * declare on the input. */
669             out_idx = num_out_sgpr;
670          }
671 
672          assert(out_idx + param_size <= (is_sgpr ? num_out_sgpr : num_out));
673 
674          if (param_size == 1)
675             arg = out[out_idx];
676          else
677             arg = ac_build_gather_values(&ctx->ac, &out[out_idx], param_size);
678 
679          if (LLVMTypeOf(arg) != param_type) {
680             if (LLVMGetTypeKind(param_type) == LLVMPointerTypeKind) {
681                if (LLVMGetPointerAddressSpace(param_type) == AC_ADDR_SPACE_CONST_32BIT) {
682                   arg = LLVMBuildBitCast(builder, arg, ctx->ac.i32, "");
683                   arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
684                } else {
685                   arg = LLVMBuildBitCast(builder, arg, ctx->ac.i64, "");
686                   arg = LLVMBuildIntToPtr(builder, arg, param_type, "");
687                }
688             } else {
689                arg = LLVMBuildBitCast(builder, arg, param_type, "");
690             }
691          }
692 
693          in[param_idx] = arg;
694          out_idx += param_size;
695       }
696 
697       ret = ac_build_call(&ctx->ac, parts[part], in, num_params);
698 
699       if (si_is_multi_part_shader(ctx->shader) && part + 1 == next_shader_first_part) {
700          ac_build_endif(&ctx->ac, 6506);
701 
702          /* The second half of the merged shader should use
703           * the inputs from the toplevel (wrapper) function,
704           * not the return value from the last call.
705           *
706           * That's because the last call was executed condi-
707           * tionally, so we can't consume it in the main
708           * block.
709           */
710          memcpy(out, initial, sizeof(initial));
711          num_out = initial_num_out;
712          num_out_sgpr = initial_num_out_sgpr;
713          continue;
714       }
715 
716       /* Extract the returned GPRs. */
717       ret_type = LLVMTypeOf(ret);
718       num_out = 0;
719       num_out_sgpr = 0;
720 
721       if (LLVMGetTypeKind(ret_type) != LLVMVoidTypeKind) {
722          assert(LLVMGetTypeKind(ret_type) == LLVMStructTypeKind);
723 
724          unsigned ret_size = LLVMCountStructElementTypes(ret_type);
725 
726          for (unsigned i = 0; i < ret_size; ++i) {
727             LLVMValueRef val = LLVMBuildExtractValue(builder, ret, i, "");
728 
729             assert(num_out < ARRAY_SIZE(out));
730             out[num_out++] = val;
731 
732             if (LLVMTypeOf(val) == ctx->ac.i32) {
733                assert(num_out_sgpr + 1 == num_out);
734                num_out_sgpr = num_out;
735             }
736          }
737       }
738    }
739 
740    /* Return the value from the last part. */
741    if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
742       LLVMBuildRetVoid(builder);
743    else
744       LLVMBuildRet(builder, ret);
745 }
746