1 /*
2  * Copyright 2018 Collabora Ltd.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * on the rights to use, copy, modify, merge, publish, distribute, sub
8  * license, and/or sell copies of the Software, and to permit persons to whom
9  * the Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL
18  * THE AUTHOR(S) AND/OR THEIR SUPPLIERS BE LIABLE FOR ANY CLAIM,
19  * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
20  * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
21  * USE OR OTHER DEALINGS IN THE SOFTWARE.
22  */
23 
24 #include "nir_to_spirv.h"
25 #include "spirv_builder.h"
26 
27 #include "nir.h"
28 #include "pipe/p_state.h"
29 #include "util/u_memory.h"
30 #include "util/hash_table.h"
31 
32 /* this consistently maps slots to a zero-indexed value to avoid wasting slots */
33 static unsigned slot_pack_map[] = {
34    /* Position is builtin */
35    [VARYING_SLOT_POS] = UINT_MAX,
36    [VARYING_SLOT_COL0] = 0, /* input/output */
37    [VARYING_SLOT_COL1] = 1, /* input/output */
38    [VARYING_SLOT_FOGC] = 2, /* input/output */
39    /* TEX0-7 are deprecated, so we put them at the end of the range and hope nobody uses them all */
40    [VARYING_SLOT_TEX0] = VARYING_SLOT_VAR0 - 1, /* input/output */
41    [VARYING_SLOT_TEX1] = VARYING_SLOT_VAR0 - 2,
42    [VARYING_SLOT_TEX2] = VARYING_SLOT_VAR0 - 3,
43    [VARYING_SLOT_TEX3] = VARYING_SLOT_VAR0 - 4,
44    [VARYING_SLOT_TEX4] = VARYING_SLOT_VAR0 - 5,
45    [VARYING_SLOT_TEX5] = VARYING_SLOT_VAR0 - 6,
46    [VARYING_SLOT_TEX6] = VARYING_SLOT_VAR0 - 7,
47    [VARYING_SLOT_TEX7] = VARYING_SLOT_VAR0 - 8,
48 
49    /* PointSize is builtin */
50    [VARYING_SLOT_PSIZ] = UINT_MAX,
51 
52    [VARYING_SLOT_BFC0] = 3, /* output only */
53    [VARYING_SLOT_BFC1] = 4, /* output only */
54    [VARYING_SLOT_EDGE] = 5, /* output only */
55    [VARYING_SLOT_CLIP_VERTEX] = 6, /* output only */
56 
57    /* ClipDistance is builtin */
58    [VARYING_SLOT_CLIP_DIST0] = UINT_MAX,
59    [VARYING_SLOT_CLIP_DIST1] = UINT_MAX,
60 
61    /* CullDistance is builtin */
62    [VARYING_SLOT_CULL_DIST0] = UINT_MAX, /* input/output */
63    [VARYING_SLOT_CULL_DIST1] = UINT_MAX, /* never actually used */
64 
65    /* PrimitiveId is builtin */
66    [VARYING_SLOT_PRIMITIVE_ID] = UINT_MAX,
67 
68    /* Layer is builtin */
69    [VARYING_SLOT_LAYER] = UINT_MAX, /* input/output */
70 
71    /* ViewportIndex is builtin */
72    [VARYING_SLOT_VIEWPORT] =  UINT_MAX, /* input/output */
73 
74    /* FrontFacing is builtin */
75    [VARYING_SLOT_FACE] = UINT_MAX,
76 
77    /* PointCoord is builtin */
78    [VARYING_SLOT_PNTC] = UINT_MAX, /* input only */
79 
80    /* TessLevelOuter is builtin */
81    [VARYING_SLOT_TESS_LEVEL_OUTER] = UINT_MAX,
82    /* TessLevelInner is builtin */
83    [VARYING_SLOT_TESS_LEVEL_INNER] = UINT_MAX,
84 
85    [VARYING_SLOT_BOUNDING_BOX0] = 7, /* Only appears as TCS output. */
86    [VARYING_SLOT_BOUNDING_BOX1] = 8, /* Only appears as TCS output. */
87    [VARYING_SLOT_VIEW_INDEX] = 9, /* input/output */
88    [VARYING_SLOT_VIEWPORT_MASK] = 10, /* output only */
89 };
90 #define NTV_MIN_RESERVED_SLOTS 11
91 
92 struct ntv_context {
93    void *mem_ctx;
94 
95    struct spirv_builder builder;
96 
97    SpvId GLSL_std_450;
98 
99    gl_shader_stage stage;
100 
101    SpvId ubos[128];
102    size_t num_ubos;
103    SpvId image_types[PIPE_MAX_SAMPLERS];
104    SpvId samplers[PIPE_MAX_SAMPLERS];
105    unsigned samplers_used : PIPE_MAX_SAMPLERS;
106    SpvId entry_ifaces[PIPE_MAX_SHADER_INPUTS * 4 + PIPE_MAX_SHADER_OUTPUTS * 4];
107    size_t num_entry_ifaces;
108 
109    SpvId *defs;
110    size_t num_defs;
111 
112    SpvId *regs;
113    size_t num_regs;
114 
115    struct hash_table *vars; /* nir_variable -> SpvId */
116    struct hash_table *so_outputs; /* pipe_stream_output -> SpvId */
117    unsigned outputs[VARYING_SLOT_MAX];
118    const struct glsl_type *so_output_gl_types[VARYING_SLOT_MAX];
119    SpvId so_output_types[VARYING_SLOT_MAX];
120 
121    const SpvId *block_ids;
122    size_t num_blocks;
123    bool block_started;
124    SpvId loop_break, loop_cont;
125 
126    SpvId front_face_var, instance_id_var, vertex_id_var;
127 #ifndef NDEBUG
128    bool seen_texcoord[8]; //whether we've seen a VARYING_SLOT_TEX[n] this pass
129 #endif
130 };
131 
132 static SpvId
133 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
134                   unsigned num_components, float value);
135 
136 static SpvId
137 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
138                   unsigned num_components, uint32_t value);
139 
140 static SpvId
141 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
142                   unsigned num_components, int32_t value);
143 
144 static SpvId
145 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src);
146 
147 static SpvId
148 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
149            SpvId src0, SpvId src1);
150 
151 static SpvId
152 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
153            SpvId src0, SpvId src1, SpvId src2);
154 
155 static SpvId
get_bvec_type(struct ntv_context * ctx,int num_components)156 get_bvec_type(struct ntv_context *ctx, int num_components)
157 {
158    SpvId bool_type = spirv_builder_type_bool(&ctx->builder);
159    if (num_components > 1)
160       return spirv_builder_type_vector(&ctx->builder, bool_type,
161                                        num_components);
162 
163    assert(num_components == 1);
164    return bool_type;
165 }
166 
167 static SpvId
block_label(struct ntv_context * ctx,nir_block * block)168 block_label(struct ntv_context *ctx, nir_block *block)
169 {
170    assert(block->index < ctx->num_blocks);
171    return ctx->block_ids[block->index];
172 }
173 
174 static SpvId
emit_float_const(struct ntv_context * ctx,int bit_size,float value)175 emit_float_const(struct ntv_context *ctx, int bit_size, float value)
176 {
177    assert(bit_size == 32);
178    return spirv_builder_const_float(&ctx->builder, bit_size, value);
179 }
180 
181 static SpvId
emit_uint_const(struct ntv_context * ctx,int bit_size,uint32_t value)182 emit_uint_const(struct ntv_context *ctx, int bit_size, uint32_t value)
183 {
184    assert(bit_size == 32);
185    return spirv_builder_const_uint(&ctx->builder, bit_size, value);
186 }
187 
188 static SpvId
emit_int_const(struct ntv_context * ctx,int bit_size,int32_t value)189 emit_int_const(struct ntv_context *ctx, int bit_size, int32_t value)
190 {
191    assert(bit_size == 32);
192    return spirv_builder_const_int(&ctx->builder, bit_size, value);
193 }
194 
195 static SpvId
get_fvec_type(struct ntv_context * ctx,unsigned bit_size,unsigned num_components)196 get_fvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
197 {
198    assert(bit_size == 32); // only 32-bit floats supported so far
199 
200    SpvId float_type = spirv_builder_type_float(&ctx->builder, bit_size);
201    if (num_components > 1)
202       return spirv_builder_type_vector(&ctx->builder, float_type,
203                                        num_components);
204 
205    assert(num_components == 1);
206    return float_type;
207 }
208 
209 static SpvId
get_ivec_type(struct ntv_context * ctx,unsigned bit_size,unsigned num_components)210 get_ivec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
211 {
212    assert(bit_size == 32); // only 32-bit ints supported so far
213 
214    SpvId int_type = spirv_builder_type_int(&ctx->builder, bit_size);
215    if (num_components > 1)
216       return spirv_builder_type_vector(&ctx->builder, int_type,
217                                        num_components);
218 
219    assert(num_components == 1);
220    return int_type;
221 }
222 
223 static SpvId
get_uvec_type(struct ntv_context * ctx,unsigned bit_size,unsigned num_components)224 get_uvec_type(struct ntv_context *ctx, unsigned bit_size, unsigned num_components)
225 {
226    assert(bit_size == 32); // only 32-bit uints supported so far
227 
228    SpvId uint_type = spirv_builder_type_uint(&ctx->builder, bit_size);
229    if (num_components > 1)
230       return spirv_builder_type_vector(&ctx->builder, uint_type,
231                                        num_components);
232 
233    assert(num_components == 1);
234    return uint_type;
235 }
236 
237 static SpvId
get_dest_uvec_type(struct ntv_context * ctx,nir_dest * dest)238 get_dest_uvec_type(struct ntv_context *ctx, nir_dest *dest)
239 {
240    unsigned bit_size = MAX2(nir_dest_bit_size(*dest), 32);
241    return get_uvec_type(ctx, bit_size, nir_dest_num_components(*dest));
242 }
243 
244 static SpvId
get_glsl_basetype(struct ntv_context * ctx,enum glsl_base_type type)245 get_glsl_basetype(struct ntv_context *ctx, enum glsl_base_type type)
246 {
247    switch (type) {
248    case GLSL_TYPE_BOOL:
249       return spirv_builder_type_bool(&ctx->builder);
250 
251    case GLSL_TYPE_FLOAT:
252       return spirv_builder_type_float(&ctx->builder, 32);
253 
254    case GLSL_TYPE_INT:
255       return spirv_builder_type_int(&ctx->builder, 32);
256 
257    case GLSL_TYPE_UINT:
258       return spirv_builder_type_uint(&ctx->builder, 32);
259    /* TODO: handle more types */
260 
261    default:
262       unreachable("unknown GLSL type");
263    }
264 }
265 
266 static SpvId
get_glsl_type(struct ntv_context * ctx,const struct glsl_type * type)267 get_glsl_type(struct ntv_context *ctx, const struct glsl_type *type)
268 {
269    assert(type);
270    if (glsl_type_is_scalar(type))
271       return get_glsl_basetype(ctx, glsl_get_base_type(type));
272 
273    if (glsl_type_is_vector(type))
274       return spirv_builder_type_vector(&ctx->builder,
275          get_glsl_basetype(ctx, glsl_get_base_type(type)),
276          glsl_get_vector_elements(type));
277 
278    if (glsl_type_is_array(type)) {
279       SpvId ret = spirv_builder_type_array(&ctx->builder,
280          get_glsl_type(ctx, glsl_get_array_element(type)),
281          emit_uint_const(ctx, 32, glsl_get_length(type)));
282       uint32_t stride = glsl_get_explicit_stride(type);
283       if (stride)
284          spirv_builder_emit_array_stride(&ctx->builder, ret, stride);
285       return ret;
286    }
287 
288 
289    unreachable("we shouldn't get here, I think...");
290 }
291 
292 static inline unsigned
handle_slot(struct ntv_context * ctx,unsigned slot)293 handle_slot(struct ntv_context *ctx, unsigned slot)
294 {
295    unsigned orig = slot;
296    if (slot < VARYING_SLOT_VAR0) {
297 #ifndef NDEBUG
298       if (slot >= VARYING_SLOT_TEX0 && slot <= VARYING_SLOT_TEX7)
299          ctx->seen_texcoord[slot - VARYING_SLOT_TEX0] = true;
300 #endif
301       slot = slot_pack_map[slot];
302       if (slot == UINT_MAX)
303          debug_printf("unhandled varying slot: %s\n", gl_varying_slot_name(orig));
304    } else {
305       slot -= VARYING_SLOT_VAR0 - NTV_MIN_RESERVED_SLOTS;
306       assert(slot <= VARYING_SLOT_VAR0 - 8 ||
307              !ctx->seen_texcoord[VARYING_SLOT_VAR0 - slot - 1]);
308 
309    }
310    assert(slot < VARYING_SLOT_VAR0);
311    return slot;
312 }
313 
314 #define HANDLE_EMIT_BUILTIN(SLOT, BUILTIN) \
315       case VARYING_SLOT_##SLOT: \
316          spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltIn##BUILTIN); \
317          break
318 
319 
320 static void
emit_input(struct ntv_context * ctx,struct nir_variable * var)321 emit_input(struct ntv_context *ctx, struct nir_variable *var)
322 {
323    SpvId var_type = get_glsl_type(ctx, var->type);
324    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
325                                                    SpvStorageClassInput,
326                                                    var_type);
327    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
328                                          SpvStorageClassInput);
329 
330    if (var->name)
331       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
332 
333    if (ctx->stage == MESA_SHADER_FRAGMENT) {
334       unsigned slot = var->data.location;
335       switch (slot) {
336       HANDLE_EMIT_BUILTIN(POS, FragCoord);
337       HANDLE_EMIT_BUILTIN(PNTC, PointCoord);
338       HANDLE_EMIT_BUILTIN(LAYER, Layer);
339       HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
340       HANDLE_EMIT_BUILTIN(CLIP_DIST0, ClipDistance);
341       HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
342       HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
343       HANDLE_EMIT_BUILTIN(FACE, FrontFacing);
344 
345       default:
346          slot = handle_slot(ctx, slot);
347          spirv_builder_emit_location(&ctx->builder, var_id, slot);
348       }
349    } else {
350       spirv_builder_emit_location(&ctx->builder, var_id,
351                                   var->data.driver_location);
352    }
353 
354    if (var->data.location_frac)
355       spirv_builder_emit_component(&ctx->builder, var_id,
356                                    var->data.location_frac);
357 
358    if (var->data.interpolation == INTERP_MODE_FLAT)
359       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
360 
361    _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
362 
363    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
364    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
365 }
366 
367 static void
emit_output(struct ntv_context * ctx,struct nir_variable * var)368 emit_output(struct ntv_context *ctx, struct nir_variable *var)
369 {
370    SpvId var_type = get_glsl_type(ctx, var->type);
371    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
372                                                    SpvStorageClassOutput,
373                                                    var_type);
374    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
375                                          SpvStorageClassOutput);
376    if (var->name)
377       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
378 
379 
380    if (ctx->stage == MESA_SHADER_VERTEX) {
381       unsigned slot = var->data.location;
382       switch (slot) {
383       HANDLE_EMIT_BUILTIN(POS, Position);
384       HANDLE_EMIT_BUILTIN(PSIZ, PointSize);
385       HANDLE_EMIT_BUILTIN(LAYER, Layer);
386       HANDLE_EMIT_BUILTIN(PRIMITIVE_ID, PrimitiveId);
387       HANDLE_EMIT_BUILTIN(CULL_DIST0, CullDistance);
388       HANDLE_EMIT_BUILTIN(VIEWPORT, ViewportIndex);
389       HANDLE_EMIT_BUILTIN(TESS_LEVEL_OUTER, TessLevelOuter);
390       HANDLE_EMIT_BUILTIN(TESS_LEVEL_INNER, TessLevelInner);
391 
392       case VARYING_SLOT_CLIP_DIST0:
393          assert(glsl_type_is_array(var->type));
394          spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInClipDistance);
395          /* this can be as large as 2x vec4, which requires 2 slots */
396          ctx->outputs[VARYING_SLOT_CLIP_DIST1] = var_id;
397          ctx->so_output_gl_types[VARYING_SLOT_CLIP_DIST1] = var->type;
398          ctx->so_output_types[VARYING_SLOT_CLIP_DIST1] = var_type;
399          break;
400 
401       default:
402          slot = handle_slot(ctx, slot);
403          spirv_builder_emit_location(&ctx->builder, var_id, slot);
404       }
405       ctx->outputs[var->data.location] = var_id;
406       ctx->so_output_gl_types[var->data.location] = var->type;
407       ctx->so_output_types[var->data.location] = var_type;
408    } else if (ctx->stage == MESA_SHADER_FRAGMENT) {
409       if (var->data.location >= FRAG_RESULT_DATA0) {
410          spirv_builder_emit_location(&ctx->builder, var_id,
411                                      var->data.location - FRAG_RESULT_DATA0);
412          spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
413       } else {
414          switch (var->data.location) {
415          case FRAG_RESULT_COLOR:
416             unreachable("gl_FragColor should be lowered by now");
417 
418          case FRAG_RESULT_DEPTH:
419             spirv_builder_emit_builtin(&ctx->builder, var_id, SpvBuiltInFragDepth);
420             break;
421 
422          default:
423             spirv_builder_emit_location(&ctx->builder, var_id,
424                                         var->data.driver_location);
425             spirv_builder_emit_index(&ctx->builder, var_id, var->data.index);
426          }
427       }
428    }
429 
430    if (var->data.location_frac)
431       spirv_builder_emit_component(&ctx->builder, var_id,
432                                    var->data.location_frac);
433 
434    switch (var->data.interpolation) {
435    case INTERP_MODE_NONE:
436    case INTERP_MODE_SMOOTH: /* XXX spirv doesn't seem to have anything for this */
437       break;
438    case INTERP_MODE_FLAT:
439       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationFlat);
440       break;
441    case INTERP_MODE_EXPLICIT:
442       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationExplicitInterpAMD);
443       break;
444    case INTERP_MODE_NOPERSPECTIVE:
445       spirv_builder_emit_decoration(&ctx->builder, var_id, SpvDecorationNoPerspective);
446       break;
447    default:
448       unreachable("unknown interpolation value");
449    }
450 
451    _mesa_hash_table_insert(ctx->vars, var, (void *)(intptr_t)var_id);
452 
453    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
454    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
455 }
456 
457 static SpvDim
type_to_dim(enum glsl_sampler_dim gdim,bool * is_ms)458 type_to_dim(enum glsl_sampler_dim gdim, bool *is_ms)
459 {
460    *is_ms = false;
461    switch (gdim) {
462    case GLSL_SAMPLER_DIM_1D:
463       return SpvDim1D;
464    case GLSL_SAMPLER_DIM_2D:
465       return SpvDim2D;
466    case GLSL_SAMPLER_DIM_3D:
467       return SpvDim3D;
468    case GLSL_SAMPLER_DIM_CUBE:
469       return SpvDimCube;
470    case GLSL_SAMPLER_DIM_RECT:
471       return SpvDim2D;
472    case GLSL_SAMPLER_DIM_BUF:
473       return SpvDimBuffer;
474    case GLSL_SAMPLER_DIM_EXTERNAL:
475       return SpvDim2D; /* seems dodgy... */
476    case GLSL_SAMPLER_DIM_MS:
477       *is_ms = true;
478       return SpvDim2D;
479    default:
480       fprintf(stderr, "unknown sampler type %d\n", gdim);
481       break;
482    }
483    return SpvDim2D;
484 }
485 
486 uint32_t
zink_binding(gl_shader_stage stage,VkDescriptorType type,int index)487 zink_binding(gl_shader_stage stage, VkDescriptorType type, int index)
488 {
489    if (stage == MESA_SHADER_NONE ||
490        stage >= MESA_SHADER_COMPUTE) {
491       unreachable("not supported");
492    } else {
493       uint32_t stage_offset = (uint32_t)stage * (PIPE_MAX_CONSTANT_BUFFERS +
494                                                  PIPE_MAX_SHADER_SAMPLER_VIEWS);
495 
496       switch (type) {
497       case VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER:
498          assert(index < PIPE_MAX_CONSTANT_BUFFERS);
499          return stage_offset + index;
500 
501       case VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER:
502          assert(index < PIPE_MAX_SHADER_SAMPLER_VIEWS);
503          return stage_offset + PIPE_MAX_CONSTANT_BUFFERS + index;
504 
505       default:
506          unreachable("unexpected type");
507       }
508    }
509 }
510 
511 static void
emit_sampler(struct ntv_context * ctx,struct nir_variable * var)512 emit_sampler(struct ntv_context *ctx, struct nir_variable *var)
513 {
514    const struct glsl_type *type = glsl_without_array(var->type);
515 
516    bool is_ms;
517    SpvDim dimension = type_to_dim(glsl_get_sampler_dim(type), &is_ms);
518 
519    SpvId result_type = get_glsl_basetype(ctx, glsl_get_sampler_result_type(type));
520    SpvId image_type = spirv_builder_type_image(&ctx->builder, result_type,
521                                                dimension, false,
522                                                glsl_sampler_type_is_array(type),
523                                                is_ms, 1,
524                                                SpvImageFormatUnknown);
525 
526    SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
527                                                          image_type);
528    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
529                                                    SpvStorageClassUniformConstant,
530                                                    sampled_type);
531 
532    if (glsl_type_is_array(var->type)) {
533       for (int i = 0; i < glsl_get_length(var->type); ++i) {
534          SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
535                                                SpvStorageClassUniformConstant);
536 
537          if (var->name) {
538             char element_name[100];
539             snprintf(element_name, sizeof(element_name), "%s_%d", var->name, i);
540             spirv_builder_emit_name(&ctx->builder, var_id, var->name);
541          }
542 
543          int index = var->data.binding + i;
544          assert(!(ctx->samplers_used & (1 << index)));
545          assert(!ctx->image_types[index]);
546          ctx->image_types[index] = image_type;
547          ctx->samplers[index] = var_id;
548          ctx->samplers_used |= 1 << index;
549 
550          spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
551                                            var->data.descriptor_set);
552          int binding = zink_binding(ctx->stage,
553                                     VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
554                                     var->data.binding + i);
555          spirv_builder_emit_binding(&ctx->builder, var_id, binding);
556       }
557    } else {
558       SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
559                                             SpvStorageClassUniformConstant);
560 
561       if (var->name)
562          spirv_builder_emit_name(&ctx->builder, var_id, var->name);
563 
564       int index = var->data.binding;
565       assert(!(ctx->samplers_used & (1 << index)));
566       assert(!ctx->image_types[index]);
567       ctx->image_types[index] = image_type;
568       ctx->samplers[index] = var_id;
569       ctx->samplers_used |= 1 << index;
570 
571       spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
572                                         var->data.descriptor_set);
573       int binding = zink_binding(ctx->stage,
574                                  VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER,
575                                  var->data.binding);
576       spirv_builder_emit_binding(&ctx->builder, var_id, binding);
577    }
578 }
579 
580 static void
emit_ubo(struct ntv_context * ctx,struct nir_variable * var)581 emit_ubo(struct ntv_context *ctx, struct nir_variable *var)
582 {
583    uint32_t size = glsl_count_attribute_slots(var->type, false);
584    SpvId vec4_type = get_uvec_type(ctx, 32, 4);
585    SpvId array_length = emit_uint_const(ctx, 32, size);
586    SpvId array_type = spirv_builder_type_array(&ctx->builder, vec4_type,
587                                                array_length);
588    spirv_builder_emit_array_stride(&ctx->builder, array_type, 16);
589 
590    // wrap UBO-array in a struct
591    SpvId struct_type = spirv_builder_type_struct(&ctx->builder, &array_type, 1);
592    if (var->name) {
593       char struct_name[100];
594       snprintf(struct_name, sizeof(struct_name), "struct_%s", var->name);
595       spirv_builder_emit_name(&ctx->builder, struct_type, struct_name);
596    }
597 
598    spirv_builder_emit_decoration(&ctx->builder, struct_type,
599                                  SpvDecorationBlock);
600    spirv_builder_emit_member_offset(&ctx->builder, struct_type, 0, 0);
601 
602 
603    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
604                                                    SpvStorageClassUniform,
605                                                    struct_type);
606 
607    SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
608                                          SpvStorageClassUniform);
609    if (var->name)
610       spirv_builder_emit_name(&ctx->builder, var_id, var->name);
611 
612    assert(ctx->num_ubos < ARRAY_SIZE(ctx->ubos));
613    ctx->ubos[ctx->num_ubos++] = var_id;
614 
615    spirv_builder_emit_descriptor_set(&ctx->builder, var_id,
616                                      var->data.descriptor_set);
617    int binding = zink_binding(ctx->stage,
618                               VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER,
619                               var->data.binding);
620    spirv_builder_emit_binding(&ctx->builder, var_id, binding);
621 }
622 
623 static void
emit_uniform(struct ntv_context * ctx,struct nir_variable * var)624 emit_uniform(struct ntv_context *ctx, struct nir_variable *var)
625 {
626    if (var->data.mode == nir_var_mem_ubo)
627       emit_ubo(ctx, var);
628    else {
629       assert(var->data.mode == nir_var_uniform);
630       if (glsl_type_is_sampler(glsl_without_array(var->type)))
631          emit_sampler(ctx, var);
632    }
633 }
634 
635 static SpvId
get_vec_from_bit_size(struct ntv_context * ctx,uint32_t bit_size,uint32_t num_components)636 get_vec_from_bit_size(struct ntv_context *ctx, uint32_t bit_size, uint32_t num_components)
637 {
638    if (bit_size == 1)
639       return get_bvec_type(ctx, num_components);
640    if (bit_size == 32)
641       return get_uvec_type(ctx, bit_size, num_components);
642    unreachable("unhandled register bit size");
643    return 0;
644 }
645 
646 static SpvId
get_src_ssa(struct ntv_context * ctx,const nir_ssa_def * ssa)647 get_src_ssa(struct ntv_context *ctx, const nir_ssa_def *ssa)
648 {
649    assert(ssa->index < ctx->num_defs);
650    assert(ctx->defs[ssa->index] != 0);
651    return ctx->defs[ssa->index];
652 }
653 
654 static SpvId
get_var_from_reg(struct ntv_context * ctx,nir_register * reg)655 get_var_from_reg(struct ntv_context *ctx, nir_register *reg)
656 {
657    assert(reg->index < ctx->num_regs);
658    assert(ctx->regs[reg->index] != 0);
659    return ctx->regs[reg->index];
660 }
661 
662 static SpvId
get_src_reg(struct ntv_context * ctx,const nir_reg_src * reg)663 get_src_reg(struct ntv_context *ctx, const nir_reg_src *reg)
664 {
665    assert(reg->reg);
666    assert(!reg->indirect);
667    assert(!reg->base_offset);
668 
669    SpvId var = get_var_from_reg(ctx, reg->reg);
670    SpvId type = get_vec_from_bit_size(ctx, reg->reg->bit_size, reg->reg->num_components);
671    return spirv_builder_emit_load(&ctx->builder, type, var);
672 }
673 
674 static SpvId
get_src(struct ntv_context * ctx,nir_src * src)675 get_src(struct ntv_context *ctx, nir_src *src)
676 {
677    if (src->is_ssa)
678       return get_src_ssa(ctx, src->ssa);
679    else
680       return get_src_reg(ctx, &src->reg);
681 }
682 
683 static SpvId
get_alu_src_raw(struct ntv_context * ctx,nir_alu_instr * alu,unsigned src)684 get_alu_src_raw(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
685 {
686    assert(!alu->src[src].negate);
687    assert(!alu->src[src].abs);
688 
689    SpvId def = get_src(ctx, &alu->src[src].src);
690 
691    unsigned used_channels = 0;
692    bool need_swizzle = false;
693    for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
694       if (!nir_alu_instr_channel_used(alu, src, i))
695          continue;
696 
697       used_channels++;
698 
699       if (alu->src[src].swizzle[i] != i)
700          need_swizzle = true;
701    }
702    assert(used_channels != 0);
703 
704    unsigned live_channels = nir_src_num_components(alu->src[src].src);
705    if (used_channels != live_channels)
706       need_swizzle = true;
707 
708    if (!need_swizzle)
709       return def;
710 
711    int bit_size = nir_src_bit_size(alu->src[src].src);
712    assert(bit_size == 1 || bit_size == 32);
713 
714    SpvId raw_type = bit_size == 1 ? spirv_builder_type_bool(&ctx->builder) :
715                                     spirv_builder_type_uint(&ctx->builder, bit_size);
716 
717    if (used_channels == 1) {
718       uint32_t indices[] =  { alu->src[src].swizzle[0] };
719       return spirv_builder_emit_composite_extract(&ctx->builder, raw_type,
720                                                   def, indices,
721                                                   ARRAY_SIZE(indices));
722    } else if (live_channels == 1) {
723       SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
724                                                      raw_type,
725                                                      used_channels);
726 
727       SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
728       for (unsigned i = 0; i < used_channels; ++i)
729         constituents[i] = def;
730 
731       return spirv_builder_emit_composite_construct(&ctx->builder,
732                                                     raw_vec_type,
733                                                     constituents,
734                                                     used_channels);
735    } else {
736       SpvId raw_vec_type = spirv_builder_type_vector(&ctx->builder,
737                                                      raw_type,
738                                                      used_channels);
739 
740       uint32_t components[NIR_MAX_VEC_COMPONENTS] = {0};
741       size_t num_components = 0;
742       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
743          if (!nir_alu_instr_channel_used(alu, src, i))
744             continue;
745 
746          components[num_components++] = alu->src[src].swizzle[i];
747       }
748 
749       return spirv_builder_emit_vector_shuffle(&ctx->builder, raw_vec_type,
750                                                def, def, components,
751                                                num_components);
752    }
753 }
754 
755 static void
store_ssa_def(struct ntv_context * ctx,nir_ssa_def * ssa,SpvId result)756 store_ssa_def(struct ntv_context *ctx, nir_ssa_def *ssa, SpvId result)
757 {
758    assert(result != 0);
759    assert(ssa->index < ctx->num_defs);
760    ctx->defs[ssa->index] = result;
761 }
762 
763 static SpvId
emit_select(struct ntv_context * ctx,SpvId type,SpvId cond,SpvId if_true,SpvId if_false)764 emit_select(struct ntv_context *ctx, SpvId type, SpvId cond,
765             SpvId if_true, SpvId if_false)
766 {
767    return emit_triop(ctx, SpvOpSelect, type, cond, if_true, if_false);
768 }
769 
770 static SpvId
uvec_to_bvec(struct ntv_context * ctx,SpvId value,unsigned num_components)771 uvec_to_bvec(struct ntv_context *ctx, SpvId value, unsigned num_components)
772 {
773    SpvId type = get_bvec_type(ctx, num_components);
774    SpvId zero = get_uvec_constant(ctx, 32, num_components, 0);
775    return emit_binop(ctx, SpvOpINotEqual, type, value, zero);
776 }
777 
778 static SpvId
emit_bitcast(struct ntv_context * ctx,SpvId type,SpvId value)779 emit_bitcast(struct ntv_context *ctx, SpvId type, SpvId value)
780 {
781    return emit_unop(ctx, SpvOpBitcast, type, value);
782 }
783 
784 static SpvId
bitcast_to_uvec(struct ntv_context * ctx,SpvId value,unsigned bit_size,unsigned num_components)785 bitcast_to_uvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
786                 unsigned num_components)
787 {
788    SpvId type = get_uvec_type(ctx, bit_size, num_components);
789    return emit_bitcast(ctx, type, value);
790 }
791 
792 static SpvId
bitcast_to_ivec(struct ntv_context * ctx,SpvId value,unsigned bit_size,unsigned num_components)793 bitcast_to_ivec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
794                 unsigned num_components)
795 {
796    SpvId type = get_ivec_type(ctx, bit_size, num_components);
797    return emit_bitcast(ctx, type, value);
798 }
799 
800 static SpvId
bitcast_to_fvec(struct ntv_context * ctx,SpvId value,unsigned bit_size,unsigned num_components)801 bitcast_to_fvec(struct ntv_context *ctx, SpvId value, unsigned bit_size,
802                unsigned num_components)
803 {
804    SpvId type = get_fvec_type(ctx, bit_size, num_components);
805    return emit_bitcast(ctx, type, value);
806 }
807 
808 static void
store_reg_def(struct ntv_context * ctx,nir_reg_dest * reg,SpvId result)809 store_reg_def(struct ntv_context *ctx, nir_reg_dest *reg, SpvId result)
810 {
811    SpvId var = get_var_from_reg(ctx, reg->reg);
812    assert(var);
813    spirv_builder_emit_store(&ctx->builder, var, result);
814 }
815 
816 static void
store_dest_raw(struct ntv_context * ctx,nir_dest * dest,SpvId result)817 store_dest_raw(struct ntv_context *ctx, nir_dest *dest, SpvId result)
818 {
819    if (dest->is_ssa)
820       store_ssa_def(ctx, &dest->ssa, result);
821    else
822       store_reg_def(ctx, &dest->reg, result);
823 }
824 
825 static SpvId
store_dest(struct ntv_context * ctx,nir_dest * dest,SpvId result,nir_alu_type type)826 store_dest(struct ntv_context *ctx, nir_dest *dest, SpvId result, nir_alu_type type)
827 {
828    unsigned num_components = nir_dest_num_components(*dest);
829    unsigned bit_size = nir_dest_bit_size(*dest);
830 
831    if (bit_size != 1) {
832       switch (nir_alu_type_get_base_type(type)) {
833       case nir_type_bool:
834          assert("bool should have bit-size 1");
835 
836       case nir_type_uint:
837          break; /* nothing to do! */
838 
839       case nir_type_int:
840       case nir_type_float:
841          result = bitcast_to_uvec(ctx, result, bit_size, num_components);
842          break;
843 
844       default:
845          unreachable("unsupported nir_alu_type");
846       }
847    }
848 
849    store_dest_raw(ctx, dest, result);
850    return result;
851 }
852 
853 static SpvId
emit_unop(struct ntv_context * ctx,SpvOp op,SpvId type,SpvId src)854 emit_unop(struct ntv_context *ctx, SpvOp op, SpvId type, SpvId src)
855 {
856    return spirv_builder_emit_unop(&ctx->builder, op, type, src);
857 }
858 
859 /* return the intended xfb output vec type based on base type and vector size */
860 static SpvId
get_output_type(struct ntv_context * ctx,unsigned register_index,unsigned num_components)861 get_output_type(struct ntv_context *ctx, unsigned register_index, unsigned num_components)
862 {
863    const struct glsl_type *out_type = ctx->so_output_gl_types[register_index];
864    enum glsl_base_type base_type = glsl_get_base_type(out_type);
865    if (base_type == GLSL_TYPE_ARRAY)
866       base_type = glsl_get_base_type(glsl_without_array(out_type));
867 
868    switch (base_type) {
869    case GLSL_TYPE_BOOL:
870       return get_bvec_type(ctx, num_components);
871 
872    case GLSL_TYPE_FLOAT:
873       return get_fvec_type(ctx, 32, num_components);
874 
875    case GLSL_TYPE_INT:
876       return get_ivec_type(ctx, 32, num_components);
877 
878    case GLSL_TYPE_UINT:
879       return get_uvec_type(ctx, 32, num_components);
880 
881    default:
882       break;
883    }
884    unreachable("unknown type");
885    return 0;
886 }
887 
888 /* for streamout create new outputs, as streamout can be done on individual components,
889    from complete outputs, so we just can't use the created packed outputs */
890 static void
emit_so_info(struct ntv_context * ctx,unsigned max_output_location,const struct pipe_stream_output_info * so_info,struct pipe_stream_output_info * local_so_info)891 emit_so_info(struct ntv_context *ctx, unsigned max_output_location,
892              const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
893 {
894    for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
895       struct pipe_stream_output so_output = local_so_info->output[i];
896       SpvId out_type = get_output_type(ctx, so_output.register_index, so_output.num_components);
897       SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
898                                                       SpvStorageClassOutput,
899                                                       out_type);
900       SpvId var_id = spirv_builder_emit_var(&ctx->builder, pointer_type,
901                                             SpvStorageClassOutput);
902       char name[10];
903 
904       snprintf(name, 10, "xfb%d", i);
905       spirv_builder_emit_name(&ctx->builder, var_id, name);
906       spirv_builder_emit_offset(&ctx->builder, var_id, (so_output.dst_offset * 4));
907       spirv_builder_emit_xfb_buffer(&ctx->builder, var_id, so_output.output_buffer);
908       spirv_builder_emit_xfb_stride(&ctx->builder, var_id, so_info->stride[so_output.output_buffer] * 4);
909 
910       /* output location is incremented by VARYING_SLOT_VAR0 for non-builtins in vtn,
911        * so we need to ensure that the new xfb location slot doesn't conflict with any previously-emitted
912        * outputs.
913        *
914        * if there's no previous outputs that take up user slots (VAR0+) then we can start right after the
915        * glsl builtin reserved slots, otherwise we start just after the adjusted user output slot
916        */
917       uint32_t location = NTV_MIN_RESERVED_SLOTS + i;
918       if (max_output_location >= VARYING_SLOT_VAR0)
919          location = max_output_location - VARYING_SLOT_VAR0 + 1 + i;
920       assert(location < VARYING_SLOT_VAR0);
921       assert(location <= VARYING_SLOT_VAR0 - 8 ||
922              !ctx->seen_texcoord[VARYING_SLOT_VAR0 - location - 1]);
923       spirv_builder_emit_location(&ctx->builder, var_id, location);
924 
925       /* note: gl_ClipDistance[4] can the 0-indexed member of VARYING_SLOT_CLIP_DIST1 here,
926        * so this is still the 0 component
927        */
928       if (so_output.start_component)
929          spirv_builder_emit_component(&ctx->builder, var_id, so_output.start_component);
930 
931       uint32_t *key = ralloc_size(ctx->mem_ctx, sizeof(uint32_t));
932       *key = (uint32_t)so_output.register_index << 2 | so_output.start_component;
933       _mesa_hash_table_insert(ctx->so_outputs, key, (void *)(intptr_t)var_id);
934 
935       assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
936       ctx->entry_ifaces[ctx->num_entry_ifaces++] = var_id;
937    }
938 }
939 
940 static void
emit_so_outputs(struct ntv_context * ctx,const struct pipe_stream_output_info * so_info,struct pipe_stream_output_info * local_so_info)941 emit_so_outputs(struct ntv_context *ctx,
942                 const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
943 {
944    SpvId loaded_outputs[VARYING_SLOT_MAX] = {};
945    for (unsigned i = 0; i < local_so_info->num_outputs; i++) {
946       uint32_t components[NIR_MAX_VEC_COMPONENTS];
947       struct pipe_stream_output so_output = local_so_info->output[i];
948       uint32_t so_key = (uint32_t) so_output.register_index << 2 | so_output.start_component;
949       struct hash_entry *he = _mesa_hash_table_search(ctx->so_outputs, &so_key);
950       assert(he);
951       SpvId so_output_var_id = (SpvId)(intptr_t)he->data;
952 
953       SpvId type = get_output_type(ctx, so_output.register_index, so_output.num_components);
954       SpvId output = ctx->outputs[so_output.register_index];
955       SpvId output_type = ctx->so_output_types[so_output.register_index];
956       const struct glsl_type *out_type = ctx->so_output_gl_types[so_output.register_index];
957 
958       if (!loaded_outputs[so_output.register_index])
959          loaded_outputs[so_output.register_index] = spirv_builder_emit_load(&ctx->builder, output_type, output);
960       SpvId src = loaded_outputs[so_output.register_index];
961 
962       SpvId result;
963 
964       for (unsigned c = 0; c < so_output.num_components; c++) {
965          components[c] = so_output.start_component + c;
966          /* this is the second half of a 2 * vec4 array */
967          if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
968             components[c] += 4;
969       }
970 
971       /* if we're emitting a scalar or the type we're emitting matches the output's original type and we're
972        * emitting the same number of components, then we can skip any sort of conversion here
973        */
974       if (glsl_type_is_scalar(out_type) || (type == output_type && glsl_get_length(out_type) == so_output.num_components))
975          result = src;
976       else {
977          /* OpCompositeExtract can only extract scalars for our use here */
978          if (so_output.num_components == 1) {
979             result = spirv_builder_emit_composite_extract(&ctx->builder, type, src, components, so_output.num_components);
980          } else if (glsl_type_is_vector(out_type)) {
981             /* OpVectorShuffle can select vector members into a differently-sized vector */
982             result = spirv_builder_emit_vector_shuffle(&ctx->builder, type,
983                                                              src, src,
984                                                              components, so_output.num_components);
985             result = emit_unop(ctx, SpvOpBitcast, type, result);
986          } else {
987              /* for arrays, we need to manually extract each desired member
988               * and re-pack them into the desired output type
989               */
990              for (unsigned c = 0; c < so_output.num_components; c++) {
991                 uint32_t member[] = { so_output.start_component + c };
992                 SpvId base_type = get_glsl_type(ctx, glsl_without_array(out_type));
993 
994                 if (ctx->stage == MESA_SHADER_VERTEX && so_output.register_index == VARYING_SLOT_CLIP_DIST1)
995                    member[0] += 4;
996                 components[c] = spirv_builder_emit_composite_extract(&ctx->builder, base_type, src, member, 1);
997              }
998              result = spirv_builder_emit_composite_construct(&ctx->builder, type, components, so_output.num_components);
999          }
1000       }
1001 
1002       spirv_builder_emit_store(&ctx->builder, so_output_var_id, result);
1003    }
1004 }
1005 
1006 static SpvId
emit_binop(struct ntv_context * ctx,SpvOp op,SpvId type,SpvId src0,SpvId src1)1007 emit_binop(struct ntv_context *ctx, SpvOp op, SpvId type,
1008            SpvId src0, SpvId src1)
1009 {
1010    return spirv_builder_emit_binop(&ctx->builder, op, type, src0, src1);
1011 }
1012 
1013 static SpvId
emit_triop(struct ntv_context * ctx,SpvOp op,SpvId type,SpvId src0,SpvId src1,SpvId src2)1014 emit_triop(struct ntv_context *ctx, SpvOp op, SpvId type,
1015            SpvId src0, SpvId src1, SpvId src2)
1016 {
1017    return spirv_builder_emit_triop(&ctx->builder, op, type, src0, src1, src2);
1018 }
1019 
1020 static SpvId
emit_builtin_unop(struct ntv_context * ctx,enum GLSLstd450 op,SpvId type,SpvId src)1021 emit_builtin_unop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1022                   SpvId src)
1023 {
1024    SpvId args[] = { src };
1025    return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1026                                       op, args, ARRAY_SIZE(args));
1027 }
1028 
1029 static SpvId
emit_builtin_binop(struct ntv_context * ctx,enum GLSLstd450 op,SpvId type,SpvId src0,SpvId src1)1030 emit_builtin_binop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1031                    SpvId src0, SpvId src1)
1032 {
1033    SpvId args[] = { src0, src1 };
1034    return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1035                                       op, args, ARRAY_SIZE(args));
1036 }
1037 
1038 static SpvId
emit_builtin_triop(struct ntv_context * ctx,enum GLSLstd450 op,SpvId type,SpvId src0,SpvId src1,SpvId src2)1039 emit_builtin_triop(struct ntv_context *ctx, enum GLSLstd450 op, SpvId type,
1040                    SpvId src0, SpvId src1, SpvId src2)
1041 {
1042    SpvId args[] = { src0, src1, src2 };
1043    return spirv_builder_emit_ext_inst(&ctx->builder, type, ctx->GLSL_std_450,
1044                                       op, args, ARRAY_SIZE(args));
1045 }
1046 
1047 static SpvId
get_fvec_constant(struct ntv_context * ctx,unsigned bit_size,unsigned num_components,float value)1048 get_fvec_constant(struct ntv_context *ctx, unsigned bit_size,
1049                   unsigned num_components, float value)
1050 {
1051    assert(bit_size == 32);
1052 
1053    SpvId result = emit_float_const(ctx, bit_size, value);
1054    if (num_components == 1)
1055       return result;
1056 
1057    assert(num_components > 1);
1058    SpvId components[num_components];
1059    for (int i = 0; i < num_components; i++)
1060       components[i] = result;
1061 
1062    SpvId type = get_fvec_type(ctx, bit_size, num_components);
1063    return spirv_builder_const_composite(&ctx->builder, type, components,
1064                                         num_components);
1065 }
1066 
1067 static SpvId
get_uvec_constant(struct ntv_context * ctx,unsigned bit_size,unsigned num_components,uint32_t value)1068 get_uvec_constant(struct ntv_context *ctx, unsigned bit_size,
1069                   unsigned num_components, uint32_t value)
1070 {
1071    assert(bit_size == 32);
1072 
1073    SpvId result = emit_uint_const(ctx, bit_size, value);
1074    if (num_components == 1)
1075       return result;
1076 
1077    assert(num_components > 1);
1078    SpvId components[num_components];
1079    for (int i = 0; i < num_components; i++)
1080       components[i] = result;
1081 
1082    SpvId type = get_uvec_type(ctx, bit_size, num_components);
1083    return spirv_builder_const_composite(&ctx->builder, type, components,
1084                                         num_components);
1085 }
1086 
1087 static SpvId
get_ivec_constant(struct ntv_context * ctx,unsigned bit_size,unsigned num_components,int32_t value)1088 get_ivec_constant(struct ntv_context *ctx, unsigned bit_size,
1089                   unsigned num_components, int32_t value)
1090 {
1091    assert(bit_size == 32);
1092 
1093    SpvId result = emit_int_const(ctx, bit_size, value);
1094    if (num_components == 1)
1095       return result;
1096 
1097    assert(num_components > 1);
1098    SpvId components[num_components];
1099    for (int i = 0; i < num_components; i++)
1100       components[i] = result;
1101 
1102    SpvId type = get_ivec_type(ctx, bit_size, num_components);
1103    return spirv_builder_const_composite(&ctx->builder, type, components,
1104                                         num_components);
1105 }
1106 
1107 static inline unsigned
alu_instr_src_components(const nir_alu_instr * instr,unsigned src)1108 alu_instr_src_components(const nir_alu_instr *instr, unsigned src)
1109 {
1110    if (nir_op_infos[instr->op].input_sizes[src] > 0)
1111       return nir_op_infos[instr->op].input_sizes[src];
1112 
1113    if (instr->dest.dest.is_ssa)
1114       return instr->dest.dest.ssa.num_components;
1115    else
1116       return instr->dest.dest.reg.reg->num_components;
1117 }
1118 
1119 static SpvId
get_alu_src(struct ntv_context * ctx,nir_alu_instr * alu,unsigned src)1120 get_alu_src(struct ntv_context *ctx, nir_alu_instr *alu, unsigned src)
1121 {
1122    SpvId raw_value = get_alu_src_raw(ctx, alu, src);
1123 
1124    unsigned num_components = alu_instr_src_components(alu, src);
1125    unsigned bit_size = nir_src_bit_size(alu->src[src].src);
1126    nir_alu_type type = nir_op_infos[alu->op].input_types[src];
1127 
1128    if (bit_size == 1)
1129       return raw_value;
1130    else {
1131       switch (nir_alu_type_get_base_type(type)) {
1132       case nir_type_bool:
1133          unreachable("bool should have bit-size 1");
1134 
1135       case nir_type_int:
1136          return bitcast_to_ivec(ctx, raw_value, bit_size, num_components);
1137 
1138       case nir_type_uint:
1139          return raw_value;
1140 
1141       case nir_type_float:
1142          return bitcast_to_fvec(ctx, raw_value, bit_size, num_components);
1143 
1144       default:
1145          unreachable("unknown nir_alu_type");
1146       }
1147    }
1148 }
1149 
1150 static SpvId
store_alu_result(struct ntv_context * ctx,nir_alu_instr * alu,SpvId result)1151 store_alu_result(struct ntv_context *ctx, nir_alu_instr *alu, SpvId result)
1152 {
1153    assert(!alu->dest.saturate);
1154    return store_dest(ctx, &alu->dest.dest, result,
1155                      nir_op_infos[alu->op].output_type);
1156 }
1157 
1158 static SpvId
get_dest_type(struct ntv_context * ctx,nir_dest * dest,nir_alu_type type)1159 get_dest_type(struct ntv_context *ctx, nir_dest *dest, nir_alu_type type)
1160 {
1161    unsigned num_components = nir_dest_num_components(*dest);
1162    unsigned bit_size = nir_dest_bit_size(*dest);
1163 
1164    if (bit_size == 1)
1165       return get_bvec_type(ctx, num_components);
1166 
1167    switch (nir_alu_type_get_base_type(type)) {
1168    case nir_type_bool:
1169       unreachable("bool should have bit-size 1");
1170 
1171    case nir_type_int:
1172       return get_ivec_type(ctx, bit_size, num_components);
1173 
1174    case nir_type_uint:
1175       return get_uvec_type(ctx, bit_size, num_components);
1176 
1177    case nir_type_float:
1178       return get_fvec_type(ctx, bit_size, num_components);
1179 
1180    default:
1181       unreachable("unsupported nir_alu_type");
1182    }
1183 }
1184 
1185 static void
emit_alu(struct ntv_context * ctx,nir_alu_instr * alu)1186 emit_alu(struct ntv_context *ctx, nir_alu_instr *alu)
1187 {
1188    SpvId src[nir_op_infos[alu->op].num_inputs];
1189    unsigned in_bit_sizes[nir_op_infos[alu->op].num_inputs];
1190    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) {
1191       src[i] = get_alu_src(ctx, alu, i);
1192       in_bit_sizes[i] = nir_src_bit_size(alu->src[i].src);
1193    }
1194 
1195    SpvId dest_type = get_dest_type(ctx, &alu->dest.dest,
1196                                    nir_op_infos[alu->op].output_type);
1197    unsigned bit_size = nir_dest_bit_size(alu->dest.dest);
1198    unsigned num_components = nir_dest_num_components(alu->dest.dest);
1199 
1200    SpvId result = 0;
1201    switch (alu->op) {
1202    case nir_op_mov:
1203       assert(nir_op_infos[alu->op].num_inputs == 1);
1204       result = src[0];
1205       break;
1206 
1207 #define UNOP(nir_op, spirv_op) \
1208    case nir_op: \
1209       assert(nir_op_infos[alu->op].num_inputs == 1); \
1210       result = emit_unop(ctx, spirv_op, dest_type, src[0]); \
1211       break;
1212 
1213    UNOP(nir_op_ineg, SpvOpSNegate)
1214    UNOP(nir_op_fneg, SpvOpFNegate)
1215    UNOP(nir_op_fddx, SpvOpDPdx)
1216    UNOP(nir_op_fddx_coarse, SpvOpDPdxCoarse)
1217    UNOP(nir_op_fddx_fine, SpvOpDPdxFine)
1218    UNOP(nir_op_fddy, SpvOpDPdy)
1219    UNOP(nir_op_fddy_coarse, SpvOpDPdyCoarse)
1220    UNOP(nir_op_fddy_fine, SpvOpDPdyFine)
1221    UNOP(nir_op_f2i32, SpvOpConvertFToS)
1222    UNOP(nir_op_f2u32, SpvOpConvertFToU)
1223    UNOP(nir_op_i2f32, SpvOpConvertSToF)
1224    UNOP(nir_op_u2f32, SpvOpConvertUToF)
1225    UNOP(nir_op_bitfield_reverse, SpvOpBitReverse)
1226 #undef UNOP
1227 
1228    case nir_op_inot:
1229       if (bit_size == 1)
1230          result = emit_unop(ctx, SpvOpLogicalNot, dest_type, src[0]);
1231       else
1232          result = emit_unop(ctx, SpvOpNot, dest_type, src[0]);
1233       break;
1234 
1235    case nir_op_b2i32:
1236       assert(nir_op_infos[alu->op].num_inputs == 1);
1237       result = emit_select(ctx, dest_type, src[0],
1238                            get_ivec_constant(ctx, 32, num_components, 1),
1239                            get_ivec_constant(ctx, 32, num_components, 0));
1240       break;
1241 
1242    case nir_op_b2f32:
1243       assert(nir_op_infos[alu->op].num_inputs == 1);
1244       result = emit_select(ctx, dest_type, src[0],
1245                            get_fvec_constant(ctx, 32, num_components, 1),
1246                            get_fvec_constant(ctx, 32, num_components, 0));
1247       break;
1248 
1249 #define BUILTIN_UNOP(nir_op, spirv_op) \
1250    case nir_op: \
1251       assert(nir_op_infos[alu->op].num_inputs == 1); \
1252       result = emit_builtin_unop(ctx, spirv_op, dest_type, src[0]); \
1253       break;
1254 
1255    BUILTIN_UNOP(nir_op_iabs, GLSLstd450SAbs)
1256    BUILTIN_UNOP(nir_op_fabs, GLSLstd450FAbs)
1257    BUILTIN_UNOP(nir_op_fsqrt, GLSLstd450Sqrt)
1258    BUILTIN_UNOP(nir_op_frsq, GLSLstd450InverseSqrt)
1259    BUILTIN_UNOP(nir_op_flog2, GLSLstd450Log2)
1260    BUILTIN_UNOP(nir_op_fexp2, GLSLstd450Exp2)
1261    BUILTIN_UNOP(nir_op_ffract, GLSLstd450Fract)
1262    BUILTIN_UNOP(nir_op_ffloor, GLSLstd450Floor)
1263    BUILTIN_UNOP(nir_op_fceil, GLSLstd450Ceil)
1264    BUILTIN_UNOP(nir_op_ftrunc, GLSLstd450Trunc)
1265    BUILTIN_UNOP(nir_op_fround_even, GLSLstd450RoundEven)
1266    BUILTIN_UNOP(nir_op_fsign, GLSLstd450FSign)
1267    BUILTIN_UNOP(nir_op_isign, GLSLstd450SSign)
1268    BUILTIN_UNOP(nir_op_fsin, GLSLstd450Sin)
1269    BUILTIN_UNOP(nir_op_fcos, GLSLstd450Cos)
1270 #undef BUILTIN_UNOP
1271 
1272    case nir_op_frcp:
1273       assert(nir_op_infos[alu->op].num_inputs == 1);
1274       result = emit_binop(ctx, SpvOpFDiv, dest_type,
1275                           get_fvec_constant(ctx, bit_size, num_components, 1),
1276                           src[0]);
1277       break;
1278 
1279    case nir_op_f2b1:
1280       assert(nir_op_infos[alu->op].num_inputs == 1);
1281       result = emit_binop(ctx, SpvOpFOrdNotEqual, dest_type, src[0],
1282                           get_fvec_constant(ctx,
1283                                             nir_src_bit_size(alu->src[0].src),
1284                                             num_components, 0));
1285       break;
1286    case nir_op_i2b1:
1287       assert(nir_op_infos[alu->op].num_inputs == 1);
1288       result = emit_binop(ctx, SpvOpINotEqual, dest_type, src[0],
1289                           get_ivec_constant(ctx,
1290                                             nir_src_bit_size(alu->src[0].src),
1291                                             num_components, 0));
1292       break;
1293 
1294 
1295 #define BINOP(nir_op, spirv_op) \
1296    case nir_op: \
1297       assert(nir_op_infos[alu->op].num_inputs == 2); \
1298       result = emit_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1299       break;
1300 
1301    BINOP(nir_op_iadd, SpvOpIAdd)
1302    BINOP(nir_op_isub, SpvOpISub)
1303    BINOP(nir_op_imul, SpvOpIMul)
1304    BINOP(nir_op_idiv, SpvOpSDiv)
1305    BINOP(nir_op_udiv, SpvOpUDiv)
1306    BINOP(nir_op_umod, SpvOpUMod)
1307    BINOP(nir_op_fadd, SpvOpFAdd)
1308    BINOP(nir_op_fsub, SpvOpFSub)
1309    BINOP(nir_op_fmul, SpvOpFMul)
1310    BINOP(nir_op_fdiv, SpvOpFDiv)
1311    BINOP(nir_op_fmod, SpvOpFMod)
1312    BINOP(nir_op_ilt, SpvOpSLessThan)
1313    BINOP(nir_op_ige, SpvOpSGreaterThanEqual)
1314    BINOP(nir_op_ult, SpvOpULessThan)
1315    BINOP(nir_op_uge, SpvOpUGreaterThanEqual)
1316    BINOP(nir_op_flt, SpvOpFOrdLessThan)
1317    BINOP(nir_op_fge, SpvOpFOrdGreaterThanEqual)
1318    BINOP(nir_op_feq, SpvOpFOrdEqual)
1319    BINOP(nir_op_fne, SpvOpFUnordNotEqual)
1320    BINOP(nir_op_ishl, SpvOpShiftLeftLogical)
1321    BINOP(nir_op_ishr, SpvOpShiftRightArithmetic)
1322    BINOP(nir_op_ushr, SpvOpShiftRightLogical)
1323    BINOP(nir_op_ixor, SpvOpBitwiseXor)
1324 #undef BINOP
1325 
1326 #define BINOP_LOG(nir_op, spv_op, spv_log_op) \
1327    case nir_op: \
1328       assert(nir_op_infos[alu->op].num_inputs == 2); \
1329       if (nir_src_bit_size(alu->src[0].src) == 1) \
1330          result = emit_binop(ctx, spv_log_op, dest_type, src[0], src[1]); \
1331       else \
1332          result = emit_binop(ctx, spv_op, dest_type, src[0], src[1]); \
1333       break;
1334 
1335    BINOP_LOG(nir_op_iand, SpvOpBitwiseAnd, SpvOpLogicalAnd)
1336    BINOP_LOG(nir_op_ior, SpvOpBitwiseOr, SpvOpLogicalOr)
1337    BINOP_LOG(nir_op_ieq, SpvOpIEqual, SpvOpLogicalEqual)
1338    BINOP_LOG(nir_op_ine, SpvOpINotEqual, SpvOpLogicalNotEqual)
1339 #undef BINOP_LOG
1340 
1341 #define BUILTIN_BINOP(nir_op, spirv_op) \
1342    case nir_op: \
1343       assert(nir_op_infos[alu->op].num_inputs == 2); \
1344       result = emit_builtin_binop(ctx, spirv_op, dest_type, src[0], src[1]); \
1345       break;
1346 
1347    BUILTIN_BINOP(nir_op_fmin, GLSLstd450FMin)
1348    BUILTIN_BINOP(nir_op_fmax, GLSLstd450FMax)
1349    BUILTIN_BINOP(nir_op_imin, GLSLstd450SMin)
1350    BUILTIN_BINOP(nir_op_imax, GLSLstd450SMax)
1351    BUILTIN_BINOP(nir_op_umin, GLSLstd450UMin)
1352    BUILTIN_BINOP(nir_op_umax, GLSLstd450UMax)
1353 #undef BUILTIN_BINOP
1354 
1355    case nir_op_fdot2:
1356    case nir_op_fdot3:
1357    case nir_op_fdot4:
1358       assert(nir_op_infos[alu->op].num_inputs == 2);
1359       result = emit_binop(ctx, SpvOpDot, dest_type, src[0], src[1]);
1360       break;
1361 
1362    case nir_op_fdph:
1363       unreachable("should already be lowered away");
1364 
1365    case nir_op_seq:
1366    case nir_op_sne:
1367    case nir_op_slt:
1368    case nir_op_sge: {
1369       assert(nir_op_infos[alu->op].num_inputs == 2);
1370       int num_components = nir_dest_num_components(alu->dest.dest);
1371       SpvId bool_type = get_bvec_type(ctx, num_components);
1372 
1373       SpvId zero = emit_float_const(ctx, bit_size, 0.0f);
1374       SpvId one = emit_float_const(ctx, bit_size, 1.0f);
1375       if (num_components > 1) {
1376          SpvId zero_comps[num_components], one_comps[num_components];
1377          for (int i = 0; i < num_components; i++) {
1378             zero_comps[i] = zero;
1379             one_comps[i] = one;
1380          }
1381 
1382          zero = spirv_builder_const_composite(&ctx->builder, dest_type,
1383                                               zero_comps, num_components);
1384          one = spirv_builder_const_composite(&ctx->builder, dest_type,
1385                                              one_comps, num_components);
1386       }
1387 
1388       SpvOp op;
1389       switch (alu->op) {
1390       case nir_op_seq: op = SpvOpFOrdEqual; break;
1391       case nir_op_sne: op = SpvOpFOrdNotEqual; break;
1392       case nir_op_slt: op = SpvOpFOrdLessThan; break;
1393       case nir_op_sge: op = SpvOpFOrdGreaterThanEqual; break;
1394       default: unreachable("unexpected op");
1395       }
1396 
1397       result = emit_binop(ctx, op, bool_type, src[0], src[1]);
1398       result = emit_select(ctx, dest_type, result, one, zero);
1399       }
1400       break;
1401 
1402    case nir_op_flrp:
1403       assert(nir_op_infos[alu->op].num_inputs == 3);
1404       result = emit_builtin_triop(ctx, GLSLstd450FMix, dest_type,
1405                                   src[0], src[1], src[2]);
1406       break;
1407 
1408    case nir_op_fcsel:
1409       result = emit_binop(ctx, SpvOpFOrdGreaterThan,
1410                           get_bvec_type(ctx, num_components),
1411                           src[0],
1412                           get_fvec_constant(ctx,
1413                                             nir_src_bit_size(alu->src[0].src),
1414                                             num_components, 0));
1415       result = emit_select(ctx, dest_type, result, src[1], src[2]);
1416       break;
1417 
1418    case nir_op_bcsel:
1419       assert(nir_op_infos[alu->op].num_inputs == 3);
1420       result = emit_select(ctx, dest_type, src[0], src[1], src[2]);
1421       break;
1422 
1423    case nir_op_bany_fnequal2:
1424    case nir_op_bany_fnequal3:
1425    case nir_op_bany_fnequal4: {
1426       assert(nir_op_infos[alu->op].num_inputs == 2);
1427       assert(alu_instr_src_components(alu, 0) ==
1428              alu_instr_src_components(alu, 1));
1429       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1430       /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1431       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpFOrdNotEqual;
1432       result = emit_binop(ctx, op,
1433                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1434                           src[0], src[1]);
1435       result = emit_unop(ctx, SpvOpAny, dest_type, result);
1436       break;
1437    }
1438 
1439    case nir_op_ball_fequal2:
1440    case nir_op_ball_fequal3:
1441    case nir_op_ball_fequal4: {
1442       assert(nir_op_infos[alu->op].num_inputs == 2);
1443       assert(alu_instr_src_components(alu, 0) ==
1444              alu_instr_src_components(alu, 1));
1445       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1446       /* The type of Operand 1 and Operand 2 must be a scalar or vector of floating-point type. */
1447       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpFOrdEqual;
1448       result = emit_binop(ctx, op,
1449                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1450                           src[0], src[1]);
1451       result = emit_unop(ctx, SpvOpAll, dest_type, result);
1452       break;
1453    }
1454 
1455    case nir_op_bany_inequal2:
1456    case nir_op_bany_inequal3:
1457    case nir_op_bany_inequal4: {
1458       assert(nir_op_infos[alu->op].num_inputs == 2);
1459       assert(alu_instr_src_components(alu, 0) ==
1460              alu_instr_src_components(alu, 1));
1461       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1462       /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1463       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalNotEqual : SpvOpINotEqual;
1464       result = emit_binop(ctx, op,
1465                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1466                           src[0], src[1]);
1467       result = emit_unop(ctx, SpvOpAny, dest_type, result);
1468       break;
1469    }
1470 
1471    case nir_op_ball_iequal2:
1472    case nir_op_ball_iequal3:
1473    case nir_op_ball_iequal4: {
1474       assert(nir_op_infos[alu->op].num_inputs == 2);
1475       assert(alu_instr_src_components(alu, 0) ==
1476              alu_instr_src_components(alu, 1));
1477       assert(in_bit_sizes[0] == in_bit_sizes[1]);
1478       /* The type of Operand 1 and Operand 2 must be a scalar or vector of integer type. */
1479       SpvOp op = in_bit_sizes[0] == 1 ? SpvOpLogicalEqual : SpvOpIEqual;
1480       result = emit_binop(ctx, op,
1481                           get_bvec_type(ctx, alu_instr_src_components(alu, 0)),
1482                           src[0], src[1]);
1483       result = emit_unop(ctx, SpvOpAll, dest_type, result);
1484       break;
1485    }
1486 
1487    case nir_op_vec2:
1488    case nir_op_vec3:
1489    case nir_op_vec4: {
1490       int num_inputs = nir_op_infos[alu->op].num_inputs;
1491       assert(2 <= num_inputs && num_inputs <= 4);
1492       result = spirv_builder_emit_composite_construct(&ctx->builder, dest_type,
1493                                                       src, num_inputs);
1494    }
1495    break;
1496 
1497    default:
1498       fprintf(stderr, "emit_alu: not implemented (%s)\n",
1499               nir_op_infos[alu->op].name);
1500 
1501       unreachable("unsupported opcode");
1502       return;
1503    }
1504 
1505    store_alu_result(ctx, alu, result);
1506 }
1507 
1508 static void
emit_load_const(struct ntv_context * ctx,nir_load_const_instr * load_const)1509 emit_load_const(struct ntv_context *ctx, nir_load_const_instr *load_const)
1510 {
1511    unsigned bit_size = load_const->def.bit_size;
1512    unsigned num_components = load_const->def.num_components;
1513 
1514    SpvId constant;
1515    if (num_components > 1) {
1516       SpvId components[num_components];
1517       SpvId type = get_vec_from_bit_size(ctx, bit_size, num_components);
1518       if (bit_size == 1) {
1519          for (int i = 0; i < num_components; i++)
1520             components[i] = spirv_builder_const_bool(&ctx->builder,
1521                                                      load_const->value[i].b);
1522 
1523       } else {
1524          for (int i = 0; i < num_components; i++)
1525             components[i] = emit_uint_const(ctx, bit_size,
1526                                             load_const->value[i].u32);
1527 
1528       }
1529       constant = spirv_builder_const_composite(&ctx->builder, type,
1530                                                components, num_components);
1531    } else {
1532       assert(num_components == 1);
1533       if (bit_size == 1)
1534          constant = spirv_builder_const_bool(&ctx->builder,
1535                                              load_const->value[0].b);
1536       else
1537          constant = emit_uint_const(ctx, bit_size, load_const->value[0].u32);
1538    }
1539 
1540    store_ssa_def(ctx, &load_const->def, constant);
1541 }
1542 
1543 static void
emit_load_ubo(struct ntv_context * ctx,nir_intrinsic_instr * intr)1544 emit_load_ubo(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1545 {
1546    nir_const_value *const_block_index = nir_src_as_const_value(intr->src[0]);
1547    assert(const_block_index); // no dynamic indexing for now
1548    assert(const_block_index->u32 == 0); // we only support the default UBO for now
1549 
1550    nir_const_value *const_offset = nir_src_as_const_value(intr->src[1]);
1551    if (const_offset) {
1552       SpvId uvec4_type = get_uvec_type(ctx, 32, 4);
1553       SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1554                                                       SpvStorageClassUniform,
1555                                                       uvec4_type);
1556 
1557       unsigned idx = const_offset->u32;
1558       SpvId member = emit_uint_const(ctx, 32, 0);
1559       SpvId offset = emit_uint_const(ctx, 32, idx);
1560       SpvId offsets[] = { member, offset };
1561       SpvId ptr = spirv_builder_emit_access_chain(&ctx->builder, pointer_type,
1562                                                   ctx->ubos[0], offsets,
1563                                                   ARRAY_SIZE(offsets));
1564       SpvId result = spirv_builder_emit_load(&ctx->builder, uvec4_type, ptr);
1565 
1566       SpvId type = get_dest_uvec_type(ctx, &intr->dest);
1567       unsigned num_components = nir_dest_num_components(intr->dest);
1568       if (num_components == 1) {
1569          uint32_t components[] = { 0 };
1570          result = spirv_builder_emit_composite_extract(&ctx->builder,
1571                                                        type,
1572                                                        result, components,
1573                                                        1);
1574       } else if (num_components < 4) {
1575          SpvId constituents[num_components];
1576          SpvId uint_type = spirv_builder_type_uint(&ctx->builder, 32);
1577          for (uint32_t i = 0; i < num_components; ++i)
1578             constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1579                                                                    uint_type,
1580                                                                    result, &i,
1581                                                                    1);
1582 
1583          result = spirv_builder_emit_composite_construct(&ctx->builder,
1584                                                          type,
1585                                                          constituents,
1586                                                          num_components);
1587       }
1588 
1589       if (nir_dest_bit_size(intr->dest) == 1)
1590          result = uvec_to_bvec(ctx, result, num_components);
1591 
1592       store_dest(ctx, &intr->dest, result, nir_type_uint);
1593    } else
1594       unreachable("uniform-addressing not yet supported");
1595 }
1596 
1597 static void
emit_discard(struct ntv_context * ctx,nir_intrinsic_instr * intr)1598 emit_discard(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1599 {
1600    assert(ctx->block_started);
1601    spirv_builder_emit_kill(&ctx->builder);
1602    /* discard is weird in NIR, so let's just create an unreachable block after
1603       it and hope that the vulkan driver will DCE any instructinos in it. */
1604    spirv_builder_label(&ctx->builder, spirv_builder_new_id(&ctx->builder));
1605 }
1606 
1607 static void
emit_load_deref(struct ntv_context * ctx,nir_intrinsic_instr * intr)1608 emit_load_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1609 {
1610    SpvId ptr = get_src(ctx, intr->src);
1611 
1612    SpvId result = spirv_builder_emit_load(&ctx->builder,
1613                                           get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type),
1614                                           ptr);
1615    unsigned num_components = nir_dest_num_components(intr->dest);
1616    unsigned bit_size = nir_dest_bit_size(intr->dest);
1617    result = bitcast_to_uvec(ctx, result, bit_size, num_components);
1618    store_dest(ctx, &intr->dest, result, nir_type_uint);
1619 }
1620 
1621 static void
emit_store_deref(struct ntv_context * ctx,nir_intrinsic_instr * intr)1622 emit_store_deref(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1623 {
1624    SpvId ptr = get_src(ctx, &intr->src[0]);
1625    SpvId src = get_src(ctx, &intr->src[1]);
1626 
1627    SpvId type = get_glsl_type(ctx, nir_src_as_deref(intr->src[0])->type);
1628    SpvId result = emit_bitcast(ctx, type, src);
1629    spirv_builder_emit_store(&ctx->builder, ptr, result);
1630 }
1631 
1632 static SpvId
create_builtin_var(struct ntv_context * ctx,SpvId var_type,SpvStorageClass storage_class,const char * name,SpvBuiltIn builtin)1633 create_builtin_var(struct ntv_context *ctx, SpvId var_type,
1634                    SpvStorageClass storage_class,
1635                    const char *name, SpvBuiltIn builtin)
1636 {
1637    SpvId pointer_type = spirv_builder_type_pointer(&ctx->builder,
1638                                                    storage_class,
1639                                                    var_type);
1640    SpvId var = spirv_builder_emit_var(&ctx->builder, pointer_type,
1641                                       storage_class);
1642    spirv_builder_emit_name(&ctx->builder, var, name);
1643    spirv_builder_emit_builtin(&ctx->builder, var, builtin);
1644 
1645    assert(ctx->num_entry_ifaces < ARRAY_SIZE(ctx->entry_ifaces));
1646    ctx->entry_ifaces[ctx->num_entry_ifaces++] = var;
1647    return var;
1648 }
1649 
1650 static void
emit_load_front_face(struct ntv_context * ctx,nir_intrinsic_instr * intr)1651 emit_load_front_face(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1652 {
1653    SpvId var_type = spirv_builder_type_bool(&ctx->builder);
1654    if (!ctx->front_face_var)
1655       ctx->front_face_var = create_builtin_var(ctx, var_type,
1656                                                SpvStorageClassInput,
1657                                                "gl_FrontFacing",
1658                                                SpvBuiltInFrontFacing);
1659 
1660    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1661                                           ctx->front_face_var);
1662    assert(1 == nir_dest_num_components(intr->dest));
1663    store_dest(ctx, &intr->dest, result, nir_type_bool);
1664 }
1665 
1666 static void
emit_load_instance_id(struct ntv_context * ctx,nir_intrinsic_instr * intr)1667 emit_load_instance_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1668 {
1669    SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1670    if (!ctx->instance_id_var)
1671       ctx->instance_id_var = create_builtin_var(ctx, var_type,
1672                                                SpvStorageClassInput,
1673                                                "gl_InstanceId",
1674                                                SpvBuiltInInstanceIndex);
1675 
1676    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1677                                           ctx->instance_id_var);
1678    assert(1 == nir_dest_num_components(intr->dest));
1679    store_dest(ctx, &intr->dest, result, nir_type_uint);
1680 }
1681 
1682 static void
emit_load_vertex_id(struct ntv_context * ctx,nir_intrinsic_instr * intr)1683 emit_load_vertex_id(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1684 {
1685    SpvId var_type = spirv_builder_type_uint(&ctx->builder, 32);
1686    if (!ctx->vertex_id_var)
1687       ctx->vertex_id_var = create_builtin_var(ctx, var_type,
1688                                                SpvStorageClassInput,
1689                                                "gl_VertexID",
1690                                                SpvBuiltInVertexIndex);
1691 
1692    SpvId result = spirv_builder_emit_load(&ctx->builder, var_type,
1693                                           ctx->vertex_id_var);
1694    assert(1 == nir_dest_num_components(intr->dest));
1695    store_dest(ctx, &intr->dest, result, nir_type_uint);
1696 }
1697 
1698 static void
emit_intrinsic(struct ntv_context * ctx,nir_intrinsic_instr * intr)1699 emit_intrinsic(struct ntv_context *ctx, nir_intrinsic_instr *intr)
1700 {
1701    switch (intr->intrinsic) {
1702    case nir_intrinsic_load_ubo:
1703       emit_load_ubo(ctx, intr);
1704       break;
1705 
1706    case nir_intrinsic_discard:
1707       emit_discard(ctx, intr);
1708       break;
1709 
1710    case nir_intrinsic_load_deref:
1711       emit_load_deref(ctx, intr);
1712       break;
1713 
1714    case nir_intrinsic_store_deref:
1715       emit_store_deref(ctx, intr);
1716       break;
1717 
1718    case nir_intrinsic_load_front_face:
1719       emit_load_front_face(ctx, intr);
1720       break;
1721 
1722    case nir_intrinsic_load_instance_id:
1723       emit_load_instance_id(ctx, intr);
1724       break;
1725 
1726    case nir_intrinsic_load_vertex_id:
1727       emit_load_vertex_id(ctx, intr);
1728       break;
1729 
1730    default:
1731       fprintf(stderr, "emit_intrinsic: not implemented (%s)\n",
1732               nir_intrinsic_infos[intr->intrinsic].name);
1733       unreachable("unsupported intrinsic");
1734    }
1735 }
1736 
1737 static void
emit_undef(struct ntv_context * ctx,nir_ssa_undef_instr * undef)1738 emit_undef(struct ntv_context *ctx, nir_ssa_undef_instr *undef)
1739 {
1740    SpvId type = get_uvec_type(ctx, undef->def.bit_size,
1741                               undef->def.num_components);
1742 
1743    store_ssa_def(ctx, &undef->def,
1744                  spirv_builder_emit_undef(&ctx->builder, type));
1745 }
1746 
1747 static SpvId
get_src_float(struct ntv_context * ctx,nir_src * src)1748 get_src_float(struct ntv_context *ctx, nir_src *src)
1749 {
1750    SpvId def = get_src(ctx, src);
1751    unsigned num_components = nir_src_num_components(*src);
1752    unsigned bit_size = nir_src_bit_size(*src);
1753    return bitcast_to_fvec(ctx, def, bit_size, num_components);
1754 }
1755 
1756 static SpvId
get_src_int(struct ntv_context * ctx,nir_src * src)1757 get_src_int(struct ntv_context *ctx, nir_src *src)
1758 {
1759    SpvId def = get_src(ctx, src);
1760    unsigned num_components = nir_src_num_components(*src);
1761    unsigned bit_size = nir_src_bit_size(*src);
1762    return bitcast_to_ivec(ctx, def, bit_size, num_components);
1763 }
1764 
1765 static inline bool
tex_instr_is_lod_allowed(nir_tex_instr * tex)1766 tex_instr_is_lod_allowed(nir_tex_instr *tex)
1767 {
1768    /* This can only be used with an OpTypeImage that has a Dim operand of 1D, 2D, 3D, or Cube
1769     * - SPIR-V: 3.14. Image Operands
1770     */
1771 
1772    return (tex->sampler_dim == GLSL_SAMPLER_DIM_1D ||
1773            tex->sampler_dim == GLSL_SAMPLER_DIM_2D ||
1774            tex->sampler_dim == GLSL_SAMPLER_DIM_3D ||
1775            tex->sampler_dim == GLSL_SAMPLER_DIM_CUBE);
1776 }
1777 
1778 static SpvId
pad_coord_vector(struct ntv_context * ctx,SpvId orig,unsigned old_size,unsigned new_size)1779 pad_coord_vector(struct ntv_context *ctx, SpvId orig, unsigned old_size, unsigned new_size)
1780 {
1781     SpvId int_type = spirv_builder_type_int(&ctx->builder, 32);
1782     SpvId type = get_ivec_type(ctx, 32, new_size);
1783     SpvId constituents[NIR_MAX_VEC_COMPONENTS] = {0};
1784     SpvId zero = emit_int_const(ctx, 32, 0);
1785     assert(new_size < NIR_MAX_VEC_COMPONENTS);
1786 
1787     if (old_size == 1)
1788        constituents[0] = orig;
1789     else {
1790        for (unsigned i = 0; i < old_size; i++)
1791           constituents[i] = spirv_builder_emit_vector_extract(&ctx->builder, int_type, orig, i);
1792     }
1793 
1794     for (unsigned i = old_size; i < new_size; i++)
1795        constituents[i] = zero;
1796 
1797     return spirv_builder_emit_composite_construct(&ctx->builder, type,
1798                                                   constituents, new_size);
1799 }
1800 
1801 static void
emit_tex(struct ntv_context * ctx,nir_tex_instr * tex)1802 emit_tex(struct ntv_context *ctx, nir_tex_instr *tex)
1803 {
1804    assert(tex->op == nir_texop_tex ||
1805           tex->op == nir_texop_txb ||
1806           tex->op == nir_texop_txl ||
1807           tex->op == nir_texop_txd ||
1808           tex->op == nir_texop_txf ||
1809           tex->op == nir_texop_txf_ms ||
1810           tex->op == nir_texop_txs);
1811    assert(tex->texture_index == tex->sampler_index);
1812 
1813    SpvId coord = 0, proj = 0, bias = 0, lod = 0, dref = 0, dx = 0, dy = 0,
1814          offset = 0, sample = 0;
1815    unsigned coord_components = 0, coord_bitsize = 0, offset_components = 0;
1816    for (unsigned i = 0; i < tex->num_srcs; i++) {
1817       switch (tex->src[i].src_type) {
1818       case nir_tex_src_coord:
1819          if (tex->op == nir_texop_txf ||
1820              tex->op == nir_texop_txf_ms)
1821             coord = get_src_int(ctx, &tex->src[i].src);
1822          else
1823             coord = get_src_float(ctx, &tex->src[i].src);
1824          coord_components = nir_src_num_components(tex->src[i].src);
1825          coord_bitsize = nir_src_bit_size(tex->src[i].src);
1826          break;
1827 
1828       case nir_tex_src_projector:
1829          assert(nir_src_num_components(tex->src[i].src) == 1);
1830          proj = get_src_float(ctx, &tex->src[i].src);
1831          assert(proj != 0);
1832          break;
1833 
1834       case nir_tex_src_offset:
1835          offset = get_src_int(ctx, &tex->src[i].src);
1836          offset_components = nir_src_num_components(tex->src[i].src);
1837          break;
1838 
1839       case nir_tex_src_bias:
1840          assert(tex->op == nir_texop_txb);
1841          bias = get_src_float(ctx, &tex->src[i].src);
1842          assert(bias != 0);
1843          break;
1844 
1845       case nir_tex_src_lod:
1846          assert(nir_src_num_components(tex->src[i].src) == 1);
1847          if (tex->op == nir_texop_txf ||
1848              tex->op == nir_texop_txf_ms ||
1849              tex->op == nir_texop_txs)
1850             lod = get_src_int(ctx, &tex->src[i].src);
1851          else
1852             lod = get_src_float(ctx, &tex->src[i].src);
1853          assert(lod != 0);
1854          break;
1855 
1856       case nir_tex_src_ms_index:
1857          assert(nir_src_num_components(tex->src[i].src) == 1);
1858          sample = get_src_int(ctx, &tex->src[i].src);
1859          break;
1860 
1861       case nir_tex_src_comparator:
1862          assert(nir_src_num_components(tex->src[i].src) == 1);
1863          dref = get_src_float(ctx, &tex->src[i].src);
1864          assert(dref != 0);
1865          break;
1866 
1867       case nir_tex_src_ddx:
1868          dx = get_src_float(ctx, &tex->src[i].src);
1869          assert(dx != 0);
1870          break;
1871 
1872       case nir_tex_src_ddy:
1873          dy = get_src_float(ctx, &tex->src[i].src);
1874          assert(dy != 0);
1875          break;
1876 
1877       default:
1878          fprintf(stderr, "texture source: %d\n", tex->src[i].src_type);
1879          unreachable("unknown texture source");
1880       }
1881    }
1882 
1883    if (lod == 0 && ctx->stage != MESA_SHADER_FRAGMENT) {
1884       lod = emit_float_const(ctx, 32, 0.0f);
1885       assert(lod != 0);
1886    }
1887 
1888    SpvId image_type = ctx->image_types[tex->texture_index];
1889    SpvId sampled_type = spirv_builder_type_sampled_image(&ctx->builder,
1890                                                          image_type);
1891 
1892    assert(ctx->samplers_used & (1u << tex->texture_index));
1893    SpvId load = spirv_builder_emit_load(&ctx->builder, sampled_type,
1894                                         ctx->samplers[tex->texture_index]);
1895 
1896    SpvId dest_type = get_dest_type(ctx, &tex->dest, tex->dest_type);
1897 
1898    if (!tex_instr_is_lod_allowed(tex))
1899       lod = 0;
1900    if (tex->op == nir_texop_txs) {
1901       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1902       SpvId result = spirv_builder_emit_image_query_size(&ctx->builder,
1903                                                          dest_type, image,
1904                                                          lod);
1905       store_dest(ctx, &tex->dest, result, tex->dest_type);
1906       return;
1907    }
1908 
1909    if (proj && coord_components > 0) {
1910       SpvId constituents[coord_components + 1];
1911       if (coord_components == 1)
1912          constituents[0] = coord;
1913       else {
1914          assert(coord_components > 1);
1915          SpvId float_type = spirv_builder_type_float(&ctx->builder, 32);
1916          for (uint32_t i = 0; i < coord_components; ++i)
1917             constituents[i] = spirv_builder_emit_composite_extract(&ctx->builder,
1918                                                  float_type,
1919                                                  coord,
1920                                                  &i, 1);
1921       }
1922 
1923       constituents[coord_components++] = proj;
1924 
1925       SpvId vec_type = get_fvec_type(ctx, 32, coord_components);
1926       coord = spirv_builder_emit_composite_construct(&ctx->builder,
1927                                                             vec_type,
1928                                                             constituents,
1929                                                             coord_components);
1930    }
1931 
1932    SpvId actual_dest_type = dest_type;
1933    if (dref)
1934       actual_dest_type = spirv_builder_type_float(&ctx->builder, 32);
1935 
1936    SpvId result;
1937    if (tex->op == nir_texop_txf ||
1938        tex->op == nir_texop_txf_ms) {
1939       SpvId image = spirv_builder_emit_image(&ctx->builder, image_type, load);
1940       if (offset) {
1941          /* SPIRV requires matched length vectors for OpIAdd, so if a shader
1942           * uses vecs of differing sizes we need to make a new vec padded with zeroes
1943           * to mimic how GLSL does this implicitly
1944           */
1945          if (offset_components > coord_components)
1946             coord = pad_coord_vector(ctx, coord, coord_components, offset_components);
1947          else if (coord_components > offset_components)
1948             offset = pad_coord_vector(ctx, offset, offset_components, coord_components);
1949          coord = emit_binop(ctx, SpvOpIAdd,
1950                             get_ivec_type(ctx, coord_bitsize, coord_components),
1951                             coord, offset);
1952       }
1953       result = spirv_builder_emit_image_fetch(&ctx->builder, dest_type,
1954                                               image, coord, lod, sample);
1955    } else {
1956       result = spirv_builder_emit_image_sample(&ctx->builder,
1957                                                actual_dest_type, load,
1958                                                coord,
1959                                                proj != 0,
1960                                                lod, bias, dref, dx, dy,
1961                                                offset);
1962    }
1963 
1964    spirv_builder_emit_decoration(&ctx->builder, result,
1965                                  SpvDecorationRelaxedPrecision);
1966 
1967    if (dref && nir_dest_num_components(tex->dest) > 1) {
1968       SpvId components[4] = { result, result, result, result };
1969       result = spirv_builder_emit_composite_construct(&ctx->builder,
1970                                                       dest_type,
1971                                                       components,
1972                                                       4);
1973    }
1974 
1975    store_dest(ctx, &tex->dest, result, tex->dest_type);
1976 }
1977 
1978 static void
start_block(struct ntv_context * ctx,SpvId label)1979 start_block(struct ntv_context *ctx, SpvId label)
1980 {
1981    /* terminate previous block if needed */
1982    if (ctx->block_started)
1983       spirv_builder_emit_branch(&ctx->builder, label);
1984 
1985    /* start new block */
1986    spirv_builder_label(&ctx->builder, label);
1987    ctx->block_started = true;
1988 }
1989 
1990 static void
branch(struct ntv_context * ctx,SpvId label)1991 branch(struct ntv_context *ctx, SpvId label)
1992 {
1993    assert(ctx->block_started);
1994    spirv_builder_emit_branch(&ctx->builder, label);
1995    ctx->block_started = false;
1996 }
1997 
1998 static void
branch_conditional(struct ntv_context * ctx,SpvId condition,SpvId then_id,SpvId else_id)1999 branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
2000                    SpvId else_id)
2001 {
2002    assert(ctx->block_started);
2003    spirv_builder_emit_branch_conditional(&ctx->builder, condition,
2004                                          then_id, else_id);
2005    ctx->block_started = false;
2006 }
2007 
2008 static void
emit_jump(struct ntv_context * ctx,nir_jump_instr * jump)2009 emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
2010 {
2011    switch (jump->type) {
2012    case nir_jump_break:
2013       assert(ctx->loop_break);
2014       branch(ctx, ctx->loop_break);
2015       break;
2016 
2017    case nir_jump_continue:
2018       assert(ctx->loop_cont);
2019       branch(ctx, ctx->loop_cont);
2020       break;
2021 
2022    default:
2023       unreachable("Unsupported jump type\n");
2024    }
2025 }
2026 
2027 static void
emit_deref_var(struct ntv_context * ctx,nir_deref_instr * deref)2028 emit_deref_var(struct ntv_context *ctx, nir_deref_instr *deref)
2029 {
2030    assert(deref->deref_type == nir_deref_type_var);
2031 
2032    struct hash_entry *he = _mesa_hash_table_search(ctx->vars, deref->var);
2033    assert(he);
2034    SpvId result = (SpvId)(intptr_t)he->data;
2035    store_dest_raw(ctx, &deref->dest, result);
2036 }
2037 
2038 static void
emit_deref_array(struct ntv_context * ctx,nir_deref_instr * deref)2039 emit_deref_array(struct ntv_context *ctx, nir_deref_instr *deref)
2040 {
2041    assert(deref->deref_type == nir_deref_type_array);
2042    nir_variable *var = nir_deref_instr_get_variable(deref);
2043 
2044    SpvStorageClass storage_class;
2045    switch (var->data.mode) {
2046    case nir_var_shader_in:
2047       storage_class = SpvStorageClassInput;
2048       break;
2049 
2050    case nir_var_shader_out:
2051       storage_class = SpvStorageClassOutput;
2052       break;
2053 
2054    default:
2055       unreachable("Unsupported nir_variable_mode\n");
2056    }
2057 
2058    SpvId index = get_src(ctx, &deref->arr.index);
2059 
2060    SpvId ptr_type = spirv_builder_type_pointer(&ctx->builder,
2061                                                storage_class,
2062                                                get_glsl_type(ctx, deref->type));
2063 
2064    SpvId result = spirv_builder_emit_access_chain(&ctx->builder,
2065                                                   ptr_type,
2066                                                   get_src(ctx, &deref->parent),
2067                                                   &index, 1);
2068    /* uint is a bit of a lie here, it's really just an opaque type */
2069    store_dest(ctx, &deref->dest, result, nir_type_uint);
2070 }
2071 
2072 static void
emit_deref(struct ntv_context * ctx,nir_deref_instr * deref)2073 emit_deref(struct ntv_context *ctx, nir_deref_instr *deref)
2074 {
2075    switch (deref->deref_type) {
2076    case nir_deref_type_var:
2077       emit_deref_var(ctx, deref);
2078       break;
2079 
2080    case nir_deref_type_array:
2081       emit_deref_array(ctx, deref);
2082       break;
2083 
2084    default:
2085       unreachable("unexpected deref_type");
2086    }
2087 }
2088 
2089 static void
emit_block(struct ntv_context * ctx,struct nir_block * block)2090 emit_block(struct ntv_context *ctx, struct nir_block *block)
2091 {
2092    start_block(ctx, block_label(ctx, block));
2093    nir_foreach_instr(instr, block) {
2094       switch (instr->type) {
2095       case nir_instr_type_alu:
2096          emit_alu(ctx, nir_instr_as_alu(instr));
2097          break;
2098       case nir_instr_type_intrinsic:
2099          emit_intrinsic(ctx, nir_instr_as_intrinsic(instr));
2100          break;
2101       case nir_instr_type_load_const:
2102          emit_load_const(ctx, nir_instr_as_load_const(instr));
2103          break;
2104       case nir_instr_type_ssa_undef:
2105          emit_undef(ctx, nir_instr_as_ssa_undef(instr));
2106          break;
2107       case nir_instr_type_tex:
2108          emit_tex(ctx, nir_instr_as_tex(instr));
2109          break;
2110       case nir_instr_type_phi:
2111          unreachable("nir_instr_type_phi not supported");
2112          break;
2113       case nir_instr_type_jump:
2114          emit_jump(ctx, nir_instr_as_jump(instr));
2115          break;
2116       case nir_instr_type_call:
2117          unreachable("nir_instr_type_call not supported");
2118          break;
2119       case nir_instr_type_parallel_copy:
2120          unreachable("nir_instr_type_parallel_copy not supported");
2121          break;
2122       case nir_instr_type_deref:
2123          emit_deref(ctx, nir_instr_as_deref(instr));
2124          break;
2125       }
2126    }
2127 }
2128 
2129 static void
2130 emit_cf_list(struct ntv_context *ctx, struct exec_list *list);
2131 
2132 static SpvId
get_src_bool(struct ntv_context * ctx,nir_src * src)2133 get_src_bool(struct ntv_context *ctx, nir_src *src)
2134 {
2135    assert(nir_src_bit_size(*src) == 1);
2136    return get_src(ctx, src);
2137 }
2138 
2139 static void
emit_if(struct ntv_context * ctx,nir_if * if_stmt)2140 emit_if(struct ntv_context *ctx, nir_if *if_stmt)
2141 {
2142    SpvId condition = get_src_bool(ctx, &if_stmt->condition);
2143 
2144    SpvId header_id = spirv_builder_new_id(&ctx->builder);
2145    SpvId then_id = block_label(ctx, nir_if_first_then_block(if_stmt));
2146    SpvId endif_id = spirv_builder_new_id(&ctx->builder);
2147    SpvId else_id = endif_id;
2148 
2149    bool has_else = !exec_list_is_empty(&if_stmt->else_list);
2150    if (has_else) {
2151       assert(nir_if_first_else_block(if_stmt)->index < ctx->num_blocks);
2152       else_id = block_label(ctx, nir_if_first_else_block(if_stmt));
2153    }
2154 
2155    /* create a header-block */
2156    start_block(ctx, header_id);
2157    spirv_builder_emit_selection_merge(&ctx->builder, endif_id,
2158                                       SpvSelectionControlMaskNone);
2159    branch_conditional(ctx, condition, then_id, else_id);
2160 
2161    emit_cf_list(ctx, &if_stmt->then_list);
2162 
2163    if (has_else) {
2164       if (ctx->block_started)
2165          branch(ctx, endif_id);
2166 
2167       emit_cf_list(ctx, &if_stmt->else_list);
2168    }
2169 
2170    start_block(ctx, endif_id);
2171 }
2172 
2173 static void
emit_loop(struct ntv_context * ctx,nir_loop * loop)2174 emit_loop(struct ntv_context *ctx, nir_loop *loop)
2175 {
2176    SpvId header_id = spirv_builder_new_id(&ctx->builder);
2177    SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
2178    SpvId break_id = spirv_builder_new_id(&ctx->builder);
2179    SpvId cont_id = spirv_builder_new_id(&ctx->builder);
2180 
2181    /* create a header-block */
2182    start_block(ctx, header_id);
2183    spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
2184    branch(ctx, begin_id);
2185 
2186    SpvId save_break = ctx->loop_break;
2187    SpvId save_cont = ctx->loop_cont;
2188    ctx->loop_break = break_id;
2189    ctx->loop_cont = cont_id;
2190 
2191    emit_cf_list(ctx, &loop->body);
2192 
2193    ctx->loop_break = save_break;
2194    ctx->loop_cont = save_cont;
2195 
2196    /* loop->body may have already ended our block */
2197    if (ctx->block_started)
2198       branch(ctx, cont_id);
2199    start_block(ctx, cont_id);
2200    branch(ctx, header_id);
2201 
2202    start_block(ctx, break_id);
2203 }
2204 
2205 static void
emit_cf_list(struct ntv_context * ctx,struct exec_list * list)2206 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
2207 {
2208    foreach_list_typed(nir_cf_node, node, node, list) {
2209       switch (node->type) {
2210       case nir_cf_node_block:
2211          emit_block(ctx, nir_cf_node_as_block(node));
2212          break;
2213 
2214       case nir_cf_node_if:
2215          emit_if(ctx, nir_cf_node_as_if(node));
2216          break;
2217 
2218       case nir_cf_node_loop:
2219          emit_loop(ctx, nir_cf_node_as_loop(node));
2220          break;
2221 
2222       case nir_cf_node_function:
2223          unreachable("nir_cf_node_function not supported");
2224          break;
2225       }
2226    }
2227 }
2228 
2229 struct spirv_shader *
nir_to_spirv(struct nir_shader * s,const struct pipe_stream_output_info * so_info,struct pipe_stream_output_info * local_so_info)2230 nir_to_spirv(struct nir_shader *s, const struct pipe_stream_output_info *so_info, struct pipe_stream_output_info *local_so_info)
2231 {
2232    struct spirv_shader *ret = NULL;
2233 
2234    struct ntv_context ctx = {};
2235    ctx.mem_ctx = ralloc_context(NULL);
2236    ctx.builder.mem_ctx = ctx.mem_ctx;
2237 
2238    switch (s->info.stage) {
2239    case MESA_SHADER_VERTEX:
2240    case MESA_SHADER_FRAGMENT:
2241    case MESA_SHADER_COMPUTE:
2242       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityShader);
2243       break;
2244 
2245    case MESA_SHADER_TESS_CTRL:
2246    case MESA_SHADER_TESS_EVAL:
2247       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTessellation);
2248       break;
2249 
2250    case MESA_SHADER_GEOMETRY:
2251       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityGeometry);
2252       break;
2253 
2254    default:
2255       unreachable("invalid stage");
2256    }
2257 
2258    // TODO: only enable when needed
2259    if (s->info.stage == MESA_SHADER_FRAGMENT) {
2260       spirv_builder_emit_cap(&ctx.builder, SpvCapabilitySampled1D);
2261       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityImageQuery);
2262       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityDerivativeControl);
2263    }
2264 
2265    ctx.stage = s->info.stage;
2266    ctx.GLSL_std_450 = spirv_builder_import(&ctx.builder, "GLSL.std.450");
2267    spirv_builder_emit_source(&ctx.builder, SpvSourceLanguageGLSL, 450);
2268 
2269    spirv_builder_emit_mem_model(&ctx.builder, SpvAddressingModelLogical,
2270                                 SpvMemoryModelGLSL450);
2271 
2272    SpvExecutionModel exec_model;
2273    switch (s->info.stage) {
2274    case MESA_SHADER_VERTEX:
2275       exec_model = SpvExecutionModelVertex;
2276       break;
2277    case MESA_SHADER_TESS_CTRL:
2278       exec_model = SpvExecutionModelTessellationControl;
2279       break;
2280    case MESA_SHADER_TESS_EVAL:
2281       exec_model = SpvExecutionModelTessellationEvaluation;
2282       break;
2283    case MESA_SHADER_GEOMETRY:
2284       exec_model = SpvExecutionModelGeometry;
2285       break;
2286    case MESA_SHADER_FRAGMENT:
2287       exec_model = SpvExecutionModelFragment;
2288       break;
2289    case MESA_SHADER_COMPUTE:
2290       exec_model = SpvExecutionModelGLCompute;
2291       break;
2292    default:
2293       unreachable("invalid stage");
2294    }
2295 
2296    SpvId type_void = spirv_builder_type_void(&ctx.builder);
2297    SpvId type_main = spirv_builder_type_function(&ctx.builder, type_void,
2298                                                  NULL, 0);
2299    SpvId entry_point = spirv_builder_new_id(&ctx.builder);
2300    spirv_builder_emit_name(&ctx.builder, entry_point, "main");
2301 
2302    ctx.vars = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_pointer,
2303                                       _mesa_key_pointer_equal);
2304 
2305    ctx.so_outputs = _mesa_hash_table_create(ctx.mem_ctx, _mesa_hash_u32,
2306                                             _mesa_key_u32_equal);
2307 
2308    nir_foreach_shader_in_variable(var, s)
2309       emit_input(&ctx, var);
2310 
2311    nir_foreach_shader_out_variable(var, s)
2312       emit_output(&ctx, var);
2313 
2314    if (so_info)
2315       emit_so_info(&ctx, util_last_bit64(s->info.outputs_written), so_info, local_so_info);
2316    nir_foreach_variable_with_modes(var, s, nir_var_uniform |
2317                                            nir_var_mem_ubo |
2318                                            nir_var_mem_ssbo)
2319       emit_uniform(&ctx, var);
2320 
2321    if (s->info.stage == MESA_SHADER_FRAGMENT) {
2322       spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2323                                    SpvExecutionModeOriginUpperLeft);
2324       if (s->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_DEPTH))
2325          spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2326                                       SpvExecutionModeDepthReplacing);
2327    }
2328 
2329    if (so_info && so_info->num_outputs) {
2330       spirv_builder_emit_cap(&ctx.builder, SpvCapabilityTransformFeedback);
2331       spirv_builder_emit_exec_mode(&ctx.builder, entry_point,
2332                                    SpvExecutionModeXfb);
2333    }
2334 
2335    spirv_builder_function(&ctx.builder, entry_point, type_void,
2336                                             SpvFunctionControlMaskNone,
2337                                             type_main);
2338 
2339    nir_function_impl *entry = nir_shader_get_entrypoint(s);
2340    nir_metadata_require(entry, nir_metadata_block_index);
2341 
2342    ctx.defs = ralloc_array_size(ctx.mem_ctx,
2343                                 sizeof(SpvId), entry->ssa_alloc);
2344    if (!ctx.defs)
2345       goto fail;
2346    ctx.num_defs = entry->ssa_alloc;
2347 
2348    nir_index_local_regs(entry);
2349    ctx.regs = ralloc_array_size(ctx.mem_ctx,
2350                                 sizeof(SpvId), entry->reg_alloc);
2351    if (!ctx.regs)
2352       goto fail;
2353    ctx.num_regs = entry->reg_alloc;
2354 
2355    SpvId *block_ids = ralloc_array_size(ctx.mem_ctx,
2356                                         sizeof(SpvId), entry->num_blocks);
2357    if (!block_ids)
2358       goto fail;
2359 
2360    for (int i = 0; i < entry->num_blocks; ++i)
2361       block_ids[i] = spirv_builder_new_id(&ctx.builder);
2362 
2363    ctx.block_ids = block_ids;
2364    ctx.num_blocks = entry->num_blocks;
2365 
2366    /* emit a block only for the variable declarations */
2367    start_block(&ctx, spirv_builder_new_id(&ctx.builder));
2368    foreach_list_typed(nir_register, reg, node, &entry->registers) {
2369       SpvId type = get_vec_from_bit_size(&ctx, reg->bit_size, reg->num_components);
2370       SpvId pointer_type = spirv_builder_type_pointer(&ctx.builder,
2371                                                       SpvStorageClassFunction,
2372                                                       type);
2373       SpvId var = spirv_builder_emit_var(&ctx.builder, pointer_type,
2374                                          SpvStorageClassFunction);
2375 
2376       ctx.regs[reg->index] = var;
2377    }
2378 
2379    emit_cf_list(&ctx, &entry->body);
2380 
2381    if (so_info)
2382       emit_so_outputs(&ctx, so_info, local_so_info);
2383 
2384    spirv_builder_return(&ctx.builder); // doesn't belong here, but whatevz
2385    spirv_builder_function_end(&ctx.builder);
2386 
2387    spirv_builder_emit_entry_point(&ctx.builder, exec_model, entry_point,
2388                                   "main", ctx.entry_ifaces,
2389                                   ctx.num_entry_ifaces);
2390 
2391    size_t num_words = spirv_builder_get_num_words(&ctx.builder);
2392 
2393    ret = CALLOC_STRUCT(spirv_shader);
2394    if (!ret)
2395       goto fail;
2396 
2397    ret->words = MALLOC(sizeof(uint32_t) * num_words);
2398    if (!ret->words)
2399       goto fail;
2400 
2401    ret->num_words = spirv_builder_get_words(&ctx.builder, ret->words, num_words);
2402    assert(ret->num_words == num_words);
2403 
2404    ralloc_free(ctx.mem_ctx);
2405 
2406    return ret;
2407 
2408 fail:
2409    ralloc_free(ctx.mem_ctx);
2410 
2411    if (ret)
2412       spirv_shader_delete(ret);
2413 
2414    return NULL;
2415 }
2416 
2417 void
spirv_shader_delete(struct spirv_shader * s)2418 spirv_shader_delete(struct spirv_shader *s)
2419 {
2420    FREE(s->words);
2421    FREE(s);
2422 }
2423