1#
2# Copyright (C) 2014 Connor Abbott
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22#
23# Authors:
24#    Connor Abbott (cwabbott0@gmail.com)
25
26import re
27
28# Class that represents all the information we have about the opcode
29# NOTE: this must be kept in sync with nir_op_info
30
31class Opcode(object):
32   """Class that represents all the information we have about the opcode
33   NOTE: this must be kept in sync with nir_op_info
34   """
35   def __init__(self, name, output_size, output_type, input_sizes,
36                input_types, is_conversion, algebraic_properties, const_expr):
37      """Parameters:
38
39      - name is the name of the opcode (prepend nir_op_ for the enum name)
40      - all types are strings that get nir_type_ prepended to them
41      - input_types is a list of types
42      - is_conversion is true if this opcode represents a type conversion
43      - algebraic_properties is a space-seperated string, where nir_op_is_ is
44        prepended before each entry
45      - const_expr is an expression or series of statements that computes the
46        constant value of the opcode given the constant values of its inputs.
47
48      Constant expressions are formed from the variables src0, src1, ...,
49      src(N-1), where N is the number of arguments.  The output of the
50      expression should be stored in the dst variable.  Per-component input
51      and output variables will be scalars and non-per-component input and
52      output variables will be a struct with fields named x, y, z, and w
53      all of the correct type.  Input and output variables can be assumed
54      to already be of the correct type and need no conversion.  In
55      particular, the conversion from the C bool type to/from  NIR_TRUE and
56      NIR_FALSE happens automatically.
57
58      For per-component instructions, the entire expression will be
59      executed once for each component.  For non-per-component
60      instructions, the expression is expected to store the correct values
61      in dst.x, dst.y, etc.  If "dst" does not exist anywhere in the
62      constant expression, an assignment to dst will happen automatically
63      and the result will be equivalent to "dst = <expression>" for
64      per-component instructions and "dst.x = dst.y = ... = <expression>"
65      for non-per-component instructions.
66      """
67      assert isinstance(name, str)
68      assert isinstance(output_size, int)
69      assert isinstance(output_type, str)
70      assert isinstance(input_sizes, list)
71      assert isinstance(input_sizes[0], int)
72      assert isinstance(input_types, list)
73      assert isinstance(input_types[0], str)
74      assert isinstance(is_conversion, bool)
75      assert isinstance(algebraic_properties, str)
76      assert isinstance(const_expr, str)
77      assert len(input_sizes) == len(input_types)
78      assert 0 <= output_size <= 5 or (output_size == 8) or (output_size == 16)
79      for size in input_sizes:
80         assert 0 <= size <= 5 or (size == 8) or (size == 16)
81         if output_size == 0:
82            assert size == 0
83         if output_size != 0:
84            assert size != 0
85      self.name = name
86      self.num_inputs = len(input_sizes)
87      self.output_size = output_size
88      self.output_type = output_type
89      self.input_sizes = input_sizes
90      self.input_types = input_types
91      self.is_conversion = is_conversion
92      self.algebraic_properties = algebraic_properties
93      self.const_expr = const_expr
94
95# helper variables for strings
96tfloat = "float"
97tint = "int"
98tbool = "bool"
99tbool1 = "bool1"
100tbool8 = "bool8"
101tbool16 = "bool16"
102tbool32 = "bool32"
103tuint = "uint"
104tuint8 = "uint8"
105tint16 = "int16"
106tuint16 = "uint16"
107tfloat16 = "float16"
108tfloat32 = "float32"
109tint32 = "int32"
110tuint32 = "uint32"
111tint64 = "int64"
112tuint64 = "uint64"
113tfloat64 = "float64"
114
115_TYPE_SPLIT_RE = re.compile(r'(?P<type>int|uint|float|bool)(?P<bits>\d+)?')
116
117def type_has_size(type_):
118    m = _TYPE_SPLIT_RE.match(type_)
119    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
120    return m.group('bits') is not None
121
122def type_size(type_):
123    m = _TYPE_SPLIT_RE.match(type_)
124    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
125    assert m.group('bits') is not None, \
126           'NIR type string has no bit size: "{}"'.format(type_)
127    return int(m.group('bits'))
128
129def type_sizes(type_):
130    if type_has_size(type_):
131        return [type_size(type_)]
132    elif type_ == 'bool':
133        return [1, 8, 16, 32]
134    elif type_ == 'float':
135        return [16, 32, 64]
136    else:
137        return [1, 8, 16, 32, 64]
138
139def type_base_type(type_):
140    m = _TYPE_SPLIT_RE.match(type_)
141    assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
142    return m.group('type')
143
144# Operation where the first two sources are commutative.
145#
146# For 2-source operations, this just mathematical commutativity.  Some
147# 3-source operations, like ffma, are only commutative in the first two
148# sources.
149_2src_commutative = "2src_commutative "
150associative = "associative "
151
152# global dictionary of opcodes
153opcodes = {}
154
155def opcode(name, output_size, output_type, input_sizes, input_types,
156           is_conversion, algebraic_properties, const_expr):
157   assert name not in opcodes
158   opcodes[name] = Opcode(name, output_size, output_type, input_sizes,
159                          input_types, is_conversion, algebraic_properties,
160                          const_expr)
161
162def unop_convert(name, out_type, in_type, const_expr):
163   opcode(name, 0, out_type, [0], [in_type], False, "", const_expr)
164
165def unop(name, ty, const_expr):
166   opcode(name, 0, ty, [0], [ty], False, "", const_expr)
167
168def unop_horiz(name, output_size, output_type, input_size, input_type,
169               const_expr):
170   opcode(name, output_size, output_type, [input_size], [input_type],
171          False, "", const_expr)
172
173def unop_reduce(name, output_size, output_type, input_type, prereduce_expr,
174                reduce_expr, final_expr):
175   def prereduce(src):
176      return "(" + prereduce_expr.format(src=src) + ")"
177   def final(src):
178      return final_expr.format(src="(" + src + ")")
179   def reduce_(src0, src1):
180      return reduce_expr.format(src0=src0, src1=src1)
181   src0 = prereduce("src0.x")
182   src1 = prereduce("src0.y")
183   src2 = prereduce("src0.z")
184   src3 = prereduce("src0.w")
185   unop_horiz(name + "2", output_size, output_type, 2, input_type,
186              final(reduce_(src0, src1)))
187   unop_horiz(name + "3", output_size, output_type, 3, input_type,
188              final(reduce_(reduce_(src0, src1), src2)))
189   unop_horiz(name + "4", output_size, output_type, 4, input_type,
190              final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
191
192def unop_numeric_convert(name, out_type, in_type, const_expr):
193   opcode(name, 0, out_type, [0], [in_type], True, "", const_expr)
194
195unop("mov", tuint, "src0")
196
197unop("ineg", tint, "-src0")
198unop("fneg", tfloat, "-src0")
199unop("inot", tint, "~src0") # invert every bit of the integer
200
201# nir_op_fsign roughly implements the OpenGL / Vulkan rules for sign(float).
202# The GLSL.std.450 FSign instruction is defined as:
203#
204#    Result is 1.0 if x > 0, 0.0 if x = 0, or -1.0 if x < 0.
205#
206# If the source is equal to zero, there is a preference for the result to have
207# the same sign, but this is not required (it is required by OpenCL).  If the
208# source is not a number, there is a preference for the result to be +0.0, but
209# this is not required (it is required by OpenCL).  If the source is not a
210# number, and the result is not +0.0, the result should definitely **not** be
211# NaN.
212#
213# The values returned for constant folding match the behavior required by
214# OpenCL.
215unop("fsign", tfloat, ("bit_size == 64 ? " +
216                       "(isnan(src0) ? 0.0  : ((src0 == 0.0 ) ? src0 : (src0 > 0.0 ) ? 1.0  : -1.0 )) : " +
217                       "(isnan(src0) ? 0.0f : ((src0 == 0.0f) ? src0 : (src0 > 0.0f) ? 1.0f : -1.0f))"))
218unop("isign", tint, "(src0 == 0) ? 0 : ((src0 > 0) ? 1 : -1)")
219unop("iabs", tint, "(src0 < 0) ? -src0 : src0")
220unop("fabs", tfloat, "fabs(src0)")
221unop("fsat", tfloat, ("fmin(fmax(src0, 0.0), 1.0)"))
222unop("frcp", tfloat, "bit_size == 64 ? 1.0 / src0 : 1.0f / src0")
223unop("frsq", tfloat, "bit_size == 64 ? 1.0 / sqrt(src0) : 1.0f / sqrtf(src0)")
224unop("fsqrt", tfloat, "bit_size == 64 ? sqrt(src0) : sqrtf(src0)")
225unop("fexp2", tfloat, "exp2f(src0)")
226unop("flog2", tfloat, "log2f(src0)")
227
228# Generate all of the numeric conversion opcodes
229for src_t in [tint, tuint, tfloat, tbool]:
230   if src_t == tbool:
231      dst_types = [tfloat, tint, tbool]
232   elif src_t == tint:
233      dst_types = [tfloat, tint, tbool]
234   elif src_t == tuint:
235      dst_types = [tfloat, tuint]
236   elif src_t == tfloat:
237      dst_types = [tint, tuint, tfloat, tbool]
238
239   for dst_t in dst_types:
240      for dst_bit_size in type_sizes(dst_t):
241          if dst_bit_size == 16 and dst_t == tfloat and src_t == tfloat:
242              rnd_modes = ['_rtne', '_rtz', '']
243              for rnd_mode in rnd_modes:
244                  if rnd_mode == '_rtne':
245                      conv_expr = """
246                      if (bit_size > 16) {
247                         dst = _mesa_half_to_float(_mesa_float_to_float16_rtne(src0));
248                      } else {
249                         dst = src0;
250                      }
251                      """
252                  elif rnd_mode == '_rtz':
253                      conv_expr = """
254                      if (bit_size > 16) {
255                         dst = _mesa_half_to_float(_mesa_float_to_float16_rtz(src0));
256                      } else {
257                         dst = src0;
258                      }
259                      """
260                  else:
261                      conv_expr = "src0"
262
263                  unop_numeric_convert("{0}2{1}{2}{3}".format(src_t[0],
264                                                              dst_t[0],
265                                                              dst_bit_size,
266                                                              rnd_mode),
267                                       dst_t + str(dst_bit_size),
268                                       src_t, conv_expr)
269          elif dst_bit_size == 32 and dst_t == tfloat and src_t == tfloat:
270              conv_expr = """
271              if (bit_size > 32 && nir_is_rounding_mode_rtz(execution_mode, 32)) {
272                 dst = _mesa_double_to_float_rtz(src0);
273              } else {
274                 dst = src0;
275              }
276              """
277              unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
278                                                       dst_bit_size),
279                                   dst_t + str(dst_bit_size), src_t, conv_expr)
280          else:
281              conv_expr = "src0 != 0" if dst_t == tbool else "src0"
282              unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
283                                                       dst_bit_size),
284                                   dst_t + str(dst_bit_size), src_t, conv_expr)
285
286# Special opcode that is the same as f2f16, i2i16, u2u16 except that it is safe
287# to remove it if the result is immediately converted back to 32 bits again.
288# This is generated as part of the precision lowering pass. mp stands for medium
289# precision.
290unop_numeric_convert("f2fmp", tfloat16, tfloat32, opcodes["f2f16"].const_expr)
291unop_numeric_convert("i2imp", tint16, tint32, opcodes["i2i16"].const_expr)
292# u2ump isn't defined, because the behavior is equal to i2imp
293unop_numeric_convert("f2imp", tint16, tfloat32, opcodes["f2i16"].const_expr)
294unop_numeric_convert("f2ump", tuint16, tfloat32, opcodes["f2u16"].const_expr)
295unop_numeric_convert("i2fmp", tfloat16, tint32, opcodes["i2f16"].const_expr)
296unop_numeric_convert("u2fmp", tfloat16, tuint32, opcodes["u2f16"].const_expr)
297
298# Unary floating-point rounding operations.
299
300
301unop("ftrunc", tfloat, "bit_size == 64 ? trunc(src0) : truncf(src0)")
302unop("fceil", tfloat, "bit_size == 64 ? ceil(src0) : ceilf(src0)")
303unop("ffloor", tfloat, "bit_size == 64 ? floor(src0) : floorf(src0)")
304unop("ffract", tfloat, "src0 - (bit_size == 64 ? floor(src0) : floorf(src0))")
305unop("fround_even", tfloat, "bit_size == 64 ? _mesa_roundeven(src0) : _mesa_roundevenf(src0)")
306
307unop("fquantize2f16", tfloat, "(fabs(src0) < ldexpf(1.0, -14)) ? copysignf(0.0f, src0) : _mesa_half_to_float(_mesa_float_to_half(src0))")
308
309# Trigonometric operations.
310
311
312unop("fsin", tfloat, "bit_size == 64 ? sin(src0) : sinf(src0)")
313unop("fcos", tfloat, "bit_size == 64 ? cos(src0) : cosf(src0)")
314
315# dfrexp
316unop_convert("frexp_exp", tint32, tfloat, "frexp(src0, &dst);")
317unop_convert("frexp_sig", tfloat, tfloat, "int n; dst = frexp(src0, &n);")
318
319# Partial derivatives.
320
321
322unop("fddx", tfloat, "0.0") # the derivative of a constant is 0.
323unop("fddy", tfloat, "0.0")
324unop("fddx_fine", tfloat, "0.0")
325unop("fddy_fine", tfloat, "0.0")
326unop("fddx_coarse", tfloat, "0.0")
327unop("fddy_coarse", tfloat, "0.0")
328
329
330# Floating point pack and unpack operations.
331
332def pack_2x16(fmt):
333   unop_horiz("pack_" + fmt + "_2x16", 1, tuint32, 2, tfloat32, """
334dst.x = (uint32_t) pack_fmt_1x16(src0.x);
335dst.x |= ((uint32_t) pack_fmt_1x16(src0.y)) << 16;
336""".replace("fmt", fmt))
337
338def pack_4x8(fmt):
339   unop_horiz("pack_" + fmt + "_4x8", 1, tuint32, 4, tfloat32, """
340dst.x = (uint32_t) pack_fmt_1x8(src0.x);
341dst.x |= ((uint32_t) pack_fmt_1x8(src0.y)) << 8;
342dst.x |= ((uint32_t) pack_fmt_1x8(src0.z)) << 16;
343dst.x |= ((uint32_t) pack_fmt_1x8(src0.w)) << 24;
344""".replace("fmt", fmt))
345
346def unpack_2x16(fmt):
347   unop_horiz("unpack_" + fmt + "_2x16", 2, tfloat32, 1, tuint32, """
348dst.x = unpack_fmt_1x16((uint16_t)(src0.x & 0xffff));
349dst.y = unpack_fmt_1x16((uint16_t)(src0.x << 16));
350""".replace("fmt", fmt))
351
352def unpack_4x8(fmt):
353   unop_horiz("unpack_" + fmt + "_4x8", 4, tfloat32, 1, tuint32, """
354dst.x = unpack_fmt_1x8((uint8_t)(src0.x & 0xff));
355dst.y = unpack_fmt_1x8((uint8_t)((src0.x >> 8) & 0xff));
356dst.z = unpack_fmt_1x8((uint8_t)((src0.x >> 16) & 0xff));
357dst.w = unpack_fmt_1x8((uint8_t)(src0.x >> 24));
358""".replace("fmt", fmt))
359
360
361pack_2x16("snorm")
362pack_4x8("snorm")
363pack_2x16("unorm")
364pack_4x8("unorm")
365pack_2x16("half")
366unpack_2x16("snorm")
367unpack_4x8("snorm")
368unpack_2x16("unorm")
369unpack_4x8("unorm")
370unpack_2x16("half")
371
372unop_horiz("pack_uvec2_to_uint", 1, tuint32, 2, tuint32, """
373dst.x = (src0.x & 0xffff) | (src0.y << 16);
374""")
375
376unop_horiz("pack_uvec4_to_uint", 1, tuint32, 4, tuint32, """
377dst.x = (src0.x <<  0) |
378        (src0.y <<  8) |
379        (src0.z << 16) |
380        (src0.w << 24);
381""")
382
383unop_horiz("pack_32_4x8", 1, tuint32, 4, tuint8,
384           "dst.x = src0.x | ((uint32_t)src0.y << 8) | ((uint32_t)src0.z << 16) | ((uint32_t)src0.w << 24);")
385
386unop_horiz("pack_32_2x16", 1, tuint32, 2, tuint16,
387           "dst.x = src0.x | ((uint32_t)src0.y << 16);")
388
389unop_horiz("pack_64_2x32", 1, tuint64, 2, tuint32,
390           "dst.x = src0.x | ((uint64_t)src0.y << 32);")
391
392unop_horiz("pack_64_4x16", 1, tuint64, 4, tuint16,
393           "dst.x = src0.x | ((uint64_t)src0.y << 16) | ((uint64_t)src0.z << 32) | ((uint64_t)src0.w << 48);")
394
395unop_horiz("unpack_64_2x32", 2, tuint32, 1, tuint64,
396           "dst.x = src0.x; dst.y = src0.x >> 32;")
397
398unop_horiz("unpack_64_4x16", 4, tuint16, 1, tuint64,
399           "dst.x = src0.x; dst.y = src0.x >> 16; dst.z = src0.x >> 32; dst.w = src0.w >> 48;")
400
401unop_horiz("unpack_32_2x16", 2, tuint16, 1, tuint32,
402           "dst.x = src0.x; dst.y = src0.x >> 16;")
403
404unop_horiz("unpack_32_4x8", 4, tuint8, 1, tuint32,
405           "dst.x = src0.x; dst.y = src0.x >> 8; dst.z = src0.x >> 16; dst.w = src0.x >> 24;")
406
407unop_horiz("unpack_half_2x16_flush_to_zero", 2, tfloat32, 1, tuint32, """
408dst.x = unpack_half_1x16_flush_to_zero((uint16_t)(src0.x & 0xffff));
409dst.y = unpack_half_1x16_flush_to_zero((uint16_t)(src0.x << 16));
410""")
411
412# Lowered floating point unpacking operations.
413
414unop_convert("unpack_half_2x16_split_x", tfloat32, tuint32,
415             "unpack_half_1x16((uint16_t)(src0 & 0xffff))")
416unop_convert("unpack_half_2x16_split_y", tfloat32, tuint32,
417             "unpack_half_1x16((uint16_t)(src0 >> 16))")
418
419unop_convert("unpack_half_2x16_split_x_flush_to_zero", tfloat32, tuint32,
420             "unpack_half_1x16_flush_to_zero((uint16_t)(src0 & 0xffff))")
421unop_convert("unpack_half_2x16_split_y_flush_to_zero", tfloat32, tuint32,
422             "unpack_half_1x16_flush_to_zero((uint16_t)(src0 >> 16))")
423
424unop_convert("unpack_32_2x16_split_x", tuint16, tuint32, "src0")
425unop_convert("unpack_32_2x16_split_y", tuint16, tuint32, "src0 >> 16")
426
427unop_convert("unpack_64_2x32_split_x", tuint32, tuint64, "src0")
428unop_convert("unpack_64_2x32_split_y", tuint32, tuint64, "src0 >> 32")
429
430# Bit operations, part of ARB_gpu_shader5.
431
432
433unop("bitfield_reverse", tuint32, """
434/* we're not winning any awards for speed here, but that's ok */
435dst = 0;
436for (unsigned bit = 0; bit < 32; bit++)
437   dst |= ((src0 >> bit) & 1) << (31 - bit);
438""")
439unop_convert("bit_count", tuint32, tuint, """
440dst = 0;
441for (unsigned bit = 0; bit < bit_size; bit++) {
442   if ((src0 >> bit) & 1)
443      dst++;
444}
445""")
446
447unop_convert("ufind_msb", tint32, tuint, """
448dst = -1;
449for (int bit = bit_size - 1; bit >= 0; bit--) {
450   if ((src0 >> bit) & 1) {
451      dst = bit;
452      break;
453   }
454}
455""")
456
457unop_convert("ufind_msb_rev", tint32, tuint, """
458dst = -1;
459for (int bit = 0; bit < bit_size; bit++) {
460   if ((src0 << bit) & 0x80000000) {
461      dst = bit;
462      break;
463   }
464}
465""")
466
467unop("uclz", tuint32, """
468int bit;
469for (bit = bit_size - 1; bit >= 0; bit--) {
470   if ((src0 & (1u << bit)) != 0)
471      break;
472}
473dst = (unsigned)(31 - bit);
474""")
475
476unop("ifind_msb", tint32, """
477dst = -1;
478for (int bit = 31; bit >= 0; bit--) {
479   /* If src0 < 0, we're looking for the first 0 bit.
480    * if src0 >= 0, we're looking for the first 1 bit.
481    */
482   if ((((src0 >> bit) & 1) && (src0 >= 0)) ||
483      (!((src0 >> bit) & 1) && (src0 < 0))) {
484      dst = bit;
485      break;
486   }
487}
488""")
489
490unop_convert("ifind_msb_rev", tint32, tint, """
491dst = -1;
492if (src0 != 0 && src0 != -1) {
493   for (int bit = 0; bit < 31; bit++) {
494      /* If src0 < 0, we're looking for the first 0 bit.
495       * if src0 >= 0, we're looking for the first 1 bit.
496       */
497      if ((((src0 << bit) & 0x40000000) && (src0 >= 0)) ||
498          ((!((src0 << bit) & 0x40000000)) && (src0 < 0))) {
499         dst = bit;
500         break;
501      }
502   }
503}
504""")
505
506unop_convert("find_lsb", tint32, tint, """
507dst = -1;
508for (unsigned bit = 0; bit < bit_size; bit++) {
509   if ((src0 >> bit) & 1) {
510      dst = bit;
511      break;
512   }
513}
514""")
515
516# AMD_gcn_shader extended instructions
517unop_horiz("cube_face_coord_amd", 2, tfloat32, 3, tfloat32, """
518dst.x = dst.y = 0.0;
519float absX = fabsf(src0.x);
520float absY = fabsf(src0.y);
521float absZ = fabsf(src0.z);
522
523float ma = 0.0;
524if (absX >= absY && absX >= absZ) { ma = 2 * src0.x; }
525if (absY >= absX && absY >= absZ) { ma = 2 * src0.y; }
526if (absZ >= absX && absZ >= absY) { ma = 2 * src0.z; }
527
528if (src0.x >= 0 && absX >= absY && absX >= absZ) { dst.x = -src0.z; dst.y = -src0.y; }
529if (src0.x < 0 && absX >= absY && absX >= absZ) { dst.x = src0.z; dst.y = -src0.y; }
530if (src0.y >= 0 && absY >= absX && absY >= absZ) { dst.x = src0.x; dst.y = src0.z; }
531if (src0.y < 0 && absY >= absX && absY >= absZ) { dst.x = src0.x; dst.y = -src0.z; }
532if (src0.z >= 0 && absZ >= absX && absZ >= absY) { dst.x = src0.x; dst.y = -src0.y; }
533if (src0.z < 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.x; dst.y = -src0.y; }
534
535dst.x = dst.x * (1.0f / ma) + 0.5f;
536dst.y = dst.y * (1.0f / ma) + 0.5f;
537""")
538
539unop_horiz("cube_face_index_amd", 1, tfloat32, 3, tfloat32, """
540dst.x = 0.0;
541float absX = fabsf(src0.x);
542float absY = fabsf(src0.y);
543float absZ = fabsf(src0.z);
544if (src0.x >= 0 && absX >= absY && absX >= absZ) dst.x = 0;
545if (src0.x < 0 && absX >= absY && absX >= absZ) dst.x = 1;
546if (src0.y >= 0 && absY >= absX && absY >= absZ) dst.x = 2;
547if (src0.y < 0 && absY >= absX && absY >= absZ) dst.x = 3;
548if (src0.z >= 0 && absZ >= absX && absZ >= absY) dst.x = 4;
549if (src0.z < 0 && absZ >= absX && absZ >= absY) dst.x = 5;
550""")
551
552# Sum of vector components
553unop_reduce("fsum", 1, tfloat, tfloat, "{src}", "{src0} + {src1}", "{src}")
554
555def binop_convert(name, out_type, in_type, alg_props, const_expr):
556   opcode(name, 0, out_type, [0, 0], [in_type, in_type],
557          False, alg_props, const_expr)
558
559def binop(name, ty, alg_props, const_expr):
560   binop_convert(name, ty, ty, alg_props, const_expr)
561
562def binop_compare(name, ty, alg_props, const_expr):
563   binop_convert(name, tbool1, ty, alg_props, const_expr)
564
565def binop_compare8(name, ty, alg_props, const_expr):
566   binop_convert(name, tbool8, ty, alg_props, const_expr)
567
568def binop_compare16(name, ty, alg_props, const_expr):
569   binop_convert(name, tbool16, ty, alg_props, const_expr)
570
571def binop_compare32(name, ty, alg_props, const_expr):
572   binop_convert(name, tbool32, ty, alg_props, const_expr)
573
574def binop_compare_all_sizes(name, ty, alg_props, const_expr):
575   binop_compare(name, ty, alg_props, const_expr)
576   binop_compare8(name + "8", ty, alg_props, const_expr)
577   binop_compare16(name + "16", ty, alg_props, const_expr)
578   binop_compare32(name + "32", ty, alg_props, const_expr)
579
580def binop_horiz(name, out_size, out_type, src1_size, src1_type, src2_size,
581                src2_type, const_expr):
582   opcode(name, out_size, out_type, [src1_size, src2_size], [src1_type, src2_type],
583          False, "", const_expr)
584
585def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
586                 reduce_expr, final_expr, suffix=""):
587   def final(src):
588      return final_expr.format(src= "(" + src + ")")
589   def reduce_(src0, src1):
590      return reduce_expr.format(src0=src0, src1=src1)
591   def prereduce(src0, src1):
592      return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
593   srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
594   def pairwise_reduce(start, size):
595      if (size == 1):
596         return srcs[start]
597      return reduce_(pairwise_reduce(start + size // 2, size // 2), pairwise_reduce(start, size // 2))
598   for size in [2, 4, 8, 16]:
599      opcode(name + str(size) + suffix, output_size, output_type,
600             [size, size], [src_type, src_type], False, _2src_commutative,
601             final(pairwise_reduce(0, size)))
602   opcode(name + "3" + suffix, output_size, output_type,
603          [3, 3], [src_type, src_type], False, _2src_commutative,
604          final(reduce_(reduce_(srcs[2], srcs[1]), srcs[0])))
605   opcode(name + "5" + suffix, output_size, output_type,
606          [5, 5], [src_type, src_type], False, _2src_commutative,
607          final(reduce_(srcs[4], reduce_(reduce_(srcs[3], srcs[2]), reduce_(srcs[1], srcs[0])))))
608
609def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
610                           reduce_expr, final_expr):
611   binop_reduce(name, output_size, tbool1, src_type,
612                prereduce_expr, reduce_expr, final_expr)
613   binop_reduce("b8" + name[1:], output_size, tbool8, src_type,
614                prereduce_expr, reduce_expr, final_expr)
615   binop_reduce("b16" + name[1:], output_size, tbool16, src_type,
616                prereduce_expr, reduce_expr, final_expr)
617   binop_reduce("b32" + name[1:], output_size, tbool32, src_type,
618                prereduce_expr, reduce_expr, final_expr)
619
620binop("fadd", tfloat, _2src_commutative + associative,"""
621if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
622   if (bit_size == 64)
623      dst = _mesa_double_add_rtz(src0, src1);
624   else
625      dst = _mesa_double_to_float_rtz((double)src0 + (double)src1);
626} else {
627   dst = src0 + src1;
628}
629""")
630binop("iadd", tint, _2src_commutative + associative, "(uint64_t)src0 + (uint64_t)src1")
631binop("iadd_sat", tint, _2src_commutative, """
632      src1 > 0 ?
633         (src0 + src1 < src0 ? u_intN_max(bit_size) : src0 + src1) :
634         (src0 < src0 + src1 ? u_intN_min(bit_size) : src0 + src1)
635""")
636binop("uadd_sat", tuint, _2src_commutative,
637      "(src0 + src1) < src0 ? u_uintN_max(sizeof(src0) * 8) : (src0 + src1)")
638binop("isub_sat", tint, "", """
639      src1 < 0 ?
640         (src0 - src1 < src0 ? u_intN_max(bit_size) : src0 - src1) :
641         (src0 < src0 - src1 ? u_intN_min(bit_size) : src0 - src1)
642""")
643binop("usub_sat", tuint, "", "src0 < src1 ? 0 : src0 - src1")
644
645binop("fsub", tfloat, "", """
646if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
647   if (bit_size == 64)
648      dst = _mesa_double_sub_rtz(src0, src1);
649   else
650      dst = _mesa_double_to_float_rtz((double)src0 - (double)src1);
651} else {
652   dst = src0 - src1;
653}
654""")
655binop("isub", tint, "", "src0 - src1")
656binop_convert("uabs_isub", tuint, tint, "", """
657              src1 > src0 ? (uint64_t) src1 - (uint64_t) src0
658                          : (uint64_t) src0 - (uint64_t) src1
659""")
660binop("uabs_usub", tuint, "", "(src1 > src0) ? (src1 - src0) : (src0 - src1)")
661
662binop("fmul", tfloat, _2src_commutative + associative, """
663if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
664   if (bit_size == 64)
665      dst = _mesa_double_mul_rtz(src0, src1);
666   else
667      dst = _mesa_double_to_float_rtz((double)src0 * (double)src1);
668} else {
669   dst = src0 * src1;
670}
671""")
672# low 32-bits of signed/unsigned integer multiply
673binop("imul", tint, _2src_commutative + associative, """
674   /* Use 64-bit multiplies to prevent overflow of signed arithmetic */
675   dst = (uint64_t)src0 * (uint64_t)src1;
676""")
677
678# Generate 64 bit result from 2 32 bits quantity
679binop_convert("imul_2x32_64", tint64, tint32, _2src_commutative,
680              "(int64_t)src0 * (int64_t)src1")
681binop_convert("umul_2x32_64", tuint64, tuint32, _2src_commutative,
682              "(uint64_t)src0 * (uint64_t)src1")
683
684# high 32-bits of signed integer multiply
685binop("imul_high", tint, _2src_commutative, """
686if (bit_size == 64) {
687   /* We need to do a full 128-bit x 128-bit multiply in order for the sign
688    * extension to work properly.  The casts are kind-of annoying but needed
689    * to prevent compiler warnings.
690    */
691   uint32_t src0_u32[4] = {
692      src0,
693      (int64_t)src0 >> 32,
694      (int64_t)src0 >> 63,
695      (int64_t)src0 >> 63,
696   };
697   uint32_t src1_u32[4] = {
698      src1,
699      (int64_t)src1 >> 32,
700      (int64_t)src1 >> 63,
701      (int64_t)src1 >> 63,
702   };
703   uint32_t prod_u32[4];
704   ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
705   dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
706} else {
707   /* First, sign-extend to 64-bit, then convert to unsigned to prevent
708    * potential overflow of signed multiply */
709   dst = ((uint64_t)(int64_t)src0 * (uint64_t)(int64_t)src1) >> bit_size;
710}
711""")
712
713# high 32-bits of unsigned integer multiply
714binop("umul_high", tuint, _2src_commutative, """
715if (bit_size == 64) {
716   /* The casts are kind-of annoying but needed to prevent compiler warnings. */
717   uint32_t src0_u32[2] = { src0, (uint64_t)src0 >> 32 };
718   uint32_t src1_u32[2] = { src1, (uint64_t)src1 >> 32 };
719   uint32_t prod_u32[4];
720   ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
721   dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
722} else {
723   dst = ((uint64_t)src0 * (uint64_t)src1) >> bit_size;
724}
725""")
726
727# low 32-bits of unsigned integer multiply
728binop("umul_low", tuint32, _2src_commutative, """
729uint64_t mask = (1 << (bit_size / 2)) - 1;
730dst = ((uint64_t)src0 & mask) * ((uint64_t)src1 & mask);
731""")
732
733# Multiply 32-bits with low 16-bits.
734binop("imul_32x16", tint32, "", "src0 * (int16_t) src1")
735binop("umul_32x16", tuint32, "", "src0 * (uint16_t) src1")
736
737binop("fdiv", tfloat, "", "src0 / src1")
738binop("idiv", tint, "", "src1 == 0 ? 0 : (src0 / src1)")
739binop("udiv", tuint, "", "src1 == 0 ? 0 : (src0 / src1)")
740
741# returns a boolean representing the carry resulting from the addition of
742# the two unsigned arguments.
743
744binop_convert("uadd_carry", tuint, tuint, _2src_commutative, "src0 + src1 < src0")
745
746# returns a boolean representing the borrow resulting from the subtraction
747# of the two unsigned arguments.
748
749binop_convert("usub_borrow", tuint, tuint, "", "src0 < src1")
750
751# hadd: (a + b) >> 1 (without overflow)
752# x + y = x - (x & ~y) + (x & ~y) + y - (~x & y) + (~x & y)
753#       =      (x & y) + (x & ~y) +      (x & y) + (~x & y)
754#       = 2 *  (x & y) + (x & ~y) +                (~x & y)
755#       =     ((x & y) << 1) + (x ^ y)
756#
757# Since we know that the bottom bit of (x & y) << 1 is zero,
758#
759# (x + y) >> 1 = (((x & y) << 1) + (x ^ y)) >> 1
760#              =   (x & y) +      ((x ^ y)  >> 1)
761binop("ihadd", tint, _2src_commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
762binop("uhadd", tuint, _2src_commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
763
764# rhadd: (a + b + 1) >> 1 (without overflow)
765# x + y + 1 = x + (~x & y) - (~x & y) + y + (x & ~y) - (x & ~y) + 1
766#           =      (x | y) - (~x & y) +      (x | y) - (x & ~y) + 1
767#           = 2 *  (x | y) - ((~x & y) +               (x & ~y)) + 1
768#           =     ((x | y) << 1) - (x ^ y) + 1
769#
770# Since we know that the bottom bit of (x & y) << 1 is zero,
771#
772# (x + y + 1) >> 1 = (x | y) + (-(x ^ y) + 1) >> 1)
773#                  = (x | y) -  ((x ^ y)      >> 1)
774binop("irhadd", tint, _2src_commutative, "(src0 | src1) - ((src0 ^ src1) >> 1)")
775binop("urhadd", tuint, _2src_commutative, "(src0 | src1) - ((src0 ^ src1) >> 1)")
776
777binop("umod", tuint, "", "src1 == 0 ? 0 : src0 % src1")
778
779# For signed integers, there are several different possible definitions of
780# "modulus" or "remainder".  We follow the conventions used by LLVM and
781# SPIR-V.  The irem opcode implements the standard C/C++ signed "%"
782# operation while the imod opcode implements the more mathematical
783# "modulus" operation.  For details on the difference, see
784#
785# http://mathforum.org/library/drmath/view/52343.html
786
787binop("irem", tint, "", "src1 == 0 ? 0 : src0 % src1")
788binop("imod", tint, "",
789      "src1 == 0 ? 0 : ((src0 % src1 == 0 || (src0 >= 0) == (src1 >= 0)) ?"
790      "                 src0 % src1 : src0 % src1 + src1)")
791binop("fmod", tfloat, "", "src0 - src1 * floorf(src0 / src1)")
792binop("frem", tfloat, "", "src0 - src1 * truncf(src0 / src1)")
793
794#
795# Comparisons
796#
797
798
799# these integer-aware comparisons return a boolean (0 or ~0)
800
801binop_compare_all_sizes("flt", tfloat, "", "src0 < src1")
802binop_compare_all_sizes("fge", tfloat, "", "src0 >= src1")
803binop_compare_all_sizes("feq", tfloat, _2src_commutative, "src0 == src1")
804binop_compare_all_sizes("fneu", tfloat, _2src_commutative, "src0 != src1")
805binop_compare_all_sizes("ilt", tint, "", "src0 < src1")
806binop_compare_all_sizes("ige", tint, "", "src0 >= src1")
807binop_compare_all_sizes("ieq", tint, _2src_commutative, "src0 == src1")
808binop_compare_all_sizes("ine", tint, _2src_commutative, "src0 != src1")
809binop_compare_all_sizes("ult", tuint, "", "src0 < src1")
810binop_compare_all_sizes("uge", tuint, "", "src0 >= src1")
811
812# integer-aware GLSL-style comparisons that compare floats and ints
813
814binop_reduce_all_sizes("ball_fequal",  1, tfloat, "{src0} == {src1}",
815                       "{src0} && {src1}", "{src}")
816binop_reduce_all_sizes("bany_fnequal", 1, tfloat, "{src0} != {src1}",
817                       "{src0} || {src1}", "{src}")
818binop_reduce_all_sizes("ball_iequal",  1, tint, "{src0} == {src1}",
819                       "{src0} && {src1}", "{src}")
820binop_reduce_all_sizes("bany_inequal", 1, tint, "{src0} != {src1}",
821                       "{src0} || {src1}", "{src}")
822
823# non-integer-aware GLSL-style comparisons that return 0.0 or 1.0
824
825binop_reduce("fall_equal",  1, tfloat32, tfloat32, "{src0} == {src1}",
826             "{src0} && {src1}", "{src} ? 1.0f : 0.0f")
827binop_reduce("fany_nequal", 1, tfloat32, tfloat32, "{src0} != {src1}",
828             "{src0} || {src1}", "{src} ? 1.0f : 0.0f")
829
830# These comparisons for integer-less hardware return 1.0 and 0.0 for true
831# and false respectively
832
833binop("slt", tfloat32, "", "(src0 < src1) ? 1.0f : 0.0f") # Set on Less Than
834binop("sge", tfloat, "", "(src0 >= src1) ? 1.0f : 0.0f") # Set on Greater or Equal
835binop("seq", tfloat32, _2src_commutative, "(src0 == src1) ? 1.0f : 0.0f") # Set on Equal
836binop("sne", tfloat32, _2src_commutative, "(src0 != src1) ? 1.0f : 0.0f") # Set on Not Equal
837
838# SPIRV shifts are undefined for shift-operands >= bitsize,
839# but SM5 shifts are defined to use only the least significant bits.
840# The NIR definition is according to the SM5 specification.
841opcode("ishl", 0, tint, [0, 0], [tint, tuint32], False, "",
842       "(uint64_t)src0 << (src1 & (sizeof(src0) * 8 - 1))")
843opcode("ishr", 0, tint, [0, 0], [tint, tuint32], False, "",
844       "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
845opcode("ushr", 0, tuint, [0, 0], [tuint, tuint32], False, "",
846       "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
847
848opcode("urol", 0, tuint, [0, 0], [tuint, tuint32], False, "", """
849   uint32_t rotate_mask = sizeof(src0) * 8 - 1;
850   dst = (src0 << (src1 & rotate_mask)) |
851         (src0 >> (-src1 & rotate_mask));
852""")
853opcode("uror", 0, tuint, [0, 0], [tuint, tuint32], False, "", """
854   uint32_t rotate_mask = sizeof(src0) * 8 - 1;
855   dst = (src0 >> (src1 & rotate_mask)) |
856         (src0 << (-src1 & rotate_mask));
857""")
858
859# bitwise logic operators
860#
861# These are also used as boolean and, or, xor for hardware supporting
862# integers.
863
864
865binop("iand", tuint, _2src_commutative + associative, "src0 & src1")
866binop("ior", tuint, _2src_commutative + associative, "src0 | src1")
867binop("ixor", tuint, _2src_commutative + associative, "src0 ^ src1")
868
869
870binop_reduce("fdot", 1, tfloat, tfloat, "{src0} * {src1}", "{src0} + {src1}",
871             "{src}")
872
873binop_reduce("fdot", 4, tfloat, tfloat,
874             "{src0} * {src1}", "{src0} + {src1}", "{src}",
875             suffix="_replicated")
876
877opcode("fdph", 1, tfloat, [3, 4], [tfloat, tfloat], False, "",
878       "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
879opcode("fdph_replicated", 4, tfloat, [3, 4], [tfloat, tfloat], False, "",
880       "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
881
882binop("fmin", tfloat, _2src_commutative + associative, "fmin(src0, src1)")
883binop("imin", tint, _2src_commutative + associative, "src1 > src0 ? src0 : src1")
884binop("umin", tuint, _2src_commutative + associative, "src1 > src0 ? src0 : src1")
885binop("fmax", tfloat, _2src_commutative + associative, "fmax(src0, src1)")
886binop("imax", tint, _2src_commutative + associative, "src1 > src0 ? src1 : src0")
887binop("umax", tuint, _2src_commutative + associative, "src1 > src0 ? src1 : src0")
888
889binop("fpow", tfloat, "", "bit_size == 64 ? powf(src0, src1) : pow(src0, src1)")
890
891binop_horiz("pack_half_2x16_split", 1, tuint32, 1, tfloat32, 1, tfloat32,
892            "pack_half_1x16(src0.x) | (pack_half_1x16(src1.x) << 16)")
893
894binop_convert("pack_64_2x32_split", tuint64, tuint32, "",
895              "src0 | ((uint64_t)src1 << 32)")
896
897binop_convert("pack_32_2x16_split", tuint32, tuint16, "",
898              "src0 | ((uint32_t)src1 << 16)")
899
900opcode("pack_32_4x8_split", 0, tuint32, [0, 0, 0, 0], [tuint8, tuint8, tuint8, tuint8],
901       False, "",
902       "src0 | ((uint32_t)src1 << 8) | ((uint32_t)src2 << 16) | ((uint32_t)src3 << 24)")
903
904# bfm implements the behavior of the first operation of the SM5 "bfi" assembly
905# and that of the "bfi1" i965 instruction. That is, the bits and offset values
906# are from the low five bits of src0 and src1, respectively.
907binop_convert("bfm", tuint32, tint32, "", """
908int bits = src0 & 0x1F;
909int offset = src1 & 0x1F;
910dst = ((1u << bits) - 1) << offset;
911""")
912
913opcode("ldexp", 0, tfloat, [0, 0], [tfloat, tint32], False, "", """
914dst = (bit_size == 64) ? ldexp(src0, src1) : ldexpf(src0, src1);
915/* flush denormals to zero. */
916if (!isnormal(dst))
917   dst = copysignf(0.0f, src0);
918""")
919
920# Combines the first component of each input to make a 2-component vector.
921
922binop_horiz("vec2", 2, tuint, 1, tuint, 1, tuint, """
923dst.x = src0.x;
924dst.y = src1.x;
925""")
926
927# Byte extraction
928binop("extract_u8", tuint, "", "(uint8_t)(src0 >> (src1 * 8))")
929binop("extract_i8", tint, "", "(int8_t)(src0 >> (src1 * 8))")
930
931# Word extraction
932binop("extract_u16", tuint, "", "(uint16_t)(src0 >> (src1 * 16))")
933binop("extract_i16", tint, "", "(int16_t)(src0 >> (src1 * 16))")
934
935# Byte/word insertion
936binop("insert_u8", tuint, "", "(src0 & 0xff) << (src1 * 8)")
937binop("insert_u16", tuint, "", "(src0 & 0xffff) << (src1 * 16)")
938
939
940def triop(name, ty, alg_props, const_expr):
941   opcode(name, 0, ty, [0, 0, 0], [ty, ty, ty], False, alg_props, const_expr)
942def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr):
943   opcode(name, output_size, tuint,
944   [src1_size, src2_size, src3_size],
945   [tuint, tuint, tuint], False, "", const_expr)
946
947triop("ffma", tfloat, _2src_commutative, """
948if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
949   if (bit_size == 64)
950      dst = _mesa_double_fma_rtz(src0, src1, src2);
951   else if (bit_size == 32)
952      dst = _mesa_float_fma_rtz(src0, src1, src2);
953   else
954      dst = _mesa_double_to_float_rtz(_mesa_double_fma_rtz(src0, src1, src2));
955} else {
956   if (bit_size == 32)
957      dst = fmaf(src0, src1, src2);
958   else
959      dst = fma(src0, src1, src2);
960}
961""")
962
963triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2")
964
965# Ternary addition
966triop("iadd3", tint, _2src_commutative + associative, "src0 + src1 + src2")
967
968# Conditional Select
969#
970# A vector conditional select instruction (like ?:, but operating per-
971# component on vectors). There are two versions, one for floating point
972# bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
973
974triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2")
975
976opcode("bcsel", 0, tuint, [0, 0, 0],
977       [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
978opcode("b8csel", 0, tuint, [0, 0, 0],
979       [tbool8, tuint, tuint], False, "", "src0 ? src1 : src2")
980opcode("b16csel", 0, tuint, [0, 0, 0],
981       [tbool16, tuint, tuint], False, "", "src0 ? src1 : src2")
982opcode("b32csel", 0, tuint, [0, 0, 0],
983       [tbool32, tuint, tuint], False, "", "src0 ? src1 : src2")
984
985triop("i32csel_gt", tint32, "", "(src0 > 0.0f) ? src1 : src2")
986triop("i32csel_ge", tint32, "", "(src0 >= 0.0f) ? src1 : src2")
987
988triop("fcsel_gt", tfloat32, "", "(src0 > 0.0f) ? src1 : src2")
989triop("fcsel_ge", tfloat32, "", "(src0 >= 0.0f) ? src1 : src2")
990
991# SM5 bfi assembly
992triop("bfi", tuint32, "", """
993unsigned mask = src0, insert = src1, base = src2;
994if (mask == 0) {
995   dst = base;
996} else {
997   unsigned tmp = mask;
998   while (!(tmp & 1)) {
999      tmp >>= 1;
1000      insert <<= 1;
1001   }
1002   dst = (base & ~mask) | (insert & mask);
1003}
1004""")
1005
1006
1007triop("bitfield_select", tuint, "", "(src0 & src1) | (~src0 & src2)")
1008
1009# SM5 ubfe/ibfe assembly: only the 5 least significant bits of offset and bits are used.
1010opcode("ubfe", 0, tuint32,
1011       [0, 0, 0], [tuint32, tuint32, tuint32], False, "", """
1012unsigned base = src0;
1013unsigned offset = src1 & 0x1F;
1014unsigned bits = src2 & 0x1F;
1015if (bits == 0) {
1016   dst = 0;
1017} else if (offset + bits < 32) {
1018   dst = (base << (32 - bits - offset)) >> (32 - bits);
1019} else {
1020   dst = base >> offset;
1021}
1022""")
1023opcode("ibfe", 0, tint32,
1024       [0, 0, 0], [tint32, tuint32, tuint32], False, "", """
1025int base = src0;
1026unsigned offset = src1 & 0x1F;
1027unsigned bits = src2 & 0x1F;
1028if (bits == 0) {
1029   dst = 0;
1030} else if (offset + bits < 32) {
1031   dst = (base << (32 - bits - offset)) >> (32 - bits);
1032} else {
1033   dst = base >> offset;
1034}
1035""")
1036
1037# GLSL bitfieldExtract()
1038opcode("ubitfield_extract", 0, tuint32,
1039       [0, 0, 0], [tuint32, tint32, tint32], False, "", """
1040unsigned base = src0;
1041int offset = src1, bits = src2;
1042if (bits == 0) {
1043   dst = 0;
1044} else if (bits < 0 || offset < 0 || offset + bits > 32) {
1045   dst = 0; /* undefined per the spec */
1046} else {
1047   dst = (base >> offset) & ((1ull << bits) - 1);
1048}
1049""")
1050opcode("ibitfield_extract", 0, tint32,
1051       [0, 0, 0], [tint32, tint32, tint32], False, "", """
1052int base = src0;
1053int offset = src1, bits = src2;
1054if (bits == 0) {
1055   dst = 0;
1056} else if (offset < 0 || bits < 0 || offset + bits > 32) {
1057   dst = 0;
1058} else {
1059   dst = (base << (32 - offset - bits)) >> (32 - bits); /* use sign-extending shift */
1060}
1061""")
1062
1063# Sum of absolute differences with accumulation.
1064# (Equivalent to AMD's v_sad_u8 instruction.)
1065# The first two sources contain packed 8-bit unsigned integers, the instruction
1066# will calculate the absolute difference of these, and then add them together.
1067# There is also a third source which is a 32-bit unsigned integer and added to the result.
1068triop_horiz("sad_u8x4", 1, 1, 1, 1, """
1069uint8_t s0_b0 = (src0.x & 0x000000ff) >> 0;
1070uint8_t s0_b1 = (src0.x & 0x0000ff00) >> 8;
1071uint8_t s0_b2 = (src0.x & 0x00ff0000) >> 16;
1072uint8_t s0_b3 = (src0.x & 0xff000000) >> 24;
1073
1074uint8_t s1_b0 = (src1.x & 0x000000ff) >> 0;
1075uint8_t s1_b1 = (src1.x & 0x0000ff00) >> 8;
1076uint8_t s1_b2 = (src1.x & 0x00ff0000) >> 16;
1077uint8_t s1_b3 = (src1.x & 0xff000000) >> 24;
1078
1079dst.x = src2.x +
1080        (s0_b0 > s1_b0 ? (s0_b0 - s1_b0) : (s1_b0 - s0_b0)) +
1081        (s0_b1 > s1_b1 ? (s0_b1 - s1_b1) : (s1_b1 - s0_b1)) +
1082        (s0_b2 > s1_b2 ? (s0_b2 - s1_b2) : (s1_b2 - s0_b2)) +
1083        (s0_b3 > s1_b3 ? (s0_b3 - s1_b3) : (s1_b3 - s0_b3));
1084""")
1085
1086# Combines the first component of each input to make a 3-component vector.
1087
1088triop_horiz("vec3", 3, 1, 1, 1, """
1089dst.x = src0.x;
1090dst.y = src1.x;
1091dst.z = src2.x;
1092""")
1093
1094def quadop_horiz(name, output_size, src1_size, src2_size, src3_size,
1095                 src4_size, const_expr):
1096   opcode(name, output_size, tuint,
1097          [src1_size, src2_size, src3_size, src4_size],
1098          [tuint, tuint, tuint, tuint],
1099          False, "", const_expr)
1100
1101opcode("bitfield_insert", 0, tuint32, [0, 0, 0, 0],
1102       [tuint32, tuint32, tint32, tint32], False, "", """
1103unsigned base = src0, insert = src1;
1104int offset = src2, bits = src3;
1105if (bits == 0) {
1106   dst = base;
1107} else if (offset < 0 || bits < 0 || bits + offset > 32) {
1108   dst = 0;
1109} else {
1110   unsigned mask = ((1ull << bits) - 1) << offset;
1111   dst = (base & ~mask) | ((insert << offset) & mask);
1112}
1113""")
1114
1115quadop_horiz("vec4", 4, 1, 1, 1, 1, """
1116dst.x = src0.x;
1117dst.y = src1.x;
1118dst.z = src2.x;
1119dst.w = src3.x;
1120""")
1121
1122opcode("vec5", 5, tuint,
1123       [1] * 5, [tuint] * 5,
1124       False, "", """
1125dst.x = src0.x;
1126dst.y = src1.x;
1127dst.z = src2.x;
1128dst.w = src3.x;
1129dst.e = src4.x;
1130""")
1131
1132opcode("vec8", 8, tuint,
1133       [1] * 8, [tuint] * 8,
1134       False, "", """
1135dst.x = src0.x;
1136dst.y = src1.x;
1137dst.z = src2.x;
1138dst.w = src3.x;
1139dst.e = src4.x;
1140dst.f = src5.x;
1141dst.g = src6.x;
1142dst.h = src7.x;
1143""")
1144
1145opcode("vec16", 16, tuint,
1146       [1] * 16, [tuint] * 16,
1147       False, "", """
1148dst.x = src0.x;
1149dst.y = src1.x;
1150dst.z = src2.x;
1151dst.w = src3.x;
1152dst.e = src4.x;
1153dst.f = src5.x;
1154dst.g = src6.x;
1155dst.h = src7.x;
1156dst.i = src8.x;
1157dst.j = src9.x;
1158dst.k = src10.x;
1159dst.l = src11.x;
1160dst.m = src12.x;
1161dst.n = src13.x;
1162dst.o = src14.x;
1163dst.p = src15.x;
1164""")
1165
1166# An integer multiply instruction for address calculation.  This is
1167# similar to imul, except that the results are undefined in case of
1168# overflow.  Overflow is defined according to the size of the variable
1169# being dereferenced.
1170#
1171# This relaxed definition, compared to imul, allows an optimization
1172# pass to propagate bounds (ie, from an load/store intrinsic) to the
1173# sources, such that lower precision integer multiplies can be used.
1174# This is useful on hw that has 24b or perhaps 16b integer multiply
1175# instructions.
1176binop("amul", tint, _2src_commutative + associative, "src0 * src1")
1177
1178# ir3-specific instruction that maps directly to mul-add shift high mix,
1179# (IMADSH_MIX16 i.e. ah * bl << 16 + c). It is used for lowering integer
1180# multiplication (imul) on Freedreno backend..
1181opcode("imadsh_mix16", 0, tint32,
1182       [0, 0, 0], [tint32, tint32, tint32], False, "", """
1183dst = ((((src0 & 0xffff0000) >> 16) * (src1 & 0x0000ffff)) << 16) + src2;
1184""")
1185
1186# ir3-specific instruction that maps directly to ir3 mad.s24.
1187#
1188# 24b multiply into 32b result (with sign extension) plus 32b int
1189triop("imad24_ir3", tint32, _2src_commutative,
1190      "(((int32_t)src0 << 8) >> 8) * (((int32_t)src1 << 8) >> 8) + src2")
1191
1192# r600-specific instruction that evaluates unnormalized cube texture coordinates
1193# and face index
1194# The actual texture coordinates are evaluated from this according to
1195#    dst.yx / abs(dst.z) + 1.5
1196unop_horiz("cube_r600", 4, tfloat32, 3, tfloat32, """
1197   dst.x = dst.y = dst.z = 0.0;
1198   float absX = fabsf(src0.x);
1199   float absY = fabsf(src0.y);
1200   float absZ = fabsf(src0.z);
1201
1202   if (absX >= absY && absX >= absZ) { dst.z = 2 * src0.x; }
1203   if (absY >= absX && absY >= absZ) { dst.z = 2 * src0.y; }
1204   if (absZ >= absX && absZ >= absY) { dst.z = 2 * src0.z; }
1205
1206   if (src0.x >= 0 && absX >= absY && absX >= absZ) {
1207      dst.y = -src0.z; dst.x = -src0.y; dst.w = 0;
1208   }
1209   if (src0.x < 0 && absX >= absY && absX >= absZ) {
1210      dst.y = src0.z; dst.x = -src0.y; dst.w = 1;
1211   }
1212   if (src0.y >= 0 && absY >= absX && absY >= absZ) {
1213      dst.y = src0.x; dst.x = src0.z; dst.w = 2;
1214   }
1215   if (src0.y < 0 && absY >= absX && absY >= absZ) {
1216      dst.y = src0.x; dst.x = -src0.z; dst.w = 3;
1217   }
1218   if (src0.z >= 0 && absZ >= absX && absZ >= absY) {
1219      dst.y = src0.x; dst.x = -src0.y; dst.w = 4;
1220   }
1221   if (src0.z < 0 && absZ >= absX && absZ >= absY) {
1222      dst.y = -src0.x; dst.x = -src0.y; dst.w = 5;
1223   }
1224""")
1225
1226# r600 specific sin and cos
1227# these trigeometric functions need some lowering because the supported
1228# input values are expected to be normalized by dividing by (2 * pi)
1229unop("fsin_r600", tfloat32, "sinf(6.2831853 * src0)")
1230unop("fcos_r600", tfloat32, "cosf(6.2831853 * src0)")
1231
1232# AGX specific sin with input expressed in quadrants. Used in the lowering for
1233# fsin/fcos. This corresponds to a sequence of 3 ALU ops in the backend (where
1234# the angle is further decomposed by quadrant, sinc is computed, and the angle
1235# is multiplied back for sin). Lowering fsin/fcos to fsin_agx requires some
1236# additional ALU that NIR may be able to optimize.
1237unop("fsin_agx", tfloat, "sinf(src0 * (6.2831853/4.0))")
1238
1239# 24b multiply into 32b result (with sign extension)
1240binop("imul24", tint32, _2src_commutative + associative,
1241      "(((int32_t)src0 << 8) >> 8) * (((int32_t)src1 << 8) >> 8)")
1242
1243# unsigned 24b multiply into 32b result plus 32b int
1244triop("umad24", tuint32, _2src_commutative,
1245      "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8) + src2")
1246
1247# unsigned 24b multiply into 32b result uint
1248binop("umul24", tint32, _2src_commutative + associative,
1249      "(((uint32_t)src0 << 8) >> 8) * (((uint32_t)src1 << 8) >> 8)")
1250
1251# relaxed versions of the above, which assume input is in the 24bit range (no clamping)
1252binop("imul24_relaxed", tint32, _2src_commutative + associative, "src0 * src1")
1253triop("umad24_relaxed", tuint32, _2src_commutative, "src0 * src1 + src2")
1254binop("umul24_relaxed", tuint32, _2src_commutative + associative, "src0 * src1")
1255
1256unop_convert("fisnormal", tbool1, tfloat, "isnormal(src0)")
1257unop_convert("fisfinite", tbool1, tfloat, "isfinite(src0)")
1258unop_convert("fisfinite32", tint32, tfloat, "isfinite(src0)")
1259
1260# vc4-specific opcodes
1261
1262# Saturated vector add for 4 8bit ints.
1263binop("usadd_4x8_vc4", tint32, _2src_commutative + associative, """
1264dst = 0;
1265for (int i = 0; i < 32; i += 8) {
1266   dst |= MIN2(((src0 >> i) & 0xff) + ((src1 >> i) & 0xff), 0xff) << i;
1267}
1268""")
1269
1270# Saturated vector subtract for 4 8bit ints.
1271binop("ussub_4x8_vc4", tint32, "", """
1272dst = 0;
1273for (int i = 0; i < 32; i += 8) {
1274   int src0_chan = (src0 >> i) & 0xff;
1275   int src1_chan = (src1 >> i) & 0xff;
1276   if (src0_chan > src1_chan)
1277      dst |= (src0_chan - src1_chan) << i;
1278}
1279""")
1280
1281# vector min for 4 8bit ints.
1282binop("umin_4x8_vc4", tint32, _2src_commutative + associative, """
1283dst = 0;
1284for (int i = 0; i < 32; i += 8) {
1285   dst |= MIN2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
1286}
1287""")
1288
1289# vector max for 4 8bit ints.
1290binop("umax_4x8_vc4", tint32, _2src_commutative + associative, """
1291dst = 0;
1292for (int i = 0; i < 32; i += 8) {
1293   dst |= MAX2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
1294}
1295""")
1296
1297# unorm multiply: (a * b) / 255.
1298binop("umul_unorm_4x8_vc4", tint32, _2src_commutative + associative, """
1299dst = 0;
1300for (int i = 0; i < 32; i += 8) {
1301   int src0_chan = (src0 >> i) & 0xff;
1302   int src1_chan = (src1 >> i) & 0xff;
1303   dst |= ((src0_chan * src1_chan) / 255) << i;
1304}
1305""")
1306
1307# Mali-specific opcodes
1308unop("fsat_signed_mali", tfloat, ("fmin(fmax(src0, -1.0), 1.0)"))
1309unop("fclamp_pos_mali", tfloat, ("fmax(src0, 0.0)"))
1310
1311# Magnitude equal to fddx/y, sign undefined. Derivative of a constant is zero.
1312unop("fddx_must_abs_mali", tfloat, "0.0")
1313unop("fddy_must_abs_mali", tfloat, "0.0")
1314
1315# DXIL specific double [un]pack
1316# DXIL doesn't support generic [un]pack instructions, so we want those
1317# lowered to bit ops. HLSL doesn't support 64bit bitcasts to/from
1318# double, only [un]pack. Technically DXIL does, but considering they
1319# can't be generated from HLSL, we want to match what would be coming from DXC.
1320# This is essentially just the standard [un]pack, except that it doesn't get
1321# lowered so we can handle it in the backend and turn it into MakeDouble/SplitDouble
1322unop_horiz("pack_double_2x32_dxil", 1, tuint64, 2, tuint32,
1323           "dst.x = src0.x | ((uint64_t)src0.y << 32);")
1324unop_horiz("unpack_double_2x32_dxil", 2, tuint32, 1, tuint64,
1325           "dst.x = src0.x; dst.y = src0.x >> 32;")
1326
1327# src0 and src1 are i8vec4 packed in an int32, and src2 is an int32.  The int8
1328# components are sign-extended to 32-bits, and a dot-product is performed on
1329# the resulting vectors.  src2 is added to the result of the dot-product.
1330opcode("sdot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1331       False, _2src_commutative, """
1332   const int32_t v0x = (int8_t)(src0      );
1333   const int32_t v0y = (int8_t)(src0 >>  8);
1334   const int32_t v0z = (int8_t)(src0 >> 16);
1335   const int32_t v0w = (int8_t)(src0 >> 24);
1336   const int32_t v1x = (int8_t)(src1      );
1337   const int32_t v1y = (int8_t)(src1 >>  8);
1338   const int32_t v1z = (int8_t)(src1 >> 16);
1339   const int32_t v1w = (int8_t)(src1 >> 24);
1340
1341   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1342""")
1343
1344# Like sdot_4x8_iadd, but unsigned.
1345opcode("udot_4x8_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
1346       False, _2src_commutative, """
1347   const uint32_t v0x = (uint8_t)(src0      );
1348   const uint32_t v0y = (uint8_t)(src0 >>  8);
1349   const uint32_t v0z = (uint8_t)(src0 >> 16);
1350   const uint32_t v0w = (uint8_t)(src0 >> 24);
1351   const uint32_t v1x = (uint8_t)(src1      );
1352   const uint32_t v1y = (uint8_t)(src1 >>  8);
1353   const uint32_t v1z = (uint8_t)(src1 >> 16);
1354   const uint32_t v1w = (uint8_t)(src1 >> 24);
1355
1356   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1357""")
1358
1359# src0 is i8vec4 packed in an int32, src1 is u8vec4 packed in an int32, and
1360# src2 is an int32.  The 8-bit components are extended to 32-bits, and a
1361# dot-product is performed on the resulting vectors.  src2 is added to the
1362# result of the dot-product.
1363#
1364# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0
1365# and source 1 mean that this opcode is not 2-source commutative
1366opcode("sudot_4x8_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1367       False, "", """
1368   const int32_t v0x = (int8_t)(src0      );
1369   const int32_t v0y = (int8_t)(src0 >>  8);
1370   const int32_t v0z = (int8_t)(src0 >> 16);
1371   const int32_t v0w = (int8_t)(src0 >> 24);
1372   const uint32_t v1x = (uint8_t)(src1      );
1373   const uint32_t v1y = (uint8_t)(src1 >>  8);
1374   const uint32_t v1z = (uint8_t)(src1 >> 16);
1375   const uint32_t v1w = (uint8_t)(src1 >> 24);
1376
1377   dst = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1378""")
1379
1380# Like sdot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
1381opcode("sdot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1382       False, _2src_commutative, """
1383   const int64_t v0x = (int8_t)(src0      );
1384   const int64_t v0y = (int8_t)(src0 >>  8);
1385   const int64_t v0z = (int8_t)(src0 >> 16);
1386   const int64_t v0w = (int8_t)(src0 >> 24);
1387   const int64_t v1x = (int8_t)(src1      );
1388   const int64_t v1y = (int8_t)(src1 >>  8);
1389   const int64_t v1z = (int8_t)(src1 >> 16);
1390   const int64_t v1w = (int8_t)(src1 >> 24);
1391
1392   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1393
1394   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
1395""")
1396
1397# Like udot_4x8_uadd, but the result is clampled to the range [0, 0xfffffffff].
1398opcode("udot_4x8_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1399       False, _2src_commutative, """
1400   const uint64_t v0x = (uint8_t)(src0      );
1401   const uint64_t v0y = (uint8_t)(src0 >>  8);
1402   const uint64_t v0z = (uint8_t)(src0 >> 16);
1403   const uint64_t v0w = (uint8_t)(src0 >> 24);
1404   const uint64_t v1x = (uint8_t)(src1      );
1405   const uint64_t v1y = (uint8_t)(src1 >>  8);
1406   const uint64_t v1z = (uint8_t)(src1 >> 16);
1407   const uint64_t v1w = (uint8_t)(src1 >> 24);
1408
1409   const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1410
1411   dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp;
1412""")
1413
1414# Like sudot_4x8_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
1415#
1416# NOTE: Unlike many of the other dp4a opcodes, this mixed signs of source 0
1417# and source 1 mean that this opcode is not 2-source commutative
1418opcode("sudot_4x8_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1419       False, "", """
1420   const int64_t v0x = (int8_t)(src0      );
1421   const int64_t v0y = (int8_t)(src0 >>  8);
1422   const int64_t v0z = (int8_t)(src0 >> 16);
1423   const int64_t v0w = (int8_t)(src0 >> 24);
1424   const uint64_t v1x = (uint8_t)(src1      );
1425   const uint64_t v1y = (uint8_t)(src1 >>  8);
1426   const uint64_t v1z = (uint8_t)(src1 >> 16);
1427   const uint64_t v1w = (uint8_t)(src1 >> 24);
1428
1429   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + (v0z * v1z) + (v0w * v1w) + src2;
1430
1431   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
1432""")
1433
1434# src0 and src1 are i16vec2 packed in an int32, and src2 is an int32.  The int16
1435# components are sign-extended to 32-bits, and a dot-product is performed on
1436# the resulting vectors.  src2 is added to the result of the dot-product.
1437opcode("sdot_2x16_iadd", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1438       False, _2src_commutative, """
1439   const int32_t v0x = (int16_t)(src0      );
1440   const int32_t v0y = (int16_t)(src0 >> 16);
1441   const int32_t v1x = (int16_t)(src1      );
1442   const int32_t v1y = (int16_t)(src1 >> 16);
1443
1444   dst = (v0x * v1x) + (v0y * v1y) + src2;
1445""")
1446
1447# Like sdot_2x16_iadd, but unsigned.
1448opcode("udot_2x16_uadd", 0, tuint32, [0, 0, 0], [tuint32, tuint32, tuint32],
1449       False, _2src_commutative, """
1450   const uint32_t v0x = (uint16_t)(src0      );
1451   const uint32_t v0y = (uint16_t)(src0 >> 16);
1452   const uint32_t v1x = (uint16_t)(src1      );
1453   const uint32_t v1y = (uint16_t)(src1 >> 16);
1454
1455   dst = (v0x * v1x) + (v0y * v1y) + src2;
1456""")
1457
1458# Like sdot_2x16_iadd, but the result is clampled to the range [-0x80000000, 0x7ffffffff].
1459opcode("sdot_2x16_iadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1460       False, _2src_commutative, """
1461   const int64_t v0x = (int16_t)(src0      );
1462   const int64_t v0y = (int16_t)(src0 >> 16);
1463   const int64_t v1x = (int16_t)(src1      );
1464   const int64_t v1y = (int16_t)(src1 >> 16);
1465
1466   const int64_t tmp = (v0x * v1x) + (v0y * v1y) + src2;
1467
1468   dst = tmp >= INT32_MAX ? INT32_MAX : (tmp <= INT32_MIN ? INT32_MIN : tmp);
1469""")
1470
1471# Like udot_2x16_uadd, but the result is clampled to the range [0, 0xfffffffff].
1472opcode("udot_2x16_uadd_sat", 0, tint32, [0, 0, 0], [tuint32, tuint32, tint32],
1473       False, _2src_commutative, """
1474   const uint64_t v0x = (uint16_t)(src0      );
1475   const uint64_t v0y = (uint16_t)(src0 >> 16);
1476   const uint64_t v1x = (uint16_t)(src1      );
1477   const uint64_t v1y = (uint16_t)(src1 >> 16);
1478
1479   const uint64_t tmp = (v0x * v1x) + (v0y * v1y) + src2;
1480
1481   dst = tmp >= UINT32_MAX ? UINT32_MAX : tmp;
1482""")
1483