1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <math.h>
25 #include "vtn_private.h"
26 #include "spirv_info.h"
27 
28 /*
29  * Normally, column vectors in SPIR-V correspond to a single NIR SSA
30  * definition. But for matrix multiplies, we want to do one routine for
31  * multiplying a matrix by a matrix and then pretend that vectors are matrices
32  * with one column. So we "wrap" these things, and unwrap the result before we
33  * send it off.
34  */
35 
36 static struct vtn_ssa_value *
wrap_matrix(struct vtn_builder * b,struct vtn_ssa_value * val)37 wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
38 {
39    if (val == NULL)
40       return NULL;
41 
42    if (glsl_type_is_matrix(val->type))
43       return val;
44 
45    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
46    dest->type = glsl_get_bare_type(val->type);
47    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
48    dest->elems[0] = val;
49 
50    return dest;
51 }
52 
53 static struct vtn_ssa_value *
unwrap_matrix(struct vtn_ssa_value * val)54 unwrap_matrix(struct vtn_ssa_value *val)
55 {
56    if (glsl_type_is_matrix(val->type))
57          return val;
58 
59    return val->elems[0];
60 }
61 
62 static struct vtn_ssa_value *
matrix_multiply(struct vtn_builder * b,struct vtn_ssa_value * _src0,struct vtn_ssa_value * _src1)63 matrix_multiply(struct vtn_builder *b,
64                 struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1)
65 {
66 
67    struct vtn_ssa_value *src0 = wrap_matrix(b, _src0);
68    struct vtn_ssa_value *src1 = wrap_matrix(b, _src1);
69    struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed);
70    struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed);
71 
72    unsigned src0_rows = glsl_get_vector_elements(src0->type);
73    unsigned src0_columns = glsl_get_matrix_columns(src0->type);
74    unsigned src1_columns = glsl_get_matrix_columns(src1->type);
75 
76    const struct glsl_type *dest_type;
77    if (src1_columns > 1) {
78       dest_type = glsl_matrix_type(glsl_get_base_type(src0->type),
79                                    src0_rows, src1_columns);
80    } else {
81       dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows);
82    }
83    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
84 
85    dest = wrap_matrix(b, dest);
86 
87    bool transpose_result = false;
88    if (src0_transpose && src1_transpose) {
89       /* transpose(A) * transpose(B) = transpose(B * A) */
90       src1 = src0_transpose;
91       src0 = src1_transpose;
92       src0_transpose = NULL;
93       src1_transpose = NULL;
94       transpose_result = true;
95    }
96 
97    if (src0_transpose && !src1_transpose &&
98        glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) {
99       /* We already have the rows of src0 and the columns of src1 available,
100        * so we can just take the dot product of each row with each column to
101        * get the result.
102        */
103 
104       for (unsigned i = 0; i < src1_columns; i++) {
105          nir_ssa_def *vec_src[4];
106          for (unsigned j = 0; j < src0_rows; j++) {
107             vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def,
108                                           src1->elems[i]->def);
109          }
110          dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows);
111       }
112    } else {
113       /* We don't handle the case where src1 is transposed but not src0, since
114        * the general case only uses individual components of src1 so the
115        * optimizer should chew through the transpose we emitted for src1.
116        */
117 
118       for (unsigned i = 0; i < src1_columns; i++) {
119          /* dest[i] = sum(src0[j] * src1[i][j] for all j) */
120          dest->elems[i]->def =
121             nir_fmul(&b->nb, src0->elems[src0_columns - 1]->def,
122                      nir_channel(&b->nb, src1->elems[i]->def, src0_columns - 1));
123          for (int j = src0_columns - 2; j >= 0; j--) {
124             dest->elems[i]->def =
125                nir_ffma(&b->nb, src0->elems[j]->def,
126                                 nir_channel(&b->nb, src1->elems[i]->def, j),
127                                 dest->elems[i]->def);
128          }
129       }
130    }
131 
132    dest = unwrap_matrix(dest);
133 
134    if (transpose_result)
135       dest = vtn_ssa_transpose(b, dest);
136 
137    return dest;
138 }
139 
140 static struct vtn_ssa_value *
mat_times_scalar(struct vtn_builder * b,struct vtn_ssa_value * mat,nir_ssa_def * scalar)141 mat_times_scalar(struct vtn_builder *b,
142                  struct vtn_ssa_value *mat,
143                  nir_ssa_def *scalar)
144 {
145    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type);
146    for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) {
147       if (glsl_base_type_is_integer(glsl_get_base_type(mat->type)))
148          dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar);
149       else
150          dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar);
151    }
152 
153    return dest;
154 }
155 
156 static struct vtn_ssa_value *
vtn_handle_matrix_alu(struct vtn_builder * b,SpvOp opcode,struct vtn_ssa_value * src0,struct vtn_ssa_value * src1)157 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
158                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
159 {
160    switch (opcode) {
161    case SpvOpFNegate: {
162       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
163       unsigned cols = glsl_get_matrix_columns(src0->type);
164       for (unsigned i = 0; i < cols; i++)
165          dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
166       return dest;
167    }
168 
169    case SpvOpFAdd: {
170       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
171       unsigned cols = glsl_get_matrix_columns(src0->type);
172       for (unsigned i = 0; i < cols; i++)
173          dest->elems[i]->def =
174             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
175       return dest;
176    }
177 
178    case SpvOpFSub: {
179       struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
180       unsigned cols = glsl_get_matrix_columns(src0->type);
181       for (unsigned i = 0; i < cols; i++)
182          dest->elems[i]->def =
183             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
184       return dest;
185    }
186 
187    case SpvOpTranspose:
188       return vtn_ssa_transpose(b, src0);
189 
190    case SpvOpMatrixTimesScalar:
191       if (src0->transposed) {
192          return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
193                                                          src1->def));
194       } else {
195          return mat_times_scalar(b, src0, src1->def);
196       }
197       break;
198 
199    case SpvOpVectorTimesMatrix:
200    case SpvOpMatrixTimesVector:
201    case SpvOpMatrixTimesMatrix:
202       if (opcode == SpvOpVectorTimesMatrix) {
203          return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
204       } else {
205          return matrix_multiply(b, src0, src1);
206       }
207       break;
208 
209    default: vtn_fail_with_opcode("unknown matrix opcode", opcode);
210    }
211 }
212 
213 static nir_alu_type
convert_op_src_type(SpvOp opcode)214 convert_op_src_type(SpvOp opcode)
215 {
216    switch (opcode) {
217    case SpvOpFConvert:
218    case SpvOpConvertFToS:
219    case SpvOpConvertFToU:
220       return nir_type_float;
221    case SpvOpSConvert:
222    case SpvOpConvertSToF:
223    case SpvOpSatConvertSToU:
224       return nir_type_int;
225    case SpvOpUConvert:
226    case SpvOpConvertUToF:
227    case SpvOpSatConvertUToS:
228       return nir_type_uint;
229    default:
230       unreachable("Unhandled conversion op");
231    }
232 }
233 
234 static nir_alu_type
convert_op_dst_type(SpvOp opcode)235 convert_op_dst_type(SpvOp opcode)
236 {
237    switch (opcode) {
238    case SpvOpFConvert:
239    case SpvOpConvertSToF:
240    case SpvOpConvertUToF:
241       return nir_type_float;
242    case SpvOpSConvert:
243    case SpvOpConvertFToS:
244    case SpvOpSatConvertUToS:
245       return nir_type_int;
246    case SpvOpUConvert:
247    case SpvOpConvertFToU:
248    case SpvOpSatConvertSToU:
249       return nir_type_uint;
250    default:
251       unreachable("Unhandled conversion op");
252    }
253 }
254 
255 nir_op
vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder * b,SpvOp opcode,bool * swap,bool * exact,unsigned src_bit_size,unsigned dst_bit_size)256 vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
257                                 SpvOp opcode, bool *swap, bool *exact,
258                                 unsigned src_bit_size, unsigned dst_bit_size)
259 {
260    /* Indicates that the first two arguments should be swapped.  This is
261     * used for implementing greater-than and less-than-or-equal.
262     */
263    *swap = false;
264 
265    *exact = false;
266 
267    switch (opcode) {
268    case SpvOpSNegate:            return nir_op_ineg;
269    case SpvOpFNegate:            return nir_op_fneg;
270    case SpvOpNot:                return nir_op_inot;
271    case SpvOpIAdd:               return nir_op_iadd;
272    case SpvOpFAdd:               return nir_op_fadd;
273    case SpvOpISub:               return nir_op_isub;
274    case SpvOpFSub:               return nir_op_fsub;
275    case SpvOpIMul:               return nir_op_imul;
276    case SpvOpFMul:               return nir_op_fmul;
277    case SpvOpUDiv:               return nir_op_udiv;
278    case SpvOpSDiv:               return nir_op_idiv;
279    case SpvOpFDiv:               return nir_op_fdiv;
280    case SpvOpUMod:               return nir_op_umod;
281    case SpvOpSMod:               return nir_op_imod;
282    case SpvOpFMod:               return nir_op_fmod;
283    case SpvOpSRem:               return nir_op_irem;
284    case SpvOpFRem:               return nir_op_frem;
285 
286    case SpvOpShiftRightLogical:     return nir_op_ushr;
287    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
288    case SpvOpShiftLeftLogical:      return nir_op_ishl;
289    case SpvOpLogicalOr:             return nir_op_ior;
290    case SpvOpLogicalEqual:          return nir_op_ieq;
291    case SpvOpLogicalNotEqual:       return nir_op_ine;
292    case SpvOpLogicalAnd:            return nir_op_iand;
293    case SpvOpLogicalNot:            return nir_op_inot;
294    case SpvOpBitwiseOr:             return nir_op_ior;
295    case SpvOpBitwiseXor:            return nir_op_ixor;
296    case SpvOpBitwiseAnd:            return nir_op_iand;
297    case SpvOpSelect:                return nir_op_bcsel;
298    case SpvOpIEqual:                return nir_op_ieq;
299 
300    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
301    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
302    case SpvOpBitFieldUExtract:      return nir_op_ubitfield_extract;
303    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
304 
305    case SpvOpUCountLeadingZerosINTEL: return nir_op_uclz;
306    /* SpvOpUCountTrailingZerosINTEL is handled elsewhere. */
307    case SpvOpAbsISubINTEL:          return nir_op_uabs_isub;
308    case SpvOpAbsUSubINTEL:          return nir_op_uabs_usub;
309    case SpvOpIAddSatINTEL:          return nir_op_iadd_sat;
310    case SpvOpUAddSatINTEL:          return nir_op_uadd_sat;
311    case SpvOpIAverageINTEL:         return nir_op_ihadd;
312    case SpvOpUAverageINTEL:         return nir_op_uhadd;
313    case SpvOpIAverageRoundedINTEL:  return nir_op_irhadd;
314    case SpvOpUAverageRoundedINTEL:  return nir_op_urhadd;
315    case SpvOpISubSatINTEL:          return nir_op_isub_sat;
316    case SpvOpUSubSatINTEL:          return nir_op_usub_sat;
317    case SpvOpIMul32x16INTEL:        return nir_op_imul_32x16;
318    case SpvOpUMul32x16INTEL:        return nir_op_umul_32x16;
319 
320    /* The ordered / unordered operators need special implementation besides
321     * the logical operator to use since they also need to check if operands are
322     * ordered.
323     */
324    case SpvOpFOrdEqual:                            *exact = true;  return nir_op_feq;
325    case SpvOpFUnordEqual:                          *exact = true;  return nir_op_feq;
326    case SpvOpINotEqual:                                            return nir_op_ine;
327    case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
328    case SpvOpFOrdNotEqual:                         *exact = true;  return nir_op_fneu;
329    case SpvOpFUnordNotEqual:                       *exact = true;  return nir_op_fneu;
330    case SpvOpULessThan:                                            return nir_op_ult;
331    case SpvOpSLessThan:                                            return nir_op_ilt;
332    case SpvOpFOrdLessThan:                         *exact = true;  return nir_op_flt;
333    case SpvOpFUnordLessThan:                       *exact = true;  return nir_op_flt;
334    case SpvOpUGreaterThan:          *swap = true;                  return nir_op_ult;
335    case SpvOpSGreaterThan:          *swap = true;                  return nir_op_ilt;
336    case SpvOpFOrdGreaterThan:       *swap = true;  *exact = true;  return nir_op_flt;
337    case SpvOpFUnordGreaterThan:     *swap = true;  *exact = true;  return nir_op_flt;
338    case SpvOpULessThanEqual:        *swap = true;                  return nir_op_uge;
339    case SpvOpSLessThanEqual:        *swap = true;                  return nir_op_ige;
340    case SpvOpFOrdLessThanEqual:     *swap = true;  *exact = true;  return nir_op_fge;
341    case SpvOpFUnordLessThanEqual:   *swap = true;  *exact = true;  return nir_op_fge;
342    case SpvOpUGreaterThanEqual:                                    return nir_op_uge;
343    case SpvOpSGreaterThanEqual:                                    return nir_op_ige;
344    case SpvOpFOrdGreaterThanEqual:                 *exact = true;  return nir_op_fge;
345    case SpvOpFUnordGreaterThanEqual:               *exact = true;  return nir_op_fge;
346 
347    /* Conversions: */
348    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
349    case SpvOpUConvert:
350    case SpvOpConvertFToU:
351    case SpvOpConvertFToS:
352    case SpvOpConvertSToF:
353    case SpvOpConvertUToF:
354    case SpvOpSConvert:
355    case SpvOpFConvert: {
356       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
357       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
358       return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
359    }
360 
361    case SpvOpPtrCastToGeneric:   return nir_op_mov;
362    case SpvOpGenericCastToPtr:   return nir_op_mov;
363 
364    /* Derivatives: */
365    case SpvOpDPdx:         return nir_op_fddx;
366    case SpvOpDPdy:         return nir_op_fddy;
367    case SpvOpDPdxFine:     return nir_op_fddx_fine;
368    case SpvOpDPdyFine:     return nir_op_fddy_fine;
369    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
370    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
371 
372    case SpvOpIsNormal:     return nir_op_fisnormal;
373    case SpvOpIsFinite:     return nir_op_fisfinite;
374 
375    default:
376       vtn_fail("No NIR equivalent: %u", opcode);
377    }
378 }
379 
380 static void
handle_no_contraction(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,UNUSED void * _void)381 handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
382                       UNUSED int member, const struct vtn_decoration *dec,
383                       UNUSED void *_void)
384 {
385    vtn_assert(dec->scope == VTN_DEC_DECORATION);
386    if (dec->decoration != SpvDecorationNoContraction)
387       return;
388 
389    b->nb.exact = true;
390 }
391 
392 void
vtn_handle_no_contraction(struct vtn_builder * b,struct vtn_value * val)393 vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val)
394 {
395    vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
396 }
397 
398 nir_rounding_mode
vtn_rounding_mode_to_nir(struct vtn_builder * b,SpvFPRoundingMode mode)399 vtn_rounding_mode_to_nir(struct vtn_builder *b, SpvFPRoundingMode mode)
400 {
401    switch (mode) {
402    case SpvFPRoundingModeRTE:
403       return nir_rounding_mode_rtne;
404    case SpvFPRoundingModeRTZ:
405       return nir_rounding_mode_rtz;
406    case SpvFPRoundingModeRTP:
407       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
408                   "FPRoundingModeRTP is only supported in kernels");
409       return nir_rounding_mode_ru;
410    case SpvFPRoundingModeRTN:
411       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
412                   "FPRoundingModeRTN is only supported in kernels");
413       return nir_rounding_mode_rd;
414    default:
415       vtn_fail("Unsupported rounding mode: %s",
416                spirv_fproundingmode_to_string(mode));
417       break;
418    }
419 }
420 
421 struct conversion_opts {
422    nir_rounding_mode rounding_mode;
423    bool saturate;
424 };
425 
426 static void
handle_conversion_opts(struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _opts)427 handle_conversion_opts(struct vtn_builder *b, UNUSED struct vtn_value *val,
428                        UNUSED int member,
429                        const struct vtn_decoration *dec, void *_opts)
430 {
431    struct conversion_opts *opts = _opts;
432 
433    switch (dec->decoration) {
434    case SpvDecorationFPRoundingMode:
435       opts->rounding_mode = vtn_rounding_mode_to_nir(b, dec->operands[0]);
436       break;
437 
438    case SpvDecorationSaturatedConversion:
439       vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
440                   "Saturated conversions are only allowed in kernels");
441       opts->saturate = true;
442       break;
443 
444    default:
445       break;
446    }
447 }
448 
449 static void
handle_no_wrap(UNUSED struct vtn_builder * b,UNUSED struct vtn_value * val,UNUSED int member,const struct vtn_decoration * dec,void * _alu)450 handle_no_wrap(UNUSED struct vtn_builder *b, UNUSED struct vtn_value *val,
451                UNUSED int member,
452                const struct vtn_decoration *dec, void *_alu)
453 {
454    nir_alu_instr *alu = _alu;
455    switch (dec->decoration) {
456    case SpvDecorationNoSignedWrap:
457       alu->no_signed_wrap = true;
458       break;
459    case SpvDecorationNoUnsignedWrap:
460       alu->no_unsigned_wrap = true;
461       break;
462    default:
463       /* Do nothing. */
464       break;
465    }
466 }
467 
468 void
vtn_handle_alu(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)469 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
470                const uint32_t *w, unsigned count)
471 {
472    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
473    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
474 
475    vtn_handle_no_contraction(b, dest_val);
476 
477    /* Collect the various SSA sources */
478    const unsigned num_inputs = count - 3;
479    struct vtn_ssa_value *vtn_src[4] = { NULL, };
480    for (unsigned i = 0; i < num_inputs; i++)
481       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
482 
483    if (glsl_type_is_matrix(vtn_src[0]->type) ||
484        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
485       vtn_push_ssa_value(b, w[2],
486          vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
487       b->nb.exact = b->exact;
488       return;
489    }
490 
491    struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
492    nir_ssa_def *src[4] = { NULL, };
493    for (unsigned i = 0; i < num_inputs; i++) {
494       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
495       src[i] = vtn_src[i]->def;
496    }
497 
498    switch (opcode) {
499    case SpvOpAny:
500       dest->def = nir_bany(&b->nb, src[0]);
501       break;
502 
503    case SpvOpAll:
504       dest->def = nir_ball(&b->nb, src[0]);
505       break;
506 
507    case SpvOpOuterProduct: {
508       for (unsigned i = 0; i < src[1]->num_components; i++) {
509          dest->elems[i]->def =
510             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
511       }
512       break;
513    }
514 
515    case SpvOpDot:
516       dest->def = nir_fdot(&b->nb, src[0], src[1]);
517       break;
518 
519    case SpvOpIAddCarry:
520       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
521       dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
522       dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
523       break;
524 
525    case SpvOpISubBorrow:
526       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
527       dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
528       dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
529       break;
530 
531    case SpvOpUMulExtended: {
532       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
533       nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
534       dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
535       dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
536       break;
537    }
538 
539    case SpvOpSMulExtended: {
540       vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
541       nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
542       dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
543       dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
544       break;
545    }
546 
547    case SpvOpFwidth:
548       dest->def = nir_fadd(&b->nb,
549                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
550                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
551       break;
552    case SpvOpFwidthFine:
553       dest->def = nir_fadd(&b->nb,
554                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
555                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
556       break;
557    case SpvOpFwidthCoarse:
558       dest->def = nir_fadd(&b->nb,
559                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
560                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
561       break;
562 
563    case SpvOpVectorTimesScalar:
564       /* The builder will take care of splatting for us. */
565       dest->def = nir_fmul(&b->nb, src[0], src[1]);
566       break;
567 
568    case SpvOpIsNan: {
569       const bool save_exact = b->nb.exact;
570 
571       b->nb.exact = true;
572       dest->def = nir_fneu(&b->nb, src[0], src[0]);
573       b->nb.exact = save_exact;
574       break;
575    }
576 
577    case SpvOpOrdered: {
578       const bool save_exact = b->nb.exact;
579 
580       b->nb.exact = true;
581       dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
582                                    nir_feq(&b->nb, src[1], src[1]));
583       b->nb.exact = save_exact;
584       break;
585    }
586 
587    case SpvOpUnordered: {
588       const bool save_exact = b->nb.exact;
589 
590       b->nb.exact = true;
591       dest->def = nir_ior(&b->nb, nir_fneu(&b->nb, src[0], src[0]),
592                                   nir_fneu(&b->nb, src[1], src[1]));
593       b->nb.exact = save_exact;
594       break;
595    }
596 
597    case SpvOpIsInf: {
598       nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
599       dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
600       break;
601    }
602 
603    case SpvOpFUnordEqual: {
604       const bool save_exact = b->nb.exact;
605 
606       b->nb.exact = true;
607 
608       /* This could also be implemented as !(a < b || b < a).  If one or both
609        * of the source are numbers, later optimization passes can easily
610        * eliminate the isnan() checks.  This may trim the sequence down to a
611        * single (a == b) operation.  Otherwise, the optimizer can transform
612        * whatever is left to !(a < b || b < a).  Since some applications will
613        * open-code this sequence, these optimizations are needed anyway.
614        */
615       dest->def =
616          nir_ior(&b->nb,
617                  nir_feq(&b->nb, src[0], src[1]),
618                  nir_ior(&b->nb,
619                          nir_fneu(&b->nb, src[0], src[0]),
620                          nir_fneu(&b->nb, src[1], src[1])));
621 
622       b->nb.exact = save_exact;
623       break;
624    }
625 
626    case SpvOpFUnordLessThan:
627    case SpvOpFUnordGreaterThan:
628    case SpvOpFUnordLessThanEqual:
629    case SpvOpFUnordGreaterThanEqual: {
630       bool swap;
631       bool unused_exact;
632       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
633       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
634       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
635                                                   &unused_exact,
636                                                   src_bit_size, dst_bit_size);
637 
638       if (swap) {
639          nir_ssa_def *tmp = src[0];
640          src[0] = src[1];
641          src[1] = tmp;
642       }
643 
644       const bool save_exact = b->nb.exact;
645 
646       b->nb.exact = true;
647 
648       /* Use the property FUnordLessThan(a, b) ≡ !FOrdGreaterThanEqual(a, b). */
649       switch (op) {
650       case nir_op_fge: op = nir_op_flt; break;
651       case nir_op_flt: op = nir_op_fge; break;
652       default: unreachable("Impossible opcode.");
653       }
654 
655       dest->def =
656          nir_inot(&b->nb,
657                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL));
658 
659       b->nb.exact = save_exact;
660       break;
661    }
662 
663    case SpvOpLessOrGreater:
664    case SpvOpFOrdNotEqual: {
665       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
666        * from the ALU will probably already be false if the operands are not
667        * ordered so we don’t need to handle it specially.
668        */
669       const bool save_exact = b->nb.exact;
670 
671       b->nb.exact = true;
672 
673       /* This could also be implemented as (a < b || b < a).  If one or both
674        * of the source are numbers, later optimization passes can easily
675        * eliminate the isnan() checks.  This may trim the sequence down to a
676        * single (a != b) operation.  Otherwise, the optimizer can transform
677        * whatever is left to (a < b || b < a).  Since some applications will
678        * open-code this sequence, these optimizations are needed anyway.
679        */
680       dest->def =
681          nir_iand(&b->nb,
682                   nir_fneu(&b->nb, src[0], src[1]),
683                   nir_iand(&b->nb,
684                           nir_feq(&b->nb, src[0], src[0]),
685                           nir_feq(&b->nb, src[1], src[1])));
686 
687       b->nb.exact = save_exact;
688       break;
689    }
690 
691    case SpvOpUConvert:
692    case SpvOpConvertFToU:
693    case SpvOpConvertFToS:
694    case SpvOpConvertSToF:
695    case SpvOpConvertUToF:
696    case SpvOpSConvert:
697    case SpvOpFConvert:
698    case SpvOpSatConvertSToU:
699    case SpvOpSatConvertUToS: {
700       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
701       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
702       nir_alu_type src_type = convert_op_src_type(opcode) | src_bit_size;
703       nir_alu_type dst_type = convert_op_dst_type(opcode) | dst_bit_size;
704 
705       struct conversion_opts opts = {
706          .rounding_mode = nir_rounding_mode_undef,
707          .saturate = false,
708       };
709       vtn_foreach_decoration(b, dest_val, handle_conversion_opts, &opts);
710 
711       if (opcode == SpvOpSatConvertSToU || opcode == SpvOpSatConvertUToS)
712          opts.saturate = true;
713 
714       if (b->shader->info.stage == MESA_SHADER_KERNEL) {
715          if (opts.rounding_mode == nir_rounding_mode_undef && !opts.saturate) {
716             nir_op op = nir_type_conversion_op(src_type, dst_type,
717                                                nir_rounding_mode_undef);
718             dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
719          } else {
720             dest->def = nir_convert_alu_types(&b->nb, dst_bit_size, src[0],
721                                               src_type, dst_type,
722                                               opts.rounding_mode, opts.saturate);
723          }
724       } else {
725          vtn_fail_if(opts.rounding_mode != nir_rounding_mode_undef &&
726                      dst_type != nir_type_float16,
727                      "Rounding modes are only allowed on conversions to "
728                      "16-bit float types");
729          nir_op op = nir_type_conversion_op(src_type, dst_type,
730                                             opts.rounding_mode);
731          dest->def = nir_build_alu(&b->nb, op, src[0], NULL, NULL, NULL);
732       }
733       break;
734    }
735 
736    case SpvOpBitFieldInsert:
737    case SpvOpBitFieldSExtract:
738    case SpvOpBitFieldUExtract:
739    case SpvOpShiftLeftLogical:
740    case SpvOpShiftRightArithmetic:
741    case SpvOpShiftRightLogical: {
742       bool swap;
743       bool exact;
744       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
745       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
746       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &exact,
747                                                   src0_bit_size, dst_bit_size);
748 
749       assert(!exact);
750 
751       assert (op == nir_op_ushr || op == nir_op_ishr || op == nir_op_ishl ||
752               op == nir_op_bitfield_insert || op == nir_op_ubitfield_extract ||
753               op == nir_op_ibitfield_extract);
754 
755       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
756          unsigned src_bit_size =
757             nir_alu_type_get_type_size(nir_op_infos[op].input_types[i]);
758          if (src_bit_size == 0)
759             continue;
760          if (src_bit_size != src[i]->bit_size) {
761             assert(src_bit_size == 32);
762             /* Convert the Shift, Offset and Count  operands to 32 bits, which is the bitsize
763              * supported by the NIR instructions. See discussion here:
764              *
765              * https://lists.freedesktop.org/archives/mesa-dev/2018-April/193026.html
766              */
767             src[i] = nir_u2u32(&b->nb, src[i]);
768          }
769       }
770       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
771       break;
772    }
773 
774    case SpvOpSignBitSet:
775       dest->def = nir_i2b(&b->nb,
776          nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
777       break;
778 
779    case SpvOpUCountTrailingZerosINTEL:
780       dest->def = nir_umin(&b->nb,
781                                nir_find_lsb(&b->nb, src[0]),
782                                nir_imm_int(&b->nb, 32u));
783       break;
784 
785    case SpvOpBitCount: {
786       /* bit_count always returns int32, but the SPIR-V opcode just says the return
787        * value needs to be big enough to store the number of bits.
788        */
789       dest->def = nir_u2u(&b->nb, nir_bit_count(&b->nb, src[0]), glsl_get_bit_size(dest_type));
790       break;
791    }
792 
793    case SpvOpSDotKHR:
794    case SpvOpUDotKHR:
795    case SpvOpSUDotKHR:
796    case SpvOpSDotAccSatKHR:
797    case SpvOpUDotAccSatKHR:
798    case SpvOpSUDotAccSatKHR:
799       unreachable("Should have called vtn_handle_integer_dot instead.");
800 
801    default: {
802       bool swap;
803       bool exact;
804       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
805       unsigned dst_bit_size = glsl_get_bit_size(dest_type);
806       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
807                                                   &exact,
808                                                   src_bit_size, dst_bit_size);
809 
810       if (swap) {
811          nir_ssa_def *tmp = src[0];
812          src[0] = src[1];
813          src[1] = tmp;
814       }
815 
816       switch (op) {
817       case nir_op_ishl:
818       case nir_op_ishr:
819       case nir_op_ushr:
820          if (src[1]->bit_size != 32)
821             src[1] = nir_u2u32(&b->nb, src[1]);
822          break;
823       default:
824          break;
825       }
826 
827       const bool save_exact = b->nb.exact;
828 
829       if (exact)
830          b->nb.exact = true;
831 
832       dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
833 
834       b->nb.exact = save_exact;
835       break;
836    } /* default */
837    }
838 
839    switch (opcode) {
840    case SpvOpIAdd:
841    case SpvOpIMul:
842    case SpvOpISub:
843    case SpvOpShiftLeftLogical:
844    case SpvOpSNegate: {
845       nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
846       vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
847       break;
848    }
849    default:
850       /* Do nothing. */
851       break;
852    }
853 
854    vtn_push_ssa_value(b, w[2], dest);
855 
856    b->nb.exact = b->exact;
857 }
858 
859 void
vtn_handle_integer_dot(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)860 vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
861                        const uint32_t *w, unsigned count)
862 {
863    struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
864    const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
865    const unsigned dest_size = glsl_get_bit_size(dest_type);
866 
867    vtn_handle_no_contraction(b, dest_val);
868 
869    /* Collect the various SSA sources.
870     *
871     * Due to the optional "Packed Vector Format" field, determine number of
872     * inputs from the opcode.  This differs from vtn_handle_alu.
873     */
874    const unsigned num_inputs = (opcode == SpvOpSDotAccSatKHR ||
875                                 opcode == SpvOpUDotAccSatKHR ||
876                                 opcode == SpvOpSUDotAccSatKHR) ? 3 : 2;
877 
878    vtn_assert(count >= num_inputs + 3);
879 
880    struct vtn_ssa_value *vtn_src[3] = { NULL, };
881    nir_ssa_def *src[3] = { NULL, };
882 
883    for (unsigned i = 0; i < num_inputs; i++) {
884       vtn_src[i] = vtn_ssa_value(b, w[i + 3]);
885       src[i] = vtn_src[i]->def;
886 
887       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
888    }
889 
890    /* For all of the opcodes *except* SpvOpSUDotKHR and SpvOpSUDotAccSatKHR,
891     * the SPV_KHR_integer_dot_product spec says:
892     *
893     *    _Vector 1_ and _Vector 2_ must have the same type.
894     *
895     * The practical requirement is the same bit-size and the same number of
896     * components.
897     */
898    vtn_fail_if(glsl_get_bit_size(vtn_src[0]->type) !=
899                glsl_get_bit_size(vtn_src[1]->type) ||
900                glsl_get_vector_elements(vtn_src[0]->type) !=
901                glsl_get_vector_elements(vtn_src[1]->type),
902                "Vector 1 and vector 2 source of opcode %s must have the same "
903                "type",
904                spirv_op_to_string(opcode));
905 
906    if (num_inputs == 3) {
907       /* The SPV_KHR_integer_dot_product spec says:
908        *
909        *    The type of Accumulator must be the same as Result Type.
910        *
911        * The handling of SpvOpSDotAccSatKHR and friends with the packed 4x8
912        * types (far below) assumes these types have the same size.
913        */
914       vtn_fail_if(dest_type != vtn_src[2]->type,
915                   "Accumulator type must be the same as Result Type for "
916                   "opcode %s",
917                   spirv_op_to_string(opcode));
918    }
919 
920    unsigned packed_bit_size = 8;
921    if (glsl_type_is_vector(vtn_src[0]->type)) {
922       /* FINISHME: Is this actually as good or better for platforms that don't
923        * have the special instructions (i.e., one or both of has_dot_4x8 or
924        * has_sudot_4x8 is false)?
925        */
926       if (glsl_get_vector_elements(vtn_src[0]->type) == 4 &&
927           glsl_get_bit_size(vtn_src[0]->type) == 8 &&
928           glsl_get_bit_size(dest_type) <= 32) {
929          src[0] = nir_pack_32_4x8(&b->nb, src[0]);
930          src[1] = nir_pack_32_4x8(&b->nb, src[1]);
931       } else if (glsl_get_vector_elements(vtn_src[0]->type) == 2 &&
932                  glsl_get_bit_size(vtn_src[0]->type) == 16 &&
933                  glsl_get_bit_size(dest_type) <= 32 &&
934                  opcode != SpvOpSUDotKHR &&
935                  opcode != SpvOpSUDotAccSatKHR) {
936          src[0] = nir_pack_32_2x16(&b->nb, src[0]);
937          src[1] = nir_pack_32_2x16(&b->nb, src[1]);
938          packed_bit_size = 16;
939       }
940    } else if (glsl_type_is_scalar(vtn_src[0]->type) &&
941               glsl_type_is_32bit(vtn_src[0]->type)) {
942       /* The SPV_KHR_integer_dot_product spec says:
943        *
944        *    When _Vector 1_ and _Vector 2_ are scalar integer types, _Packed
945        *    Vector Format_ must be specified to select how the integers are to
946        *    be interpreted as vectors.
947        *
948        * The "Packed Vector Format" value follows the last input.
949        */
950       vtn_assert(count == (num_inputs + 4));
951       const SpvPackedVectorFormat pack_format = w[num_inputs + 3];
952       vtn_fail_if(pack_format != SpvPackedVectorFormatPackedVectorFormat4x8BitKHR,
953                   "Unsupported vector packing format %d for opcode %s",
954                   pack_format, spirv_op_to_string(opcode));
955    } else {
956       vtn_fail_with_opcode("Invalid source types.", opcode);
957    }
958 
959    nir_ssa_def *dest = NULL;
960 
961    if (src[0]->num_components > 1) {
962       const nir_op s_conversion_op =
963          nir_type_conversion_op(nir_type_int, nir_type_int | dest_size,
964                                 nir_rounding_mode_undef);
965 
966       const nir_op u_conversion_op =
967          nir_type_conversion_op(nir_type_uint, nir_type_uint | dest_size,
968                                 nir_rounding_mode_undef);
969 
970       nir_op src0_conversion_op;
971       nir_op src1_conversion_op;
972 
973       switch (opcode) {
974       case SpvOpSDotKHR:
975       case SpvOpSDotAccSatKHR:
976          src0_conversion_op = s_conversion_op;
977          src1_conversion_op = s_conversion_op;
978          break;
979 
980       case SpvOpUDotKHR:
981       case SpvOpUDotAccSatKHR:
982          src0_conversion_op = u_conversion_op;
983          src1_conversion_op = u_conversion_op;
984          break;
985 
986       case SpvOpSUDotKHR:
987       case SpvOpSUDotAccSatKHR:
988          src0_conversion_op = s_conversion_op;
989          src1_conversion_op = u_conversion_op;
990          break;
991 
992       default:
993          unreachable("Invalid opcode.");
994       }
995 
996       /* The SPV_KHR_integer_dot_product spec says:
997        *
998        *    All components of the input vectors are sign-extended to the bit
999        *    width of the result's type. The sign-extended input vectors are
1000        *    then multiplied component-wise and all components of the vector
1001        *    resulting from the component-wise multiplication are added
1002        *    together. The resulting value will equal the low-order N bits of
1003        *    the correct result R, where N is the result width and R is
1004        *    computed with enough precision to avoid overflow and underflow.
1005        */
1006       const unsigned vector_components =
1007          glsl_get_vector_elements(vtn_src[0]->type);
1008 
1009       for (unsigned i = 0; i < vector_components; i++) {
1010          nir_ssa_def *const src0 =
1011             nir_build_alu(&b->nb, src0_conversion_op,
1012                           nir_channel(&b->nb, src[0], i), NULL, NULL, NULL);
1013 
1014          nir_ssa_def *const src1 =
1015             nir_build_alu(&b->nb, src1_conversion_op,
1016                           nir_channel(&b->nb, src[1], i), NULL, NULL, NULL);
1017 
1018          nir_ssa_def *const mul_result = nir_imul(&b->nb, src0, src1);
1019 
1020          dest = (i == 0) ? mul_result : nir_iadd(&b->nb, dest, mul_result);
1021       }
1022 
1023       if (num_inputs == 3) {
1024          /* For SpvOpSDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1025           *
1026           *    Signed integer dot product of _Vector 1_ and _Vector 2_ and
1027           *    signed saturating addition of the result with _Accumulator_.
1028           *
1029           * For SpvOpUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1030           *
1031           *    Unsigned integer dot product of _Vector 1_ and _Vector 2_ and
1032           *    unsigned saturating addition of the result with _Accumulator_.
1033           *
1034           * For SpvOpSUDotAccSatKHR, the SPV_KHR_integer_dot_product spec says:
1035           *
1036           *    Mixed-signedness integer dot product of _Vector 1_ and _Vector
1037           *    2_ and signed saturating addition of the result with
1038           *    _Accumulator_.
1039           */
1040          dest = (opcode == SpvOpUDotAccSatKHR)
1041             ? nir_uadd_sat(&b->nb, dest, src[2])
1042             : nir_iadd_sat(&b->nb, dest, src[2]);
1043       }
1044    } else {
1045       assert(src[0]->num_components == 1 && src[1]->num_components == 1);
1046       assert(src[0]->bit_size == 32 && src[1]->bit_size == 32);
1047 
1048       nir_ssa_def *const zero = nir_imm_zero(&b->nb, 1, 32);
1049       bool is_signed = opcode == SpvOpSDotKHR || opcode == SpvOpSUDotKHR ||
1050                        opcode == SpvOpSDotAccSatKHR || opcode == SpvOpSUDotAccSatKHR;
1051 
1052       if (packed_bit_size == 16) {
1053          switch (opcode) {
1054          case SpvOpSDotKHR:
1055             dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1056             break;
1057          case SpvOpUDotKHR:
1058             dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1059             break;
1060          case SpvOpSDotAccSatKHR:
1061             if (dest_size == 32)
1062                dest = nir_sdot_2x16_iadd_sat(&b->nb, src[0], src[1], src[2]);
1063             else
1064                dest = nir_sdot_2x16_iadd(&b->nb, src[0], src[1], zero);
1065             break;
1066          case SpvOpUDotAccSatKHR:
1067             if (dest_size == 32)
1068                dest = nir_udot_2x16_uadd_sat(&b->nb, src[0], src[1], src[2]);
1069             else
1070                dest = nir_udot_2x16_uadd(&b->nb, src[0], src[1], zero);
1071             break;
1072          default:
1073             unreachable("Invalid opcode.");
1074          }
1075       } else {
1076          switch (opcode) {
1077          case SpvOpSDotKHR:
1078             dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1079             break;
1080          case SpvOpUDotKHR:
1081             dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1082             break;
1083          case SpvOpSUDotKHR:
1084             dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1085             break;
1086          case SpvOpSDotAccSatKHR:
1087             if (dest_size == 32)
1088                dest = nir_sdot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1089             else
1090                dest = nir_sdot_4x8_iadd(&b->nb, src[0], src[1], zero);
1091             break;
1092          case SpvOpUDotAccSatKHR:
1093             if (dest_size == 32)
1094                dest = nir_udot_4x8_uadd_sat(&b->nb, src[0], src[1], src[2]);
1095             else
1096                dest = nir_udot_4x8_uadd(&b->nb, src[0], src[1], zero);
1097             break;
1098          case SpvOpSUDotAccSatKHR:
1099             if (dest_size == 32)
1100                dest = nir_sudot_4x8_iadd_sat(&b->nb, src[0], src[1], src[2]);
1101             else
1102                dest = nir_sudot_4x8_iadd(&b->nb, src[0], src[1], zero);
1103             break;
1104          default:
1105             unreachable("Invalid opcode.");
1106          }
1107       }
1108 
1109       if (dest_size != 32) {
1110          /* When the accumulator is 32-bits, a NIR dot-product with saturate
1111           * is generated above.  In all other cases a regular dot-product is
1112           * generated above, and separate addition with saturate is generated
1113           * here.
1114           *
1115           * The SPV_KHR_integer_dot_product spec says:
1116           *
1117           *    If any of the multiplications or additions, with the exception
1118           *    of the final accumulation, overflow or underflow, the result of
1119           *    the instruction is undefined.
1120           *
1121           * Therefore it is safe to cast the dot-product result down to the
1122           * size of the accumulator before doing the addition.  Since the
1123           * result of the dot-product cannot overflow 32-bits, this is also
1124           * safe to cast up.
1125           */
1126          if (num_inputs == 3) {
1127             dest = is_signed
1128                ? nir_iadd_sat(&b->nb, nir_i2i(&b->nb, dest, dest_size), src[2])
1129                : nir_uadd_sat(&b->nb, nir_u2u(&b->nb, dest, dest_size), src[2]);
1130          } else {
1131             dest = is_signed
1132                ? nir_i2i(&b->nb, dest, dest_size)
1133                : nir_u2u(&b->nb, dest, dest_size);
1134          }
1135       }
1136    }
1137 
1138    vtn_push_nir_ssa(b, w[2], dest);
1139 
1140    b->nb.exact = b->exact;
1141 }
1142 
1143 void
vtn_handle_bitcast(struct vtn_builder * b,const uint32_t * w,unsigned count)1144 vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
1145 {
1146    vtn_assert(count == 4);
1147    /* From the definition of OpBitcast in the SPIR-V 1.2 spec:
1148     *
1149     *    "If Result Type has the same number of components as Operand, they
1150     *    must also have the same component width, and results are computed per
1151     *    component.
1152     *
1153     *    If Result Type has a different number of components than Operand, the
1154     *    total number of bits in Result Type must equal the total number of
1155     *    bits in Operand. Let L be the type, either Result Type or Operand’s
1156     *    type, that has the larger number of components. Let S be the other
1157     *    type, with the smaller number of components. The number of components
1158     *    in L must be an integer multiple of the number of components in S.
1159     *    The first component (that is, the only or lowest-numbered component)
1160     *    of S maps to the first components of L, and so on, up to the last
1161     *    component of S mapping to the last components of L. Within this
1162     *    mapping, any single component of S (mapping to multiple components of
1163     *    L) maps its lower-ordered bits to the lower-numbered components of L."
1164     */
1165 
1166    struct vtn_type *type = vtn_get_type(b, w[1]);
1167    struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
1168 
1169    vtn_fail_if(src->num_components * src->bit_size !=
1170                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
1171                "Source and destination of OpBitcast must have the same "
1172                "total number of bits");
1173    nir_ssa_def *val =
1174       nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
1175    vtn_push_nir_ssa(b, w[2], val);
1176 }
1177