1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22#
23# Authors:
24#    Jason Ekstrand (jason@jlekstrand.net)
25
26import ast
27from collections import defaultdict
28import itertools
29import struct
30import sys
31import mako.template
32import re
33import traceback
34
35from nir_opcodes import opcodes, type_sizes
36
37# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
38nir_search_max_comm_ops = 8
39
40# These opcodes are only employed by nir_search.  This provides a mapping from
41# opcode to destination type.
42conv_opcode_types = {
43    'i2f' : 'float',
44    'u2f' : 'float',
45    'f2f' : 'float',
46    'f2u' : 'uint',
47    'f2i' : 'int',
48    'u2u' : 'uint',
49    'i2i' : 'int',
50    'b2f' : 'float',
51    'b2i' : 'int',
52    'i2b' : 'bool',
53    'f2b' : 'bool',
54}
55
56def get_c_opcode(op):
57      if op in conv_opcode_types:
58         return 'nir_search_op_' + op
59      else:
60         return 'nir_op_' + op
61
62_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
63
64def type_bits(type_str):
65   m = _type_re.match(type_str)
66   assert m.group('type')
67
68   if m.group('bits') is None:
69      return 0
70   else:
71      return int(m.group('bits'))
72
73# Represents a set of variables, each with a unique id
74class VarSet(object):
75   def __init__(self):
76      self.names = {}
77      self.ids = itertools.count()
78      self.immutable = False;
79
80   def __getitem__(self, name):
81      if name not in self.names:
82         assert not self.immutable, "Unknown replacement variable: " + name
83         self.names[name] = next(self.ids)
84
85      return self.names[name]
86
87   def lock(self):
88      self.immutable = True
89
90class Value(object):
91   @staticmethod
92   def create(val, name_base, varset):
93      if isinstance(val, bytes):
94         val = val.decode('utf-8')
95
96      if isinstance(val, tuple):
97         return Expression(val, name_base, varset)
98      elif isinstance(val, Expression):
99         return val
100      elif isinstance(val, str):
101         return Variable(val, name_base, varset)
102      elif isinstance(val, (bool, float, int)):
103         return Constant(val, name_base)
104
105   def __init__(self, val, name, type_str):
106      self.in_val = str(val)
107      self.name = name
108      self.type_str = type_str
109
110   def __str__(self):
111      return self.in_val
112
113   def get_bit_size(self):
114      """Get the physical bit-size that has been chosen for this value, or if
115      there is none, the canonical value which currently represents this
116      bit-size class. Variables will be preferred, i.e. if there are any
117      variables in the equivalence class, the canonical value will be a
118      variable. We do this since we'll need to know which variable each value
119      is equivalent to when constructing the replacement expression. This is
120      the "find" part of the union-find algorithm.
121      """
122      bit_size = self
123
124      while isinstance(bit_size, Value):
125         if bit_size._bit_size is None:
126            break
127         bit_size = bit_size._bit_size
128
129      if bit_size is not self:
130         self._bit_size = bit_size
131      return bit_size
132
133   def set_bit_size(self, other):
134      """Make self.get_bit_size() return what other.get_bit_size() return
135      before calling this, or just "other" if it's a concrete bit-size. This is
136      the "union" part of the union-find algorithm.
137      """
138
139      self_bit_size = self.get_bit_size()
140      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
141
142      if self_bit_size == other_bit_size:
143         return
144
145      self_bit_size._bit_size = other_bit_size
146
147   @property
148   def type_enum(self):
149      return "nir_search_value_" + self.type_str
150
151   @property
152   def c_type(self):
153      return "nir_search_" + self.type_str
154
155   def __c_name(self, cache):
156      if cache is not None and self.name in cache:
157         return cache[self.name]
158      else:
159         return self.name
160
161   def c_value_ptr(self, cache):
162      return "&{0}.value".format(self.__c_name(cache))
163
164   def c_ptr(self, cache):
165      return "&{0}".format(self.__c_name(cache))
166
167   @property
168   def c_bit_size(self):
169      bit_size = self.get_bit_size()
170      if isinstance(bit_size, int):
171         return bit_size
172      elif isinstance(bit_size, Variable):
173         return -bit_size.index - 1
174      else:
175         # If the bit-size class is neither a variable, nor an actual bit-size, then
176         # - If it's in the search expression, we don't need to check anything
177         # - If it's in the replace expression, either it's ambiguous (in which
178         # case we'd reject it), or it equals the bit-size of the search value
179         # We represent these cases with a 0 bit-size.
180         return 0
181
182   __template = mako.template.Template("""{
183   { ${val.type_enum}, ${val.c_bit_size} },
184% if isinstance(val, Constant):
185   ${val.type()}, { ${val.hex()} /* ${val.value} */ },
186% elif isinstance(val, Variable):
187   ${val.index}, /* ${val.var_name} */
188   ${'true' if val.is_constant else 'false'},
189   ${val.type() or 'nir_type_invalid' },
190   ${val.cond if val.cond else 'NULL'},
191   ${val.swizzle()},
192% elif isinstance(val, Expression):
193   ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
194   ${val.comm_expr_idx}, ${val.comm_exprs},
195   ${val.c_opcode()},
196   { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
197   ${val.cond if val.cond else 'NULL'},
198% endif
199};""")
200
201   def render(self, cache):
202      struct_init = self.__template.render(val=self, cache=cache,
203                                           Constant=Constant,
204                                           Variable=Variable,
205                                           Expression=Expression)
206      if cache is not None and struct_init in cache:
207         # If it's in the cache, register a name remap in the cache and render
208         # only a comment saying it's been remapped
209         cache[self.name] = cache[struct_init]
210         return "/* {} -> {} in the cache */\n".format(self.name,
211                                                       cache[struct_init])
212      else:
213         if cache is not None:
214            cache[struct_init] = self.name
215         return "static const {} {} = {}\n".format(self.c_type, self.name,
216                                                   struct_init)
217
218_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
219
220class Constant(Value):
221   def __init__(self, val, name):
222      Value.__init__(self, val, name, "constant")
223
224      if isinstance(val, (str)):
225         m = _constant_re.match(val)
226         self.value = ast.literal_eval(m.group('value'))
227         self._bit_size = int(m.group('bits')) if m.group('bits') else None
228      else:
229         self.value = val
230         self._bit_size = None
231
232      if isinstance(self.value, bool):
233         assert self._bit_size is None or self._bit_size == 1
234         self._bit_size = 1
235
236   def hex(self):
237      if isinstance(self.value, (bool)):
238         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
239      if isinstance(self.value, int):
240         return hex(self.value)
241      elif isinstance(self.value, float):
242         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
243      else:
244         assert False
245
246   def type(self):
247      if isinstance(self.value, (bool)):
248         return "nir_type_bool"
249      elif isinstance(self.value, int):
250         return "nir_type_int"
251      elif isinstance(self.value, float):
252         return "nir_type_float"
253
254   def equivalent(self, other):
255      """Check that two constants are equivalent.
256
257      This is check is much weaker than equality.  One generally cannot be
258      used in place of the other.  Using this implementation for the __eq__
259      will break BitSizeValidator.
260
261      """
262      if not isinstance(other, type(self)):
263         return False
264
265      return self.value == other.value
266
267# The $ at the end forces there to be an error if any part of the string
268# doesn't match one of the field patterns.
269_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
270                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
271                          r"(?P<cond>\([^\)]+\))?"
272                          r"(?P<swiz>\.[xyzw]+)?"
273                          r"$")
274
275class Variable(Value):
276   def __init__(self, val, name, varset):
277      Value.__init__(self, val, name, "variable")
278
279      m = _var_name_re.match(val)
280      assert m and m.group('name') is not None, \
281            "Malformed variable name \"{}\".".format(val)
282
283      self.var_name = m.group('name')
284
285      # Prevent common cases where someone puts quotes around a literal
286      # constant.  If we want to support names that have numeric or
287      # punctuation characters, we can me the first assertion more flexible.
288      assert self.var_name.isalpha()
289      assert self.var_name != 'True'
290      assert self.var_name != 'False'
291
292      self.is_constant = m.group('const') is not None
293      self.cond = m.group('cond')
294      self.required_type = m.group('type')
295      self._bit_size = int(m.group('bits')) if m.group('bits') else None
296      self.swiz = m.group('swiz')
297
298      if self.required_type == 'bool':
299         if self._bit_size is not None:
300            assert self._bit_size in type_sizes(self.required_type)
301         else:
302            self._bit_size = 1
303
304      if self.required_type is not None:
305         assert self.required_type in ('float', 'bool', 'int', 'uint')
306
307      self.index = varset[self.var_name]
308
309   def type(self):
310      if self.required_type == 'bool':
311         return "nir_type_bool"
312      elif self.required_type in ('int', 'uint'):
313         return "nir_type_int"
314      elif self.required_type == 'float':
315         return "nir_type_float"
316
317   def equivalent(self, other):
318      """Check that two variables are equivalent.
319
320      This is check is much weaker than equality.  One generally cannot be
321      used in place of the other.  Using this implementation for the __eq__
322      will break BitSizeValidator.
323
324      """
325      if not isinstance(other, type(self)):
326         return False
327
328      return self.index == other.index
329
330   def swizzle(self):
331      if self.swiz is not None:
332         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
333                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
334                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
335                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
336                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
337         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
338      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
339
340_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
341                        r"(?P<cond>\([^\)]+\))?")
342
343class Expression(Value):
344   def __init__(self, expr, name_base, varset):
345      Value.__init__(self, expr, name_base, "expression")
346      assert isinstance(expr, tuple)
347
348      m = _opcode_re.match(expr[0])
349      assert m and m.group('opcode') is not None
350
351      self.opcode = m.group('opcode')
352      self._bit_size = int(m.group('bits')) if m.group('bits') else None
353      self.inexact = m.group('inexact') is not None
354      self.exact = m.group('exact') is not None
355      self.cond = m.group('cond')
356
357      assert not self.inexact or not self.exact, \
358            'Expression cannot be both exact and inexact.'
359
360      # "many-comm-expr" isn't really a condition.  It's notification to the
361      # generator that this pattern is known to have too many commutative
362      # expressions, and an error should not be generated for this case.
363      self.many_commutative_expressions = False
364      if self.cond and self.cond.find("many-comm-expr") >= 0:
365         # Split the condition into a comma-separated list.  Remove
366         # "many-comm-expr".  If there is anything left, put it back together.
367         c = self.cond[1:-1].split(",")
368         c.remove("many-comm-expr")
369
370         self.cond = "({})".format(",".join(c)) if c else None
371         self.many_commutative_expressions = True
372
373      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
374                       for (i, src) in enumerate(expr[1:]) ]
375
376      # nir_search_expression::srcs is hard-coded to 4
377      assert len(self.sources) <= 4
378
379      if self.opcode in conv_opcode_types:
380         assert self._bit_size is None, \
381                'Expression cannot use an unsized conversion opcode with ' \
382                'an explicit size; that\'s silly.'
383
384      self.__index_comm_exprs(0)
385
386   def equivalent(self, other):
387      """Check that two variables are equivalent.
388
389      This is check is much weaker than equality.  One generally cannot be
390      used in place of the other.  Using this implementation for the __eq__
391      will break BitSizeValidator.
392
393      This implementation does not check for equivalence due to commutativity,
394      but it could.
395
396      """
397      if not isinstance(other, type(self)):
398         return False
399
400      if len(self.sources) != len(other.sources):
401         return False
402
403      if self.opcode != other.opcode:
404         return False
405
406      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
407
408   def __index_comm_exprs(self, base_idx):
409      """Recursively count and index commutative expressions
410      """
411      self.comm_exprs = 0
412
413      # A note about the explicit "len(self.sources)" check. The list of
414      # sources comes from user input, and that input might be bad.  Check
415      # that the expected second source exists before accessing it. Without
416      # this check, a unit test that does "('iadd', 'a')" will crash.
417      if self.opcode not in conv_opcode_types and \
418         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
419         len(self.sources) >= 2 and \
420         not self.sources[0].equivalent(self.sources[1]):
421         self.comm_expr_idx = base_idx
422         self.comm_exprs += 1
423      else:
424         self.comm_expr_idx = -1
425
426      for s in self.sources:
427         if isinstance(s, Expression):
428            s.__index_comm_exprs(base_idx + self.comm_exprs)
429            self.comm_exprs += s.comm_exprs
430
431      return self.comm_exprs
432
433   def c_opcode(self):
434      return get_c_opcode(self.opcode)
435
436   def render(self, cache):
437      srcs = "\n".join(src.render(cache) for src in self.sources)
438      return srcs + super(Expression, self).render(cache)
439
440class BitSizeValidator(object):
441   """A class for validating bit sizes of expressions.
442
443   NIR supports multiple bit-sizes on expressions in order to handle things
444   such as fp64.  The source and destination of every ALU operation is
445   assigned a type and that type may or may not specify a bit size.  Sources
446   and destinations whose type does not specify a bit size are considered
447   "unsized" and automatically take on the bit size of the corresponding
448   register or SSA value.  NIR has two simple rules for bit sizes that are
449   validated by nir_validator:
450
451    1) A given SSA def or register has a single bit size that is respected by
452       everything that reads from it or writes to it.
453
454    2) The bit sizes of all unsized inputs/outputs on any given ALU
455       instruction must match.  They need not match the sized inputs or
456       outputs but they must match each other.
457
458   In order to keep nir_algebraic relatively simple and easy-to-use,
459   nir_search supports a type of bit-size inference based on the two rules
460   above.  This is similar to type inference in many common programming
461   languages.  If, for instance, you are constructing an add operation and you
462   know the second source is 16-bit, then you know that the other source and
463   the destination must also be 16-bit.  There are, however, cases where this
464   inference can be ambiguous or contradictory.  Consider, for instance, the
465   following transformation:
466
467   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
468
469   This transformation can potentially cause a problem because usub_borrow is
470   well-defined for any bit-size of integer.  However, b2i always generates a
471   32-bit result so it could end up replacing a 64-bit expression with one
472   that takes two 64-bit values and produces a 32-bit value.  As another
473   example, consider this expression:
474
475   (('bcsel', a, b, 0), ('iand', a, b))
476
477   In this case, in the search expression a must be 32-bit but b can
478   potentially have any bit size.  If we had a 64-bit b value, we would end up
479   trying to and a 32-bit value with a 64-bit value which would be invalid
480
481   This class solves that problem by providing a validation layer that proves
482   that a given search-and-replace operation is 100% well-defined before we
483   generate any code.  This ensures that bugs are caught at compile time
484   rather than at run time.
485
486   Each value maintains a "bit-size class", which is either an actual bit size
487   or an equivalence class with other values that must have the same bit size.
488   The validator works by combining bit-size classes with each other according
489   to the NIR rules outlined above, checking that there are no inconsistencies.
490   When doing this for the replacement expression, we make sure to never change
491   the equivalence class of any of the search values. We could make the example
492   transforms above work by doing some extra run-time checking of the search
493   expression, but we make the user specify those constraints themselves, to
494   avoid any surprises. Since the replacement bitsizes can only be connected to
495   the source bitsize via variables (variables must have the same bitsize in
496   the source and replacment expressions) or the roots of the expression (the
497   replacement expression must produce the same bit size as the search
498   expression), we prevent merging a variable with anything when processing the
499   replacement expression, or specializing the search bitsize
500   with anything. The former prevents
501
502   (('bcsel', a, b, 0), ('iand', a, b))
503
504   from being allowed, since we'd have to merge the bitsizes for a and b due to
505   the 'iand', while the latter prevents
506
507   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
508
509   from being allowed, since the search expression has the bit size of a and b,
510   which can't be specialized to 32 which is the bitsize of the replace
511   expression. It also prevents something like:
512
513   (('b2i', ('i2b', a)), ('ineq', a, 0))
514
515   since the bitsize of 'b2i', which can be anything, can't be specialized to
516   the bitsize of a.
517
518   After doing all this, we check that every subexpression of the replacement
519   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
520   of the search expresssion, since those are the things that are known when
521   constructing the replacement expresssion. Finally, we record the bitsize
522   needed in nir_search_value so that we know what to do when building the
523   replacement expression.
524   """
525
526   def __init__(self, varset):
527      self._var_classes = [None] * len(varset.names)
528
529   def compare_bitsizes(self, a, b):
530      """Determines which bitsize class is a specialization of the other, or
531      whether neither is. When we merge two different bitsizes, the
532      less-specialized bitsize always points to the more-specialized one, so
533      that calling get_bit_size() always gets you the most specialized bitsize.
534      The specialization partial order is given by:
535      - Physical bitsizes are always the most specialized, and a different
536        bitsize can never specialize another.
537      - In the search expression, variables can always be specialized to each
538        other and to physical bitsizes. In the replace expression, we disallow
539        this to avoid adding extra constraints to the search expression that
540        the user didn't specify.
541      - Expressions and constants without a bitsize can always be specialized to
542        each other and variables, but not the other way around.
543
544        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
545        and None if they are not comparable (neither a <= b nor b <= a).
546      """
547      if isinstance(a, int):
548         if isinstance(b, int):
549            return 0 if a == b else None
550         elif isinstance(b, Variable):
551            return -1 if self.is_search else None
552         else:
553            return -1
554      elif isinstance(a, Variable):
555         if isinstance(b, int):
556            return 1 if self.is_search else None
557         elif isinstance(b, Variable):
558            return 0 if self.is_search or a.index == b.index else None
559         else:
560            return -1
561      else:
562         if isinstance(b, int):
563            return 1
564         elif isinstance(b, Variable):
565            return 1
566         else:
567            return 0
568
569   def unify_bit_size(self, a, b, error_msg):
570      """Record that a must have the same bit-size as b. If both
571      have been assigned conflicting physical bit-sizes, call "error_msg" with
572      the bit-sizes of self and other to get a message and raise an error.
573      In the replace expression, disallow merging variables with other
574      variables and physical bit-sizes as well.
575      """
576      a_bit_size = a.get_bit_size()
577      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
578
579      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
580
581      assert cmp_result is not None, \
582         error_msg(a_bit_size, b_bit_size)
583
584      if cmp_result < 0:
585         b_bit_size.set_bit_size(a)
586      elif not isinstance(a_bit_size, int):
587         a_bit_size.set_bit_size(b)
588
589   def merge_variables(self, val):
590      """Perform the first part of type inference by merging all the different
591      uses of the same variable. We always do this as if we're in the search
592      expression, even if we're actually not, since otherwise we'd get errors
593      if the search expression specified some constraint but the replace
594      expression didn't, because we'd be merging a variable and a constant.
595      """
596      if isinstance(val, Variable):
597         if self._var_classes[val.index] is None:
598            self._var_classes[val.index] = val
599         else:
600            other = self._var_classes[val.index]
601            self.unify_bit_size(other, val,
602                  lambda other_bit_size, bit_size:
603                     'Variable {} has conflicting bit size requirements: ' \
604                     'it must have bit size {} and {}'.format(
605                        val.var_name, other_bit_size, bit_size))
606      elif isinstance(val, Expression):
607         for src in val.sources:
608            self.merge_variables(src)
609
610   def validate_value(self, val):
611      """Validate the an expression by performing classic Hindley-Milner
612      type inference on bitsizes. This will detect if there are any conflicting
613      requirements, and unify variables so that we know which variables must
614      have the same bitsize. If we're operating on the replace expression, we
615      will refuse to merge different variables together or merge a variable
616      with a constant, in order to prevent surprises due to rules unexpectedly
617      not matching at runtime.
618      """
619      if not isinstance(val, Expression):
620         return
621
622      # Generic conversion ops are special in that they have a single unsized
623      # source and an unsized destination and the two don't have to match.
624      # This means there's no validation or unioning to do here besides the
625      # len(val.sources) check.
626      if val.opcode in conv_opcode_types:
627         assert len(val.sources) == 1, \
628            "Expression {} has {} sources, expected 1".format(
629               val, len(val.sources))
630         self.validate_value(val.sources[0])
631         return
632
633      nir_op = opcodes[val.opcode]
634      assert len(val.sources) == nir_op.num_inputs, \
635         "Expression {} has {} sources, expected {}".format(
636            val, len(val.sources), nir_op.num_inputs)
637
638      for src in val.sources:
639         self.validate_value(src)
640
641      dst_type_bits = type_bits(nir_op.output_type)
642
643      # First, unify all the sources. That way, an error coming up because two
644      # sources have an incompatible bit-size won't produce an error message
645      # involving the destination.
646      first_unsized_src = None
647      for src_type, src in zip(nir_op.input_types, val.sources):
648         src_type_bits = type_bits(src_type)
649         if src_type_bits == 0:
650            if first_unsized_src is None:
651               first_unsized_src = src
652               continue
653
654            if self.is_search:
655               self.unify_bit_size(first_unsized_src, src,
656                  lambda first_unsized_src_bit_size, src_bit_size:
657                     'Source {} of {} must have bit size {}, while source {} ' \
658                     'must have incompatible bit size {}'.format(
659                        first_unsized_src, val, first_unsized_src_bit_size,
660                        src, src_bit_size))
661            else:
662               self.unify_bit_size(first_unsized_src, src,
663                  lambda first_unsized_src_bit_size, src_bit_size:
664                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
665                     'of {} may not have the same bit size when building the ' \
666                     'replacement expression.'.format(
667                        first_unsized_src, first_unsized_src_bit_size, src,
668                        src_bit_size, val))
669         else:
670            if self.is_search:
671               self.unify_bit_size(src, src_type_bits,
672                  lambda src_bit_size, unused:
673                     '{} must have {} bits, but as a source of nir_op_{} '\
674                     'it must have {} bits'.format(
675                        src, src_bit_size, nir_op.name, src_type_bits))
676            else:
677               self.unify_bit_size(src, src_type_bits,
678                  lambda src_bit_size, unused:
679                     '{} has the bit size of {}, but as a source of ' \
680                     'nir_op_{} it must have {} bits, which may not be the ' \
681                     'same'.format(
682                        src, src_bit_size, nir_op.name, src_type_bits))
683
684      if dst_type_bits == 0:
685         if first_unsized_src is not None:
686            if self.is_search:
687               self.unify_bit_size(val, first_unsized_src,
688                  lambda val_bit_size, src_bit_size:
689                     '{} must have the bit size of {}, while its source {} ' \
690                     'must have incompatible bit size {}'.format(
691                        val, val_bit_size, first_unsized_src, src_bit_size))
692            else:
693               self.unify_bit_size(val, first_unsized_src,
694                  lambda val_bit_size, src_bit_size:
695                     '{} must have {} bits, but its source {} ' \
696                     '(bit size of {}) may not have that bit size ' \
697                     'when building the replacement.'.format(
698                        val, val_bit_size, first_unsized_src, src_bit_size))
699      else:
700         self.unify_bit_size(val, dst_type_bits,
701            lambda dst_bit_size, unused:
702               '{} must have {} bits, but as a destination of nir_op_{} ' \
703               'it must have {} bits'.format(
704                  val, dst_bit_size, nir_op.name, dst_type_bits))
705
706   def validate_replace(self, val, search):
707      bit_size = val.get_bit_size()
708      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
709            bit_size == search.get_bit_size(), \
710            'Ambiguous bit size for replacement value {}: ' \
711            'it cannot be deduced from a variable, a fixed bit size ' \
712            'somewhere, or the search expression.'.format(val)
713
714      if isinstance(val, Expression):
715         for src in val.sources:
716            self.validate_replace(src, search)
717
718   def validate(self, search, replace):
719      self.is_search = True
720      self.merge_variables(search)
721      self.merge_variables(replace)
722      self.validate_value(search)
723
724      self.is_search = False
725      self.validate_value(replace)
726
727      # Check that search is always more specialized than replace. Note that
728      # we're doing this in replace mode, disallowing merging variables.
729      search_bit_size = search.get_bit_size()
730      replace_bit_size = replace.get_bit_size()
731      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
732
733      assert cmp_result is not None and cmp_result <= 0, \
734         'The search expression bit size {} and replace expression ' \
735         'bit size {} may not be the same'.format(
736               search_bit_size, replace_bit_size)
737
738      replace.set_bit_size(search)
739
740      self.validate_replace(replace, search)
741
742_optimization_ids = itertools.count()
743
744condition_list = ['true']
745
746class SearchAndReplace(object):
747   def __init__(self, transform):
748      self.id = next(_optimization_ids)
749
750      search = transform[0]
751      replace = transform[1]
752      if len(transform) > 2:
753         self.condition = transform[2]
754      else:
755         self.condition = 'true'
756
757      if self.condition not in condition_list:
758         condition_list.append(self.condition)
759      self.condition_index = condition_list.index(self.condition)
760
761      varset = VarSet()
762      if isinstance(search, Expression):
763         self.search = search
764      else:
765         self.search = Expression(search, "search{0}".format(self.id), varset)
766
767      varset.lock()
768
769      if isinstance(replace, Value):
770         self.replace = replace
771      else:
772         self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
773
774      BitSizeValidator(varset).validate(self.search, self.replace)
775
776class TreeAutomaton(object):
777   """This class calculates a bottom-up tree automaton to quickly search for
778   the left-hand sides of tranforms. Tree automatons are a generalization of
779   classical NFA's and DFA's, where the transition function determines the
780   state of the parent node based on the state of its children. We construct a
781   deterministic automaton to match patterns, using a similar algorithm to the
782   classical NFA to DFA construction. At the moment, it only matches opcodes
783   and constants (without checking the actual value), leaving more detailed
784   checking to the search function which actually checks the leaves. The
785   automaton acts as a quick filter for the search function, requiring only n
786   + 1 table lookups for each n-source operation. The implementation is based
787   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
788   In the language of that reference, this is a frontier-to-root deterministic
789   automaton using only symbol filtering. The filtering is crucial to reduce
790   both the time taken to generate the tables and the size of the tables.
791   """
792   def __init__(self, transforms):
793      self.patterns = [t.search for t in transforms]
794      self._compute_items()
795      self._build_table()
796      #print('num items: {}'.format(len(set(self.items.values()))))
797      #print('num states: {}'.format(len(self.states)))
798      #for state, patterns in zip(self.states, self.patterns):
799      #   print('{}: num patterns: {}'.format(state, len(patterns)))
800
801   class IndexMap(object):
802      """An indexed list of objects, where one can either lookup an object by
803      index or find the index associated to an object quickly using a hash
804      table. Compared to a list, it has a constant time index(). Compared to a
805      set, it provides a stable iteration order.
806      """
807      def __init__(self, iterable=()):
808         self.objects = []
809         self.map = {}
810         for obj in iterable:
811            self.add(obj)
812
813      def __getitem__(self, i):
814         return self.objects[i]
815
816      def __contains__(self, obj):
817         return obj in self.map
818
819      def __len__(self):
820         return len(self.objects)
821
822      def __iter__(self):
823         return iter(self.objects)
824
825      def clear(self):
826         self.objects = []
827         self.map.clear()
828
829      def index(self, obj):
830         return self.map[obj]
831
832      def add(self, obj):
833         if obj in self.map:
834            return self.map[obj]
835         else:
836            index = len(self.objects)
837            self.objects.append(obj)
838            self.map[obj] = index
839            return index
840
841      def __repr__(self):
842         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
843
844   class Item(object):
845      """This represents an "item" in the language of "Tree Automatons." This
846      is just a subtree of some pattern, which represents a potential partial
847      match at runtime. We deduplicate them, so that identical subtrees of
848      different patterns share the same object, and store some extra
849      information needed for the main algorithm as well.
850      """
851      def __init__(self, opcode, children):
852         self.opcode = opcode
853         self.children = children
854         # These are the indices of patterns for which this item is the root node.
855         self.patterns = []
856         # This the set of opcodes for parents of this item. Used to speed up
857         # filtering.
858         self.parent_ops = set()
859
860      def __str__(self):
861         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
862
863      def __repr__(self):
864         return str(self)
865
866   def _compute_items(self):
867      """Build a set of all possible items, deduplicating them."""
868      # This is a map from (opcode, sources) to item.
869      self.items = {}
870
871      # The set of all opcodes used by the patterns. Used later to avoid
872      # building and emitting all the tables for opcodes that aren't used.
873      self.opcodes = self.IndexMap()
874
875      def get_item(opcode, children, pattern=None):
876         commutative = len(children) >= 2 \
877               and "2src_commutative" in opcodes[opcode].algebraic_properties
878         item = self.items.setdefault((opcode, children),
879                                      self.Item(opcode, children))
880         if commutative:
881            self.items[opcode, (children[1], children[0]) + children[2:]] = item
882         if pattern is not None:
883            item.patterns.append(pattern)
884         return item
885
886      self.wildcard = get_item("__wildcard", ())
887      self.const = get_item("__const", ())
888
889      def process_subpattern(src, pattern=None):
890         if isinstance(src, Constant):
891            # Note: we throw away the actual constant value!
892            return self.const
893         elif isinstance(src, Variable):
894            if src.is_constant:
895               return self.const
896            else:
897               # Note: we throw away which variable it is here! This special
898               # item is equivalent to nu in "Tree Automatons."
899               return self.wildcard
900         else:
901            assert isinstance(src, Expression)
902            opcode = src.opcode
903            stripped = opcode.rstrip('0123456789')
904            if stripped in conv_opcode_types:
905               # Matches that use conversion opcodes with a specific type,
906               # like f2b1, are tricky.  Either we construct the automaton to
907               # match specific NIR opcodes like nir_op_f2b1, in which case we
908               # need to create separate items for each possible NIR opcode
909               # for patterns that have a generic opcode like f2b, or we
910               # construct it to match the search opcode, in which case we
911               # need to map f2b1 to f2b when constructing the automaton. Here
912               # we do the latter.
913               opcode = stripped
914            self.opcodes.add(opcode)
915            children = tuple(process_subpattern(c) for c in src.sources)
916            item = get_item(opcode, children, pattern)
917            for i, child in enumerate(children):
918               child.parent_ops.add(opcode)
919            return item
920
921      for i, pattern in enumerate(self.patterns):
922         process_subpattern(pattern, i)
923
924   def _build_table(self):
925      """This is the core algorithm which builds up the transition table. It
926      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
927      Comp_a and Filt_{a,i} using integers to identify match sets." It
928      simultaneously builds up a list of all possible "match sets" or
929      "states", where each match set represents the set of Item's that match a
930      given instruction, and builds up the transition table between states.
931      """
932      # Map from opcode + filtered state indices to transitioned state.
933      self.table = defaultdict(dict)
934      # Bijection from state to index. q in the original algorithm is
935      # len(self.states)
936      self.states = self.IndexMap()
937      # List of pattern matches for each state index.
938      self.state_patterns = []
939      # Map from state index to filtered state index for each opcode.
940      self.filter = defaultdict(list)
941      # Bijections from filtered state to filtered state index for each
942      # opcode, called the "representor sets" in the original algorithm.
943      # q_{a,j} in the original algorithm is len(self.rep[op]).
944      self.rep = defaultdict(self.IndexMap)
945
946      # Everything in self.states with a index at least worklist_index is part
947      # of the worklist of newly created states. There is also a worklist of
948      # newly fitered states for each opcode, for which worklist_indices
949      # serves a similar purpose. worklist_index corresponds to p in the
950      # original algorithm, while worklist_indices is p_{a,j} (although since
951      # we only filter by opcode/symbol, it's really just p_a).
952      self.worklist_index = 0
953      worklist_indices = defaultdict(lambda: 0)
954
955      # This is the set of opcodes for which the filtered worklist is non-empty.
956      # It's used to avoid scanning opcodes for which there is nothing to
957      # process when building the transition table. It corresponds to new_a in
958      # the original algorithm.
959      new_opcodes = self.IndexMap()
960
961      # Process states on the global worklist, filtering them for each opcode,
962      # updating the filter tables, and updating the filtered worklists if any
963      # new filtered states are found. Similar to ComputeRepresenterSets() in
964      # the original algorithm, although that only processes a single state.
965      def process_new_states():
966         while self.worklist_index < len(self.states):
967            state = self.states[self.worklist_index]
968
969            # Calculate pattern matches for this state. Each pattern is
970            # assigned to a unique item, so we don't have to worry about
971            # deduplicating them here. However, we do have to sort them so
972            # that they're visited at runtime in the order they're specified
973            # in the source.
974            patterns = list(sorted(p for item in state for p in item.patterns))
975            assert len(self.state_patterns) == self.worklist_index
976            self.state_patterns.append(patterns)
977
978            # calculate filter table for this state, and update filtered
979            # worklists.
980            for op in self.opcodes:
981               filt = self.filter[op]
982               rep = self.rep[op]
983               filtered = frozenset(item for item in state if \
984                  op in item.parent_ops)
985               if filtered in rep:
986                  rep_index = rep.index(filtered)
987               else:
988                  rep_index = rep.add(filtered)
989                  new_opcodes.add(op)
990               assert len(filt) == self.worklist_index
991               filt.append(rep_index)
992            self.worklist_index += 1
993
994      # There are two start states: one which can only match as a wildcard,
995      # and one which can match as a wildcard or constant. These will be the
996      # states of intrinsics/other instructions and load_const instructions,
997      # respectively. The indices of these must match the definitions of
998      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
999      # initialize things correctly.
1000      self.states.add(frozenset((self.wildcard,)))
1001      self.states.add(frozenset((self.const,self.wildcard)))
1002      process_new_states()
1003
1004      while len(new_opcodes) > 0:
1005         for op in new_opcodes:
1006            rep = self.rep[op]
1007            table = self.table[op]
1008            op_worklist_index = worklist_indices[op]
1009            if op in conv_opcode_types:
1010               num_srcs = 1
1011            else:
1012               num_srcs = opcodes[op].num_inputs
1013
1014            # Iterate over all possible source combinations where at least one
1015            # is on the worklist.
1016            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1017               if all(src_idx < op_worklist_index for src_idx in src_indices):
1018                  continue
1019
1020               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1021
1022               # Try all possible pairings of source items and add the
1023               # corresponding parent items. This is Comp_a from the paper.
1024               parent = set(self.items[op, item_srcs] for item_srcs in
1025                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1026
1027               # We could always start matching something else with a
1028               # wildcard. This is Cl from the paper.
1029               parent.add(self.wildcard)
1030
1031               table[src_indices] = self.states.add(frozenset(parent))
1032            worklist_indices[op] = len(rep)
1033         new_opcodes.clear()
1034         process_new_states()
1035
1036_algebraic_pass_template = mako.template.Template("""
1037#include "nir.h"
1038#include "nir_builder.h"
1039#include "nir_search.h"
1040#include "nir_search_helpers.h"
1041
1042/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1043 * transforms:
1044% for xform in xforms:
1045 *    ${xform.search} => ${xform.replace}
1046% endfor
1047 */
1048
1049<% cache = {} %>
1050% for xform in xforms:
1051   ${xform.search.render(cache)}
1052   ${xform.replace.render(cache)}
1053% endfor
1054
1055% for state_id, state_xforms in enumerate(automaton.state_patterns):
1056% if state_xforms: # avoid emitting a 0-length array for MSVC
1057static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1058% for i in state_xforms:
1059  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1060% endfor
1061};
1062% endif
1063% endfor
1064
1065static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1066% for op in automaton.opcodes:
1067   [${get_c_opcode(op)}] = {
1068      .filter = (uint16_t []) {
1069      % for e in automaton.filter[op]:
1070         ${e},
1071      % endfor
1072      },
1073      <%
1074        num_filtered = len(automaton.rep[op])
1075      %>
1076      .num_filtered_states = ${num_filtered},
1077      .table = (uint16_t []) {
1078      <%
1079        num_srcs = len(next(iter(automaton.table[op])))
1080      %>
1081      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1082         ${automaton.table[op][indices]},
1083      % endfor
1084      },
1085   },
1086% endfor
1087};
1088
1089const struct transform *${pass_name}_transforms[] = {
1090% for i in range(len(automaton.state_patterns)):
1091   % if automaton.state_patterns[i]:
1092   ${pass_name}_state${i}_xforms,
1093   % else:
1094   NULL,
1095   % endif
1096% endfor
1097};
1098
1099const uint16_t ${pass_name}_transform_counts[] = {
1100% for i in range(len(automaton.state_patterns)):
1101   % if automaton.state_patterns[i]:
1102   (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
1103   % else:
1104   0,
1105   % endif
1106% endfor
1107};
1108
1109bool
1110${pass_name}(nir_shader *shader)
1111{
1112   bool progress = false;
1113   bool condition_flags[${len(condition_list)}];
1114   const nir_shader_compiler_options *options = shader->options;
1115   const shader_info *info = &shader->info;
1116   (void) options;
1117   (void) info;
1118
1119   % for index, condition in enumerate(condition_list):
1120   condition_flags[${index}] = ${condition};
1121   % endfor
1122
1123   nir_foreach_function(function, shader) {
1124      if (function->impl) {
1125         progress |= nir_algebraic_impl(function->impl, condition_flags,
1126                                        ${pass_name}_transforms,
1127                                        ${pass_name}_transform_counts,
1128                                        ${pass_name}_table);
1129      }
1130   }
1131
1132   return progress;
1133}
1134""")
1135
1136
1137class AlgebraicPass(object):
1138   def __init__(self, pass_name, transforms):
1139      self.xforms = []
1140      self.opcode_xforms = defaultdict(lambda : [])
1141      self.pass_name = pass_name
1142
1143      error = False
1144
1145      for xform in transforms:
1146         if not isinstance(xform, SearchAndReplace):
1147            try:
1148               xform = SearchAndReplace(xform)
1149            except:
1150               print("Failed to parse transformation:", file=sys.stderr)
1151               print("  " + str(xform), file=sys.stderr)
1152               traceback.print_exc(file=sys.stderr)
1153               print('', file=sys.stderr)
1154               error = True
1155               continue
1156
1157         self.xforms.append(xform)
1158         if xform.search.opcode in conv_opcode_types:
1159            dst_type = conv_opcode_types[xform.search.opcode]
1160            for size in type_sizes(dst_type):
1161               sized_opcode = xform.search.opcode + str(size)
1162               self.opcode_xforms[sized_opcode].append(xform)
1163         else:
1164            self.opcode_xforms[xform.search.opcode].append(xform)
1165
1166         # Check to make sure the search pattern does not unexpectedly contain
1167         # more commutative expressions than match_expression (nir_search.c)
1168         # can handle.
1169         comm_exprs = xform.search.comm_exprs
1170
1171         if xform.search.many_commutative_expressions:
1172            if comm_exprs <= nir_search_max_comm_ops:
1173               print("Transform expected to have too many commutative " \
1174                     "expression but did not " \
1175                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1176                     file=sys.stderr)
1177               print("  " + str(xform), file=sys.stderr)
1178               traceback.print_exc(file=sys.stderr)
1179               print('', file=sys.stderr)
1180               error = True
1181         else:
1182            if comm_exprs > nir_search_max_comm_ops:
1183               print("Transformation with too many commutative expressions " \
1184                     "({} > {}).  Modify pattern or annotate with " \
1185                     "\"many-comm-expr\".".format(comm_exprs,
1186                                                  nir_search_max_comm_ops),
1187                     file=sys.stderr)
1188               print("  " + str(xform.search), file=sys.stderr)
1189               print("{}".format(xform.search.cond), file=sys.stderr)
1190               error = True
1191
1192      self.automaton = TreeAutomaton(self.xforms)
1193
1194      if error:
1195         sys.exit(1)
1196
1197
1198   def render(self):
1199      return _algebraic_pass_template.render(pass_name=self.pass_name,
1200                                             xforms=self.xforms,
1201                                             opcode_xforms=self.opcode_xforms,
1202                                             condition_list=condition_list,
1203                                             automaton=self.automaton,
1204                                             get_c_opcode=get_c_opcode,
1205                                             itertools=itertools)
1206