1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22#
23# Authors:
24#    Jason Ekstrand (jason@jlekstrand.net)
25
26import ast
27from collections import defaultdict
28import itertools
29import struct
30import sys
31import mako.template
32import re
33import traceback
34
35from nir_opcodes import opcodes, type_sizes
36
37# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
38nir_search_max_comm_ops = 8
39
40# These opcodes are only employed by nir_search.  This provides a mapping from
41# opcode to destination type.
42conv_opcode_types = {
43    'i2f' : 'float',
44    'u2f' : 'float',
45    'f2f' : 'float',
46    'f2u' : 'uint',
47    'f2i' : 'int',
48    'u2u' : 'uint',
49    'i2i' : 'int',
50    'b2f' : 'float',
51    'b2i' : 'int',
52    'i2b' : 'bool',
53    'f2b' : 'bool',
54}
55
56def get_cond_index(conds, cond):
57    if cond:
58        if cond in conds:
59            return conds[cond]
60        else:
61            cond_index = len(conds)
62            conds[cond] = cond_index
63            return cond_index
64    else:
65        return -1
66
67def get_c_opcode(op):
68      if op in conv_opcode_types:
69         return 'nir_search_op_' + op
70      else:
71         return 'nir_op_' + op
72
73_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
74
75def type_bits(type_str):
76   m = _type_re.match(type_str)
77   assert m.group('type')
78
79   if m.group('bits') is None:
80      return 0
81   else:
82      return int(m.group('bits'))
83
84# Represents a set of variables, each with a unique id
85class VarSet(object):
86   def __init__(self):
87      self.names = {}
88      self.ids = itertools.count()
89      self.immutable = False;
90
91   def __getitem__(self, name):
92      if name not in self.names:
93         assert not self.immutable, "Unknown replacement variable: " + name
94         self.names[name] = next(self.ids)
95
96      return self.names[name]
97
98   def lock(self):
99      self.immutable = True
100
101class SearchExpression(object):
102   def __init__(self, expr):
103      self.opcode = expr[0]
104      self.sources = expr[1:]
105      self.ignore_exact = False
106
107   @staticmethod
108   def create(val):
109      if isinstance(val, tuple):
110         return SearchExpression(val)
111      else:
112         assert(isinstance(val, SearchExpression))
113         return val
114
115   def __repr__(self):
116      l = [self.opcode, *self.sources]
117      if self.ignore_exact:
118         l.append('ignore_exact')
119      return repr((*l,))
120
121class Value(object):
122   @staticmethod
123   def create(val, name_base, varset, algebraic_pass):
124      if isinstance(val, bytes):
125         val = val.decode('utf-8')
126
127      if isinstance(val, tuple) or isinstance(val, SearchExpression):
128         return Expression(val, name_base, varset, algebraic_pass)
129      elif isinstance(val, Expression):
130         return val
131      elif isinstance(val, str):
132         return Variable(val, name_base, varset, algebraic_pass)
133      elif isinstance(val, (bool, float, int)):
134         return Constant(val, name_base)
135
136   def __init__(self, val, name, type_str):
137      self.in_val = str(val)
138      self.name = name
139      self.type_str = type_str
140
141   def __str__(self):
142      return self.in_val
143
144   def get_bit_size(self):
145      """Get the physical bit-size that has been chosen for this value, or if
146      there is none, the canonical value which currently represents this
147      bit-size class. Variables will be preferred, i.e. if there are any
148      variables in the equivalence class, the canonical value will be a
149      variable. We do this since we'll need to know which variable each value
150      is equivalent to when constructing the replacement expression. This is
151      the "find" part of the union-find algorithm.
152      """
153      bit_size = self
154
155      while isinstance(bit_size, Value):
156         if bit_size._bit_size is None:
157            break
158         bit_size = bit_size._bit_size
159
160      if bit_size is not self:
161         self._bit_size = bit_size
162      return bit_size
163
164   def set_bit_size(self, other):
165      """Make self.get_bit_size() return what other.get_bit_size() return
166      before calling this, or just "other" if it's a concrete bit-size. This is
167      the "union" part of the union-find algorithm.
168      """
169
170      self_bit_size = self.get_bit_size()
171      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
172
173      if self_bit_size == other_bit_size:
174         return
175
176      self_bit_size._bit_size = other_bit_size
177
178   @property
179   def type_enum(self):
180      return "nir_search_value_" + self.type_str
181
182   @property
183   def c_bit_size(self):
184      bit_size = self.get_bit_size()
185      if isinstance(bit_size, int):
186         return bit_size
187      elif isinstance(bit_size, Variable):
188         return -bit_size.index - 1
189      else:
190         # If the bit-size class is neither a variable, nor an actual bit-size, then
191         # - If it's in the search expression, we don't need to check anything
192         # - If it's in the replace expression, either it's ambiguous (in which
193         # case we'd reject it), or it equals the bit-size of the search value
194         # We represent these cases with a 0 bit-size.
195         return 0
196
197   __template = mako.template.Template("""   { .${val.type_str} = {
198      { ${val.type_enum}, ${val.c_bit_size} },
199% if isinstance(val, Constant):
200      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
201% elif isinstance(val, Variable):
202      ${val.index}, /* ${val.var_name} */
203      ${'true' if val.is_constant else 'false'},
204      ${val.type() or 'nir_type_invalid' },
205      ${val.cond_index},
206      ${val.swizzle()},
207% elif isinstance(val, Expression):
208      ${'true' if val.inexact else 'false'},
209      ${'true' if val.exact else 'false'},
210      ${'true' if val.ignore_exact else 'false'},
211      ${val.c_opcode()},
212      ${val.comm_expr_idx}, ${val.comm_exprs},
213      { ${', '.join(src.array_index for src in val.sources)} },
214      ${val.cond_index},
215% endif
216   } },
217""")
218
219   def render(self, cache):
220      struct_init = self.__template.render(val=self,
221                                           Constant=Constant,
222                                           Variable=Variable,
223                                           Expression=Expression)
224      if struct_init in cache:
225         # If it's in the cache, register a name remap in the cache and render
226         # only a comment saying it's been remapped
227         self.array_index = cache[struct_init]
228         return "   /* {} -> {} in the cache */\n".format(self.name,
229                                                       cache[struct_init])
230      else:
231         self.array_index = str(cache["next_index"])
232         cache[struct_init] = self.array_index
233         cache["next_index"] += 1
234         return struct_init
235
236_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
237
238class Constant(Value):
239   def __init__(self, val, name):
240      Value.__init__(self, val, name, "constant")
241
242      if isinstance(val, (str)):
243         m = _constant_re.match(val)
244         self.value = ast.literal_eval(m.group('value'))
245         self._bit_size = int(m.group('bits')) if m.group('bits') else None
246      else:
247         self.value = val
248         self._bit_size = None
249
250      if isinstance(self.value, bool):
251         assert self._bit_size is None or self._bit_size == 1
252         self._bit_size = 1
253
254   def hex(self):
255      if isinstance(self.value, (bool)):
256         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
257      if isinstance(self.value, int):
258         return hex(self.value)
259      elif isinstance(self.value, float):
260         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
261      else:
262         assert False
263
264   def type(self):
265      if isinstance(self.value, (bool)):
266         return "nir_type_bool"
267      elif isinstance(self.value, int):
268         return "nir_type_int"
269      elif isinstance(self.value, float):
270         return "nir_type_float"
271
272   def equivalent(self, other):
273      """Check that two constants are equivalent.
274
275      This is check is much weaker than equality.  One generally cannot be
276      used in place of the other.  Using this implementation for the __eq__
277      will break BitSizeValidator.
278
279      """
280      if not isinstance(other, type(self)):
281         return False
282
283      return self.value == other.value
284
285# The $ at the end forces there to be an error if any part of the string
286# doesn't match one of the field patterns.
287_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
288                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
289                          r"(?P<cond>\([^\)]+\))?"
290                          r"(?P<swiz>\.[xyzw]+)?"
291                          r"$")
292
293class Variable(Value):
294   def __init__(self, val, name, varset, algebraic_pass):
295      Value.__init__(self, val, name, "variable")
296
297      m = _var_name_re.match(val)
298      assert m and m.group('name') is not None, \
299            "Malformed variable name \"{}\".".format(val)
300
301      self.var_name = m.group('name')
302
303      # Prevent common cases where someone puts quotes around a literal
304      # constant.  If we want to support names that have numeric or
305      # punctuation characters, we can me the first assertion more flexible.
306      assert self.var_name.isalpha()
307      assert self.var_name != 'True'
308      assert self.var_name != 'False'
309
310      self.is_constant = m.group('const') is not None
311      self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
312      self.required_type = m.group('type')
313      self._bit_size = int(m.group('bits')) if m.group('bits') else None
314      self.swiz = m.group('swiz')
315
316      if self.required_type == 'bool':
317         if self._bit_size is not None:
318            assert self._bit_size in type_sizes(self.required_type)
319         else:
320            self._bit_size = 1
321
322      if self.required_type is not None:
323         assert self.required_type in ('float', 'bool', 'int', 'uint')
324
325      self.index = varset[self.var_name]
326
327   def type(self):
328      if self.required_type == 'bool':
329         return "nir_type_bool"
330      elif self.required_type in ('int', 'uint'):
331         return "nir_type_int"
332      elif self.required_type == 'float':
333         return "nir_type_float"
334
335   def equivalent(self, other):
336      """Check that two variables are equivalent.
337
338      This is check is much weaker than equality.  One generally cannot be
339      used in place of the other.  Using this implementation for the __eq__
340      will break BitSizeValidator.
341
342      """
343      if not isinstance(other, type(self)):
344         return False
345
346      return self.index == other.index
347
348   def swizzle(self):
349      if self.swiz is not None:
350         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
351                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
352                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
353                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
354                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
355         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
356      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
357
358_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
359                        r"(?P<cond>\([^\)]+\))?")
360
361class Expression(Value):
362   def __init__(self, expr, name_base, varset, algebraic_pass):
363      Value.__init__(self, expr, name_base, "expression")
364
365      expr = SearchExpression.create(expr)
366
367      m = _opcode_re.match(expr.opcode)
368      assert m and m.group('opcode') is not None
369
370      self.opcode = m.group('opcode')
371      self._bit_size = int(m.group('bits')) if m.group('bits') else None
372      self.inexact = m.group('inexact') is not None
373      self.exact = m.group('exact') is not None
374      self.ignore_exact = expr.ignore_exact
375      self.cond = m.group('cond')
376
377      assert not self.inexact or not self.exact, \
378            'Expression cannot be both exact and inexact.'
379
380      # "many-comm-expr" isn't really a condition.  It's notification to the
381      # generator that this pattern is known to have too many commutative
382      # expressions, and an error should not be generated for this case.
383      self.many_commutative_expressions = False
384      if self.cond and self.cond.find("many-comm-expr") >= 0:
385         # Split the condition into a comma-separated list.  Remove
386         # "many-comm-expr".  If there is anything left, put it back together.
387         c = self.cond[1:-1].split(",")
388         c.remove("many-comm-expr")
389         assert(len(c) <= 1)
390
391         self.cond = c[0] if c else None
392         self.many_commutative_expressions = True
393
394      # Deduplicate references to the condition functions for the expressions
395      # and save the index for the order they were added.
396      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
397
398      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
399                       for (i, src) in enumerate(expr.sources) ]
400
401      # nir_search_expression::srcs is hard-coded to 4
402      assert len(self.sources) <= 4
403
404      if self.opcode in conv_opcode_types:
405         assert self._bit_size is None, \
406                'Expression cannot use an unsized conversion opcode with ' \
407                'an explicit size; that\'s silly.'
408
409      self.__index_comm_exprs(0)
410
411   def equivalent(self, other):
412      """Check that two variables are equivalent.
413
414      This is check is much weaker than equality.  One generally cannot be
415      used in place of the other.  Using this implementation for the __eq__
416      will break BitSizeValidator.
417
418      This implementation does not check for equivalence due to commutativity,
419      but it could.
420
421      """
422      if not isinstance(other, type(self)):
423         return False
424
425      if len(self.sources) != len(other.sources):
426         return False
427
428      if self.opcode != other.opcode:
429         return False
430
431      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
432
433   def __index_comm_exprs(self, base_idx):
434      """Recursively count and index commutative expressions
435      """
436      self.comm_exprs = 0
437
438      # A note about the explicit "len(self.sources)" check. The list of
439      # sources comes from user input, and that input might be bad.  Check
440      # that the expected second source exists before accessing it. Without
441      # this check, a unit test that does "('iadd', 'a')" will crash.
442      if self.opcode not in conv_opcode_types and \
443         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
444         len(self.sources) >= 2 and \
445         not self.sources[0].equivalent(self.sources[1]):
446         self.comm_expr_idx = base_idx
447         self.comm_exprs += 1
448      else:
449         self.comm_expr_idx = -1
450
451      for s in self.sources:
452         if isinstance(s, Expression):
453            s.__index_comm_exprs(base_idx + self.comm_exprs)
454            self.comm_exprs += s.comm_exprs
455
456      return self.comm_exprs
457
458   def c_opcode(self):
459      return get_c_opcode(self.opcode)
460
461   def render(self, cache):
462      srcs = "".join(src.render(cache) for src in self.sources)
463      return srcs + super(Expression, self).render(cache)
464
465class BitSizeValidator(object):
466   """A class for validating bit sizes of expressions.
467
468   NIR supports multiple bit-sizes on expressions in order to handle things
469   such as fp64.  The source and destination of every ALU operation is
470   assigned a type and that type may or may not specify a bit size.  Sources
471   and destinations whose type does not specify a bit size are considered
472   "unsized" and automatically take on the bit size of the corresponding
473   register or SSA value.  NIR has two simple rules for bit sizes that are
474   validated by nir_validator:
475
476    1) A given SSA def or register has a single bit size that is respected by
477       everything that reads from it or writes to it.
478
479    2) The bit sizes of all unsized inputs/outputs on any given ALU
480       instruction must match.  They need not match the sized inputs or
481       outputs but they must match each other.
482
483   In order to keep nir_algebraic relatively simple and easy-to-use,
484   nir_search supports a type of bit-size inference based on the two rules
485   above.  This is similar to type inference in many common programming
486   languages.  If, for instance, you are constructing an add operation and you
487   know the second source is 16-bit, then you know that the other source and
488   the destination must also be 16-bit.  There are, however, cases where this
489   inference can be ambiguous or contradictory.  Consider, for instance, the
490   following transformation:
491
492   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
493
494   This transformation can potentially cause a problem because usub_borrow is
495   well-defined for any bit-size of integer.  However, b2i always generates a
496   32-bit result so it could end up replacing a 64-bit expression with one
497   that takes two 64-bit values and produces a 32-bit value.  As another
498   example, consider this expression:
499
500   (('bcsel', a, b, 0), ('iand', a, b))
501
502   In this case, in the search expression a must be 32-bit but b can
503   potentially have any bit size.  If we had a 64-bit b value, we would end up
504   trying to and a 32-bit value with a 64-bit value which would be invalid
505
506   This class solves that problem by providing a validation layer that proves
507   that a given search-and-replace operation is 100% well-defined before we
508   generate any code.  This ensures that bugs are caught at compile time
509   rather than at run time.
510
511   Each value maintains a "bit-size class", which is either an actual bit size
512   or an equivalence class with other values that must have the same bit size.
513   The validator works by combining bit-size classes with each other according
514   to the NIR rules outlined above, checking that there are no inconsistencies.
515   When doing this for the replacement expression, we make sure to never change
516   the equivalence class of any of the search values. We could make the example
517   transforms above work by doing some extra run-time checking of the search
518   expression, but we make the user specify those constraints themselves, to
519   avoid any surprises. Since the replacement bitsizes can only be connected to
520   the source bitsize via variables (variables must have the same bitsize in
521   the source and replacment expressions) or the roots of the expression (the
522   replacement expression must produce the same bit size as the search
523   expression), we prevent merging a variable with anything when processing the
524   replacement expression, or specializing the search bitsize
525   with anything. The former prevents
526
527   (('bcsel', a, b, 0), ('iand', a, b))
528
529   from being allowed, since we'd have to merge the bitsizes for a and b due to
530   the 'iand', while the latter prevents
531
532   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
533
534   from being allowed, since the search expression has the bit size of a and b,
535   which can't be specialized to 32 which is the bitsize of the replace
536   expression. It also prevents something like:
537
538   (('b2i', ('i2b', a)), ('ineq', a, 0))
539
540   since the bitsize of 'b2i', which can be anything, can't be specialized to
541   the bitsize of a.
542
543   After doing all this, we check that every subexpression of the replacement
544   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
545   of the search expresssion, since those are the things that are known when
546   constructing the replacement expresssion. Finally, we record the bitsize
547   needed in nir_search_value so that we know what to do when building the
548   replacement expression.
549   """
550
551   def __init__(self, varset):
552      self._var_classes = [None] * len(varset.names)
553
554   def compare_bitsizes(self, a, b):
555      """Determines which bitsize class is a specialization of the other, or
556      whether neither is. When we merge two different bitsizes, the
557      less-specialized bitsize always points to the more-specialized one, so
558      that calling get_bit_size() always gets you the most specialized bitsize.
559      The specialization partial order is given by:
560      - Physical bitsizes are always the most specialized, and a different
561        bitsize can never specialize another.
562      - In the search expression, variables can always be specialized to each
563        other and to physical bitsizes. In the replace expression, we disallow
564        this to avoid adding extra constraints to the search expression that
565        the user didn't specify.
566      - Expressions and constants without a bitsize can always be specialized to
567        each other and variables, but not the other way around.
568
569        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
570        and None if they are not comparable (neither a <= b nor b <= a).
571      """
572      if isinstance(a, int):
573         if isinstance(b, int):
574            return 0 if a == b else None
575         elif isinstance(b, Variable):
576            return -1 if self.is_search else None
577         else:
578            return -1
579      elif isinstance(a, Variable):
580         if isinstance(b, int):
581            return 1 if self.is_search else None
582         elif isinstance(b, Variable):
583            return 0 if self.is_search or a.index == b.index else None
584         else:
585            return -1
586      else:
587         if isinstance(b, int):
588            return 1
589         elif isinstance(b, Variable):
590            return 1
591         else:
592            return 0
593
594   def unify_bit_size(self, a, b, error_msg):
595      """Record that a must have the same bit-size as b. If both
596      have been assigned conflicting physical bit-sizes, call "error_msg" with
597      the bit-sizes of self and other to get a message and raise an error.
598      In the replace expression, disallow merging variables with other
599      variables and physical bit-sizes as well.
600      """
601      a_bit_size = a.get_bit_size()
602      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
603
604      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
605
606      assert cmp_result is not None, \
607         error_msg(a_bit_size, b_bit_size)
608
609      if cmp_result < 0:
610         b_bit_size.set_bit_size(a)
611      elif not isinstance(a_bit_size, int):
612         a_bit_size.set_bit_size(b)
613
614   def merge_variables(self, val):
615      """Perform the first part of type inference by merging all the different
616      uses of the same variable. We always do this as if we're in the search
617      expression, even if we're actually not, since otherwise we'd get errors
618      if the search expression specified some constraint but the replace
619      expression didn't, because we'd be merging a variable and a constant.
620      """
621      if isinstance(val, Variable):
622         if self._var_classes[val.index] is None:
623            self._var_classes[val.index] = val
624         else:
625            other = self._var_classes[val.index]
626            self.unify_bit_size(other, val,
627                  lambda other_bit_size, bit_size:
628                     'Variable {} has conflicting bit size requirements: ' \
629                     'it must have bit size {} and {}'.format(
630                        val.var_name, other_bit_size, bit_size))
631      elif isinstance(val, Expression):
632         for src in val.sources:
633            self.merge_variables(src)
634
635   def validate_value(self, val):
636      """Validate the an expression by performing classic Hindley-Milner
637      type inference on bitsizes. This will detect if there are any conflicting
638      requirements, and unify variables so that we know which variables must
639      have the same bitsize. If we're operating on the replace expression, we
640      will refuse to merge different variables together or merge a variable
641      with a constant, in order to prevent surprises due to rules unexpectedly
642      not matching at runtime.
643      """
644      if not isinstance(val, Expression):
645         return
646
647      # Generic conversion ops are special in that they have a single unsized
648      # source and an unsized destination and the two don't have to match.
649      # This means there's no validation or unioning to do here besides the
650      # len(val.sources) check.
651      if val.opcode in conv_opcode_types:
652         assert len(val.sources) == 1, \
653            "Expression {} has {} sources, expected 1".format(
654               val, len(val.sources))
655         self.validate_value(val.sources[0])
656         return
657
658      nir_op = opcodes[val.opcode]
659      assert len(val.sources) == nir_op.num_inputs, \
660         "Expression {} has {} sources, expected {}".format(
661            val, len(val.sources), nir_op.num_inputs)
662
663      for src in val.sources:
664         self.validate_value(src)
665
666      dst_type_bits = type_bits(nir_op.output_type)
667
668      # First, unify all the sources. That way, an error coming up because two
669      # sources have an incompatible bit-size won't produce an error message
670      # involving the destination.
671      first_unsized_src = None
672      for src_type, src in zip(nir_op.input_types, val.sources):
673         src_type_bits = type_bits(src_type)
674         if src_type_bits == 0:
675            if first_unsized_src is None:
676               first_unsized_src = src
677               continue
678
679            if self.is_search:
680               self.unify_bit_size(first_unsized_src, src,
681                  lambda first_unsized_src_bit_size, src_bit_size:
682                     'Source {} of {} must have bit size {}, while source {} ' \
683                     'must have incompatible bit size {}'.format(
684                        first_unsized_src, val, first_unsized_src_bit_size,
685                        src, src_bit_size))
686            else:
687               self.unify_bit_size(first_unsized_src, src,
688                  lambda first_unsized_src_bit_size, src_bit_size:
689                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
690                     'of {} may not have the same bit size when building the ' \
691                     'replacement expression.'.format(
692                        first_unsized_src, first_unsized_src_bit_size, src,
693                        src_bit_size, val))
694         else:
695            if self.is_search:
696               self.unify_bit_size(src, src_type_bits,
697                  lambda src_bit_size, unused:
698                     '{} must have {} bits, but as a source of nir_op_{} '\
699                     'it must have {} bits'.format(
700                        src, src_bit_size, nir_op.name, src_type_bits))
701            else:
702               self.unify_bit_size(src, src_type_bits,
703                  lambda src_bit_size, unused:
704                     '{} has the bit size of {}, but as a source of ' \
705                     'nir_op_{} it must have {} bits, which may not be the ' \
706                     'same'.format(
707                        src, src_bit_size, nir_op.name, src_type_bits))
708
709      if dst_type_bits == 0:
710         if first_unsized_src is not None:
711            if self.is_search:
712               self.unify_bit_size(val, first_unsized_src,
713                  lambda val_bit_size, src_bit_size:
714                     '{} must have the bit size of {}, while its source {} ' \
715                     'must have incompatible bit size {}'.format(
716                        val, val_bit_size, first_unsized_src, src_bit_size))
717            else:
718               self.unify_bit_size(val, first_unsized_src,
719                  lambda val_bit_size, src_bit_size:
720                     '{} must have {} bits, but its source {} ' \
721                     '(bit size of {}) may not have that bit size ' \
722                     'when building the replacement.'.format(
723                        val, val_bit_size, first_unsized_src, src_bit_size))
724      else:
725         self.unify_bit_size(val, dst_type_bits,
726            lambda dst_bit_size, unused:
727               '{} must have {} bits, but as a destination of nir_op_{} ' \
728               'it must have {} bits'.format(
729                  val, dst_bit_size, nir_op.name, dst_type_bits))
730
731   def validate_replace(self, val, search):
732      bit_size = val.get_bit_size()
733      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
734            bit_size == search.get_bit_size(), \
735            'Ambiguous bit size for replacement value {}: ' \
736            'it cannot be deduced from a variable, a fixed bit size ' \
737            'somewhere, or the search expression.'.format(val)
738
739      if isinstance(val, Expression):
740         for src in val.sources:
741            self.validate_replace(src, search)
742
743   def validate(self, search, replace):
744      self.is_search = True
745      self.merge_variables(search)
746      self.merge_variables(replace)
747      self.validate_value(search)
748
749      self.is_search = False
750      self.validate_value(replace)
751
752      # Check that search is always more specialized than replace. Note that
753      # we're doing this in replace mode, disallowing merging variables.
754      search_bit_size = search.get_bit_size()
755      replace_bit_size = replace.get_bit_size()
756      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
757
758      assert cmp_result is not None and cmp_result <= 0, \
759         'The search expression bit size {} and replace expression ' \
760         'bit size {} may not be the same'.format(
761               search_bit_size, replace_bit_size)
762
763      replace.set_bit_size(search)
764
765      self.validate_replace(replace, search)
766
767_optimization_ids = itertools.count()
768
769condition_list = ['true']
770
771class SearchAndReplace(object):
772   def __init__(self, transform, algebraic_pass):
773      self.id = next(_optimization_ids)
774
775      search = transform[0]
776      replace = transform[1]
777      if len(transform) > 2:
778         self.condition = transform[2]
779      else:
780         self.condition = 'true'
781
782      if self.condition not in condition_list:
783         condition_list.append(self.condition)
784      self.condition_index = condition_list.index(self.condition)
785
786      varset = VarSet()
787      if isinstance(search, Expression):
788         self.search = search
789      else:
790         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
791
792      varset.lock()
793
794      if isinstance(replace, Value):
795         self.replace = replace
796      else:
797         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
798
799      BitSizeValidator(varset).validate(self.search, self.replace)
800
801class TreeAutomaton(object):
802   """This class calculates a bottom-up tree automaton to quickly search for
803   the left-hand sides of tranforms. Tree automatons are a generalization of
804   classical NFA's and DFA's, where the transition function determines the
805   state of the parent node based on the state of its children. We construct a
806   deterministic automaton to match patterns, using a similar algorithm to the
807   classical NFA to DFA construction. At the moment, it only matches opcodes
808   and constants (without checking the actual value), leaving more detailed
809   checking to the search function which actually checks the leaves. The
810   automaton acts as a quick filter for the search function, requiring only n
811   + 1 table lookups for each n-source operation. The implementation is based
812   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
813   In the language of that reference, this is a frontier-to-root deterministic
814   automaton using only symbol filtering. The filtering is crucial to reduce
815   both the time taken to generate the tables and the size of the tables.
816   """
817   def __init__(self, transforms):
818      self.patterns = [t.search for t in transforms]
819      self._compute_items()
820      self._build_table()
821      #print('num items: {}'.format(len(set(self.items.values()))))
822      #print('num states: {}'.format(len(self.states)))
823      #for state, patterns in zip(self.states, self.patterns):
824      #   print('{}: num patterns: {}'.format(state, len(patterns)))
825
826   class IndexMap(object):
827      """An indexed list of objects, where one can either lookup an object by
828      index or find the index associated to an object quickly using a hash
829      table. Compared to a list, it has a constant time index(). Compared to a
830      set, it provides a stable iteration order.
831      """
832      def __init__(self, iterable=()):
833         self.objects = []
834         self.map = {}
835         for obj in iterable:
836            self.add(obj)
837
838      def __getitem__(self, i):
839         return self.objects[i]
840
841      def __contains__(self, obj):
842         return obj in self.map
843
844      def __len__(self):
845         return len(self.objects)
846
847      def __iter__(self):
848         return iter(self.objects)
849
850      def clear(self):
851         self.objects = []
852         self.map.clear()
853
854      def index(self, obj):
855         return self.map[obj]
856
857      def add(self, obj):
858         if obj in self.map:
859            return self.map[obj]
860         else:
861            index = len(self.objects)
862            self.objects.append(obj)
863            self.map[obj] = index
864            return index
865
866      def __repr__(self):
867         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
868
869   class Item(object):
870      """This represents an "item" in the language of "Tree Automatons." This
871      is just a subtree of some pattern, which represents a potential partial
872      match at runtime. We deduplicate them, so that identical subtrees of
873      different patterns share the same object, and store some extra
874      information needed for the main algorithm as well.
875      """
876      def __init__(self, opcode, children):
877         self.opcode = opcode
878         self.children = children
879         # These are the indices of patterns for which this item is the root node.
880         self.patterns = []
881         # This the set of opcodes for parents of this item. Used to speed up
882         # filtering.
883         self.parent_ops = set()
884
885      def __str__(self):
886         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
887
888      def __repr__(self):
889         return str(self)
890
891   def _compute_items(self):
892      """Build a set of all possible items, deduplicating them."""
893      # This is a map from (opcode, sources) to item.
894      self.items = {}
895
896      # The set of all opcodes used by the patterns. Used later to avoid
897      # building and emitting all the tables for opcodes that aren't used.
898      self.opcodes = self.IndexMap()
899
900      def get_item(opcode, children, pattern=None):
901         commutative = len(children) >= 2 \
902               and "2src_commutative" in opcodes[opcode].algebraic_properties
903         item = self.items.setdefault((opcode, children),
904                                      self.Item(opcode, children))
905         if commutative:
906            self.items[opcode, (children[1], children[0]) + children[2:]] = item
907         if pattern is not None:
908            item.patterns.append(pattern)
909         return item
910
911      self.wildcard = get_item("__wildcard", ())
912      self.const = get_item("__const", ())
913
914      def process_subpattern(src, pattern=None):
915         if isinstance(src, Constant):
916            # Note: we throw away the actual constant value!
917            return self.const
918         elif isinstance(src, Variable):
919            if src.is_constant:
920               return self.const
921            else:
922               # Note: we throw away which variable it is here! This special
923               # item is equivalent to nu in "Tree Automatons."
924               return self.wildcard
925         else:
926            assert isinstance(src, Expression)
927            opcode = src.opcode
928            stripped = opcode.rstrip('0123456789')
929            if stripped in conv_opcode_types:
930               # Matches that use conversion opcodes with a specific type,
931               # like f2b1, are tricky.  Either we construct the automaton to
932               # match specific NIR opcodes like nir_op_f2b1, in which case we
933               # need to create separate items for each possible NIR opcode
934               # for patterns that have a generic opcode like f2b, or we
935               # construct it to match the search opcode, in which case we
936               # need to map f2b1 to f2b when constructing the automaton. Here
937               # we do the latter.
938               opcode = stripped
939            self.opcodes.add(opcode)
940            children = tuple(process_subpattern(c) for c in src.sources)
941            item = get_item(opcode, children, pattern)
942            for i, child in enumerate(children):
943               child.parent_ops.add(opcode)
944            return item
945
946      for i, pattern in enumerate(self.patterns):
947         process_subpattern(pattern, i)
948
949   def _build_table(self):
950      """This is the core algorithm which builds up the transition table. It
951      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
952      Comp_a and Filt_{a,i} using integers to identify match sets." It
953      simultaneously builds up a list of all possible "match sets" or
954      "states", where each match set represents the set of Item's that match a
955      given instruction, and builds up the transition table between states.
956      """
957      # Map from opcode + filtered state indices to transitioned state.
958      self.table = defaultdict(dict)
959      # Bijection from state to index. q in the original algorithm is
960      # len(self.states)
961      self.states = self.IndexMap()
962      # Lists of pattern matches separated by None
963      self.state_patterns = [None]
964      # Offset in the ->transforms table for each state index
965      self.state_pattern_offsets = []
966      # Map from state index to filtered state index for each opcode.
967      self.filter = defaultdict(list)
968      # Bijections from filtered state to filtered state index for each
969      # opcode, called the "representor sets" in the original algorithm.
970      # q_{a,j} in the original algorithm is len(self.rep[op]).
971      self.rep = defaultdict(self.IndexMap)
972
973      # Everything in self.states with a index at least worklist_index is part
974      # of the worklist of newly created states. There is also a worklist of
975      # newly fitered states for each opcode, for which worklist_indices
976      # serves a similar purpose. worklist_index corresponds to p in the
977      # original algorithm, while worklist_indices is p_{a,j} (although since
978      # we only filter by opcode/symbol, it's really just p_a).
979      self.worklist_index = 0
980      worklist_indices = defaultdict(lambda: 0)
981
982      # This is the set of opcodes for which the filtered worklist is non-empty.
983      # It's used to avoid scanning opcodes for which there is nothing to
984      # process when building the transition table. It corresponds to new_a in
985      # the original algorithm.
986      new_opcodes = self.IndexMap()
987
988      # Process states on the global worklist, filtering them for each opcode,
989      # updating the filter tables, and updating the filtered worklists if any
990      # new filtered states are found. Similar to ComputeRepresenterSets() in
991      # the original algorithm, although that only processes a single state.
992      def process_new_states():
993         while self.worklist_index < len(self.states):
994            state = self.states[self.worklist_index]
995            # Calculate pattern matches for this state. Each pattern is
996            # assigned to a unique item, so we don't have to worry about
997            # deduplicating them here. However, we do have to sort them so
998            # that they're visited at runtime in the order they're specified
999            # in the source.
1000            patterns = list(sorted(p for item in state for p in item.patterns))
1001
1002            if patterns:
1003                # Add our patterns to the global table.
1004                self.state_pattern_offsets.append(len(self.state_patterns))
1005                self.state_patterns.extend(patterns)
1006                self.state_patterns.append(None)
1007            else:
1008                # Point to the initial sentinel in the global table.
1009                self.state_pattern_offsets.append(0)
1010
1011            # calculate filter table for this state, and update filtered
1012            # worklists.
1013            for op in self.opcodes:
1014               filt = self.filter[op]
1015               rep = self.rep[op]
1016               filtered = frozenset(item for item in state if \
1017                  op in item.parent_ops)
1018               if filtered in rep:
1019                  rep_index = rep.index(filtered)
1020               else:
1021                  rep_index = rep.add(filtered)
1022                  new_opcodes.add(op)
1023               assert len(filt) == self.worklist_index
1024               filt.append(rep_index)
1025            self.worklist_index += 1
1026
1027      # There are two start states: one which can only match as a wildcard,
1028      # and one which can match as a wildcard or constant. These will be the
1029      # states of intrinsics/other instructions and load_const instructions,
1030      # respectively. The indices of these must match the definitions of
1031      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1032      # initialize things correctly.
1033      self.states.add(frozenset((self.wildcard,)))
1034      self.states.add(frozenset((self.const,self.wildcard)))
1035      process_new_states()
1036
1037      while len(new_opcodes) > 0:
1038         for op in new_opcodes:
1039            rep = self.rep[op]
1040            table = self.table[op]
1041            op_worklist_index = worklist_indices[op]
1042            if op in conv_opcode_types:
1043               num_srcs = 1
1044            else:
1045               num_srcs = opcodes[op].num_inputs
1046
1047            # Iterate over all possible source combinations where at least one
1048            # is on the worklist.
1049            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1050               if all(src_idx < op_worklist_index for src_idx in src_indices):
1051                  continue
1052
1053               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1054
1055               # Try all possible pairings of source items and add the
1056               # corresponding parent items. This is Comp_a from the paper.
1057               parent = set(self.items[op, item_srcs] for item_srcs in
1058                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1059
1060               # We could always start matching something else with a
1061               # wildcard. This is Cl from the paper.
1062               parent.add(self.wildcard)
1063
1064               table[src_indices] = self.states.add(frozenset(parent))
1065            worklist_indices[op] = len(rep)
1066         new_opcodes.clear()
1067         process_new_states()
1068
1069_algebraic_pass_template = mako.template.Template("""
1070#include "nir.h"
1071#include "nir_builder.h"
1072#include "nir_search.h"
1073#include "nir_search_helpers.h"
1074
1075/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1076 * transforms:
1077% for xform in xforms:
1078 *    ${xform.search} => ${xform.replace}
1079% endfor
1080 */
1081
1082<% cache = {"next_index": 0} %>
1083static const nir_search_value_union ${pass_name}_values[] = {
1084% for xform in xforms:
1085   /* ${xform.search} => ${xform.replace} */
1086${xform.search.render(cache)}
1087${xform.replace.render(cache)}
1088% endfor
1089};
1090
1091% if expression_cond:
1092static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
1093% for cond in expression_cond:
1094   ${cond[0]},
1095% endfor
1096};
1097% endif
1098
1099% if variable_cond:
1100static const nir_search_variable_cond ${pass_name}_variable_cond[] = {
1101% for cond in variable_cond:
1102   ${cond[0]},
1103% endfor
1104};
1105% endif
1106
1107static const struct transform ${pass_name}_transforms[] = {
1108% for i in automaton.state_patterns:
1109% if i is not None:
1110   { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
1111% else:
1112   { ~0, ~0, ~0 }, /* Sentinel */
1113
1114% endif
1115% endfor
1116};
1117
1118static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = {
1119% for op in automaton.opcodes:
1120   [${get_c_opcode(op)}] = {
1121% if all(e == 0 for e in automaton.filter[op]):
1122      .filter = NULL,
1123% else:
1124      .filter = (const uint16_t []) {
1125      % for e in automaton.filter[op]:
1126         ${e},
1127      % endfor
1128      },
1129% endif
1130      <%
1131        num_filtered = len(automaton.rep[op])
1132      %>
1133      .num_filtered_states = ${num_filtered},
1134      .table = (const uint16_t []) {
1135      <%
1136        num_srcs = len(next(iter(automaton.table[op])))
1137      %>
1138      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1139         ${automaton.table[op][indices]},
1140      % endfor
1141      },
1142   },
1143% endfor
1144};
1145
1146/* Mapping from state index to offset in transforms (0 being no transforms) */
1147static const uint16_t ${pass_name}_transform_offsets[] = {
1148% for offset in automaton.state_pattern_offsets:
1149   ${offset},
1150% endfor
1151};
1152
1153static const nir_algebraic_table ${pass_name}_table = {
1154   .transforms = ${pass_name}_transforms,
1155   .transform_offsets = ${pass_name}_transform_offsets,
1156   .pass_op_table = ${pass_name}_pass_op_table,
1157   .values = ${pass_name}_values,
1158   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
1159   .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" },
1160};
1161
1162bool
1163${pass_name}(nir_shader *shader)
1164{
1165   bool progress = false;
1166   bool condition_flags[${len(condition_list)}];
1167   const nir_shader_compiler_options *options = shader->options;
1168   const shader_info *info = &shader->info;
1169   (void) options;
1170   (void) info;
1171
1172   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
1173   % for index, condition in enumerate(condition_list):
1174   condition_flags[${index}] = ${condition};
1175   % endfor
1176
1177   nir_foreach_function(function, shader) {
1178      if (function->impl) {
1179         progress |= nir_algebraic_impl(function->impl, condition_flags,
1180                                        &${pass_name}_table);
1181      }
1182   }
1183
1184   return progress;
1185}
1186""")
1187
1188
1189class AlgebraicPass(object):
1190   def __init__(self, pass_name, transforms):
1191      self.xforms = []
1192      self.opcode_xforms = defaultdict(lambda : [])
1193      self.pass_name = pass_name
1194      self.expression_cond = {}
1195      self.variable_cond = {}
1196
1197      error = False
1198
1199      for xform in transforms:
1200         if not isinstance(xform, SearchAndReplace):
1201            try:
1202               xform = SearchAndReplace(xform, self)
1203            except:
1204               print("Failed to parse transformation:", file=sys.stderr)
1205               print("  " + str(xform), file=sys.stderr)
1206               traceback.print_exc(file=sys.stderr)
1207               print('', file=sys.stderr)
1208               error = True
1209               continue
1210
1211         self.xforms.append(xform)
1212         if xform.search.opcode in conv_opcode_types:
1213            dst_type = conv_opcode_types[xform.search.opcode]
1214            for size in type_sizes(dst_type):
1215               sized_opcode = xform.search.opcode + str(size)
1216               self.opcode_xforms[sized_opcode].append(xform)
1217         else:
1218            self.opcode_xforms[xform.search.opcode].append(xform)
1219
1220         # Check to make sure the search pattern does not unexpectedly contain
1221         # more commutative expressions than match_expression (nir_search.c)
1222         # can handle.
1223         comm_exprs = xform.search.comm_exprs
1224
1225         if xform.search.many_commutative_expressions:
1226            if comm_exprs <= nir_search_max_comm_ops:
1227               print("Transform expected to have too many commutative " \
1228                     "expression but did not " \
1229                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1230                     file=sys.stderr)
1231               print("  " + str(xform), file=sys.stderr)
1232               traceback.print_exc(file=sys.stderr)
1233               print('', file=sys.stderr)
1234               error = True
1235         else:
1236            if comm_exprs > nir_search_max_comm_ops:
1237               print("Transformation with too many commutative expressions " \
1238                     "({} > {}).  Modify pattern or annotate with " \
1239                     "\"many-comm-expr\".".format(comm_exprs,
1240                                                  nir_search_max_comm_ops),
1241                     file=sys.stderr)
1242               print("  " + str(xform.search), file=sys.stderr)
1243               print("{}".format(xform.search.cond), file=sys.stderr)
1244               error = True
1245
1246      self.automaton = TreeAutomaton(self.xforms)
1247
1248      if error:
1249         sys.exit(1)
1250
1251
1252   def render(self):
1253      return _algebraic_pass_template.render(pass_name=self.pass_name,
1254                                             xforms=self.xforms,
1255                                             opcode_xforms=self.opcode_xforms,
1256                                             condition_list=condition_list,
1257                                             automaton=self.automaton,
1258                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1259                                             variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
1260                                             get_c_opcode=get_c_opcode,
1261                                             itertools=itertools)
1262
1263# The replacement expression isn't necessarily exact if the search expression is exact.
1264def ignore_exact(*expr):
1265   expr = SearchExpression.create(expr)
1266   expr.ignore_exact = True
1267   return expr
1268