1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22# 23# Authors: 24# Jason Ekstrand (jason@jlekstrand.net) 25 26import ast 27from collections import defaultdict 28import itertools 29import struct 30import sys 31import mako.template 32import re 33import traceback 34 35from nir_opcodes import opcodes, type_sizes 36 37# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 38nir_search_max_comm_ops = 8 39 40# These opcodes are only employed by nir_search. This provides a mapping from 41# opcode to destination type. 42conv_opcode_types = { 43 'i2f' : 'float', 44 'u2f' : 'float', 45 'f2f' : 'float', 46 'f2u' : 'uint', 47 'f2i' : 'int', 48 'u2u' : 'uint', 49 'i2i' : 'int', 50 'b2f' : 'float', 51 'b2i' : 'int', 52 'i2b' : 'bool', 53 'f2b' : 'bool', 54} 55 56def get_cond_index(conds, cond): 57 if cond: 58 if cond in conds: 59 return conds[cond] 60 else: 61 cond_index = len(conds) 62 conds[cond] = cond_index 63 return cond_index 64 else: 65 return -1 66 67def get_c_opcode(op): 68 if op in conv_opcode_types: 69 return 'nir_search_op_' + op 70 else: 71 return 'nir_op_' + op 72 73_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 74 75def type_bits(type_str): 76 m = _type_re.match(type_str) 77 assert m.group('type') 78 79 if m.group('bits') is None: 80 return 0 81 else: 82 return int(m.group('bits')) 83 84# Represents a set of variables, each with a unique id 85class VarSet(object): 86 def __init__(self): 87 self.names = {} 88 self.ids = itertools.count() 89 self.immutable = False; 90 91 def __getitem__(self, name): 92 if name not in self.names: 93 assert not self.immutable, "Unknown replacement variable: " + name 94 self.names[name] = next(self.ids) 95 96 return self.names[name] 97 98 def lock(self): 99 self.immutable = True 100 101class SearchExpression(object): 102 def __init__(self, expr): 103 self.opcode = expr[0] 104 self.sources = expr[1:] 105 self.ignore_exact = False 106 107 @staticmethod 108 def create(val): 109 if isinstance(val, tuple): 110 return SearchExpression(val) 111 else: 112 assert(isinstance(val, SearchExpression)) 113 return val 114 115 def __repr__(self): 116 l = [self.opcode, *self.sources] 117 if self.ignore_exact: 118 l.append('ignore_exact') 119 return repr((*l,)) 120 121class Value(object): 122 @staticmethod 123 def create(val, name_base, varset, algebraic_pass): 124 if isinstance(val, bytes): 125 val = val.decode('utf-8') 126 127 if isinstance(val, tuple) or isinstance(val, SearchExpression): 128 return Expression(val, name_base, varset, algebraic_pass) 129 elif isinstance(val, Expression): 130 return val 131 elif isinstance(val, str): 132 return Variable(val, name_base, varset, algebraic_pass) 133 elif isinstance(val, (bool, float, int)): 134 return Constant(val, name_base) 135 136 def __init__(self, val, name, type_str): 137 self.in_val = str(val) 138 self.name = name 139 self.type_str = type_str 140 141 def __str__(self): 142 return self.in_val 143 144 def get_bit_size(self): 145 """Get the physical bit-size that has been chosen for this value, or if 146 there is none, the canonical value which currently represents this 147 bit-size class. Variables will be preferred, i.e. if there are any 148 variables in the equivalence class, the canonical value will be a 149 variable. We do this since we'll need to know which variable each value 150 is equivalent to when constructing the replacement expression. This is 151 the "find" part of the union-find algorithm. 152 """ 153 bit_size = self 154 155 while isinstance(bit_size, Value): 156 if bit_size._bit_size is None: 157 break 158 bit_size = bit_size._bit_size 159 160 if bit_size is not self: 161 self._bit_size = bit_size 162 return bit_size 163 164 def set_bit_size(self, other): 165 """Make self.get_bit_size() return what other.get_bit_size() return 166 before calling this, or just "other" if it's a concrete bit-size. This is 167 the "union" part of the union-find algorithm. 168 """ 169 170 self_bit_size = self.get_bit_size() 171 other_bit_size = other if isinstance(other, int) else other.get_bit_size() 172 173 if self_bit_size == other_bit_size: 174 return 175 176 self_bit_size._bit_size = other_bit_size 177 178 @property 179 def type_enum(self): 180 return "nir_search_value_" + self.type_str 181 182 @property 183 def c_bit_size(self): 184 bit_size = self.get_bit_size() 185 if isinstance(bit_size, int): 186 return bit_size 187 elif isinstance(bit_size, Variable): 188 return -bit_size.index - 1 189 else: 190 # If the bit-size class is neither a variable, nor an actual bit-size, then 191 # - If it's in the search expression, we don't need to check anything 192 # - If it's in the replace expression, either it's ambiguous (in which 193 # case we'd reject it), or it equals the bit-size of the search value 194 # We represent these cases with a 0 bit-size. 195 return 0 196 197 __template = mako.template.Template(""" { .${val.type_str} = { 198 { ${val.type_enum}, ${val.c_bit_size} }, 199% if isinstance(val, Constant): 200 ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 201% elif isinstance(val, Variable): 202 ${val.index}, /* ${val.var_name} */ 203 ${'true' if val.is_constant else 'false'}, 204 ${val.type() or 'nir_type_invalid' }, 205 ${val.cond_index}, 206 ${val.swizzle()}, 207% elif isinstance(val, Expression): 208 ${'true' if val.inexact else 'false'}, 209 ${'true' if val.exact else 'false'}, 210 ${'true' if val.ignore_exact else 'false'}, 211 ${val.c_opcode()}, 212 ${val.comm_expr_idx}, ${val.comm_exprs}, 213 { ${', '.join(src.array_index for src in val.sources)} }, 214 ${val.cond_index}, 215% endif 216 } }, 217""") 218 219 def render(self, cache): 220 struct_init = self.__template.render(val=self, 221 Constant=Constant, 222 Variable=Variable, 223 Expression=Expression) 224 if struct_init in cache: 225 # If it's in the cache, register a name remap in the cache and render 226 # only a comment saying it's been remapped 227 self.array_index = cache[struct_init] 228 return " /* {} -> {} in the cache */\n".format(self.name, 229 cache[struct_init]) 230 else: 231 self.array_index = str(cache["next_index"]) 232 cache[struct_init] = self.array_index 233 cache["next_index"] += 1 234 return struct_init 235 236_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 237 238class Constant(Value): 239 def __init__(self, val, name): 240 Value.__init__(self, val, name, "constant") 241 242 if isinstance(val, (str)): 243 m = _constant_re.match(val) 244 self.value = ast.literal_eval(m.group('value')) 245 self._bit_size = int(m.group('bits')) if m.group('bits') else None 246 else: 247 self.value = val 248 self._bit_size = None 249 250 if isinstance(self.value, bool): 251 assert self._bit_size is None or self._bit_size == 1 252 self._bit_size = 1 253 254 def hex(self): 255 if isinstance(self.value, (bool)): 256 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 257 if isinstance(self.value, int): 258 return hex(self.value) 259 elif isinstance(self.value, float): 260 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 261 else: 262 assert False 263 264 def type(self): 265 if isinstance(self.value, (bool)): 266 return "nir_type_bool" 267 elif isinstance(self.value, int): 268 return "nir_type_int" 269 elif isinstance(self.value, float): 270 return "nir_type_float" 271 272 def equivalent(self, other): 273 """Check that two constants are equivalent. 274 275 This is check is much weaker than equality. One generally cannot be 276 used in place of the other. Using this implementation for the __eq__ 277 will break BitSizeValidator. 278 279 """ 280 if not isinstance(other, type(self)): 281 return False 282 283 return self.value == other.value 284 285# The $ at the end forces there to be an error if any part of the string 286# doesn't match one of the field patterns. 287_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 288 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 289 r"(?P<cond>\([^\)]+\))?" 290 r"(?P<swiz>\.[xyzw]+)?" 291 r"$") 292 293class Variable(Value): 294 def __init__(self, val, name, varset, algebraic_pass): 295 Value.__init__(self, val, name, "variable") 296 297 m = _var_name_re.match(val) 298 assert m and m.group('name') is not None, \ 299 "Malformed variable name \"{}\".".format(val) 300 301 self.var_name = m.group('name') 302 303 # Prevent common cases where someone puts quotes around a literal 304 # constant. If we want to support names that have numeric or 305 # punctuation characters, we can me the first assertion more flexible. 306 assert self.var_name.isalpha() 307 assert self.var_name != 'True' 308 assert self.var_name != 'False' 309 310 self.is_constant = m.group('const') is not None 311 self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) 312 self.required_type = m.group('type') 313 self._bit_size = int(m.group('bits')) if m.group('bits') else None 314 self.swiz = m.group('swiz') 315 316 if self.required_type == 'bool': 317 if self._bit_size is not None: 318 assert self._bit_size in type_sizes(self.required_type) 319 else: 320 self._bit_size = 1 321 322 if self.required_type is not None: 323 assert self.required_type in ('float', 'bool', 'int', 'uint') 324 325 self.index = varset[self.var_name] 326 327 def type(self): 328 if self.required_type == 'bool': 329 return "nir_type_bool" 330 elif self.required_type in ('int', 'uint'): 331 return "nir_type_int" 332 elif self.required_type == 'float': 333 return "nir_type_float" 334 335 def equivalent(self, other): 336 """Check that two variables are equivalent. 337 338 This is check is much weaker than equality. One generally cannot be 339 used in place of the other. Using this implementation for the __eq__ 340 will break BitSizeValidator. 341 342 """ 343 if not isinstance(other, type(self)): 344 return False 345 346 return self.index == other.index 347 348 def swizzle(self): 349 if self.swiz is not None: 350 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 351 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 352 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 353 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 354 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 355 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 356 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 357 358_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 359 r"(?P<cond>\([^\)]+\))?") 360 361class Expression(Value): 362 def __init__(self, expr, name_base, varset, algebraic_pass): 363 Value.__init__(self, expr, name_base, "expression") 364 365 expr = SearchExpression.create(expr) 366 367 m = _opcode_re.match(expr.opcode) 368 assert m and m.group('opcode') is not None 369 370 self.opcode = m.group('opcode') 371 self._bit_size = int(m.group('bits')) if m.group('bits') else None 372 self.inexact = m.group('inexact') is not None 373 self.exact = m.group('exact') is not None 374 self.ignore_exact = expr.ignore_exact 375 self.cond = m.group('cond') 376 377 assert not self.inexact or not self.exact, \ 378 'Expression cannot be both exact and inexact.' 379 380 # "many-comm-expr" isn't really a condition. It's notification to the 381 # generator that this pattern is known to have too many commutative 382 # expressions, and an error should not be generated for this case. 383 self.many_commutative_expressions = False 384 if self.cond and self.cond.find("many-comm-expr") >= 0: 385 # Split the condition into a comma-separated list. Remove 386 # "many-comm-expr". If there is anything left, put it back together. 387 c = self.cond[1:-1].split(",") 388 c.remove("many-comm-expr") 389 assert(len(c) <= 1) 390 391 self.cond = c[0] if c else None 392 self.many_commutative_expressions = True 393 394 # Deduplicate references to the condition functions for the expressions 395 # and save the index for the order they were added. 396 self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond) 397 398 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass) 399 for (i, src) in enumerate(expr.sources) ] 400 401 # nir_search_expression::srcs is hard-coded to 4 402 assert len(self.sources) <= 4 403 404 if self.opcode in conv_opcode_types: 405 assert self._bit_size is None, \ 406 'Expression cannot use an unsized conversion opcode with ' \ 407 'an explicit size; that\'s silly.' 408 409 self.__index_comm_exprs(0) 410 411 def equivalent(self, other): 412 """Check that two variables are equivalent. 413 414 This is check is much weaker than equality. One generally cannot be 415 used in place of the other. Using this implementation for the __eq__ 416 will break BitSizeValidator. 417 418 This implementation does not check for equivalence due to commutativity, 419 but it could. 420 421 """ 422 if not isinstance(other, type(self)): 423 return False 424 425 if len(self.sources) != len(other.sources): 426 return False 427 428 if self.opcode != other.opcode: 429 return False 430 431 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 432 433 def __index_comm_exprs(self, base_idx): 434 """Recursively count and index commutative expressions 435 """ 436 self.comm_exprs = 0 437 438 # A note about the explicit "len(self.sources)" check. The list of 439 # sources comes from user input, and that input might be bad. Check 440 # that the expected second source exists before accessing it. Without 441 # this check, a unit test that does "('iadd', 'a')" will crash. 442 if self.opcode not in conv_opcode_types and \ 443 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 444 len(self.sources) >= 2 and \ 445 not self.sources[0].equivalent(self.sources[1]): 446 self.comm_expr_idx = base_idx 447 self.comm_exprs += 1 448 else: 449 self.comm_expr_idx = -1 450 451 for s in self.sources: 452 if isinstance(s, Expression): 453 s.__index_comm_exprs(base_idx + self.comm_exprs) 454 self.comm_exprs += s.comm_exprs 455 456 return self.comm_exprs 457 458 def c_opcode(self): 459 return get_c_opcode(self.opcode) 460 461 def render(self, cache): 462 srcs = "".join(src.render(cache) for src in self.sources) 463 return srcs + super(Expression, self).render(cache) 464 465class BitSizeValidator(object): 466 """A class for validating bit sizes of expressions. 467 468 NIR supports multiple bit-sizes on expressions in order to handle things 469 such as fp64. The source and destination of every ALU operation is 470 assigned a type and that type may or may not specify a bit size. Sources 471 and destinations whose type does not specify a bit size are considered 472 "unsized" and automatically take on the bit size of the corresponding 473 register or SSA value. NIR has two simple rules for bit sizes that are 474 validated by nir_validator: 475 476 1) A given SSA def or register has a single bit size that is respected by 477 everything that reads from it or writes to it. 478 479 2) The bit sizes of all unsized inputs/outputs on any given ALU 480 instruction must match. They need not match the sized inputs or 481 outputs but they must match each other. 482 483 In order to keep nir_algebraic relatively simple and easy-to-use, 484 nir_search supports a type of bit-size inference based on the two rules 485 above. This is similar to type inference in many common programming 486 languages. If, for instance, you are constructing an add operation and you 487 know the second source is 16-bit, then you know that the other source and 488 the destination must also be 16-bit. There are, however, cases where this 489 inference can be ambiguous or contradictory. Consider, for instance, the 490 following transformation: 491 492 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 493 494 This transformation can potentially cause a problem because usub_borrow is 495 well-defined for any bit-size of integer. However, b2i always generates a 496 32-bit result so it could end up replacing a 64-bit expression with one 497 that takes two 64-bit values and produces a 32-bit value. As another 498 example, consider this expression: 499 500 (('bcsel', a, b, 0), ('iand', a, b)) 501 502 In this case, in the search expression a must be 32-bit but b can 503 potentially have any bit size. If we had a 64-bit b value, we would end up 504 trying to and a 32-bit value with a 64-bit value which would be invalid 505 506 This class solves that problem by providing a validation layer that proves 507 that a given search-and-replace operation is 100% well-defined before we 508 generate any code. This ensures that bugs are caught at compile time 509 rather than at run time. 510 511 Each value maintains a "bit-size class", which is either an actual bit size 512 or an equivalence class with other values that must have the same bit size. 513 The validator works by combining bit-size classes with each other according 514 to the NIR rules outlined above, checking that there are no inconsistencies. 515 When doing this for the replacement expression, we make sure to never change 516 the equivalence class of any of the search values. We could make the example 517 transforms above work by doing some extra run-time checking of the search 518 expression, but we make the user specify those constraints themselves, to 519 avoid any surprises. Since the replacement bitsizes can only be connected to 520 the source bitsize via variables (variables must have the same bitsize in 521 the source and replacment expressions) or the roots of the expression (the 522 replacement expression must produce the same bit size as the search 523 expression), we prevent merging a variable with anything when processing the 524 replacement expression, or specializing the search bitsize 525 with anything. The former prevents 526 527 (('bcsel', a, b, 0), ('iand', a, b)) 528 529 from being allowed, since we'd have to merge the bitsizes for a and b due to 530 the 'iand', while the latter prevents 531 532 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 533 534 from being allowed, since the search expression has the bit size of a and b, 535 which can't be specialized to 32 which is the bitsize of the replace 536 expression. It also prevents something like: 537 538 (('b2i', ('i2b', a)), ('ineq', a, 0)) 539 540 since the bitsize of 'b2i', which can be anything, can't be specialized to 541 the bitsize of a. 542 543 After doing all this, we check that every subexpression of the replacement 544 was assigned a constant bitsize, the bitsize of a variable, or the bitsize 545 of the search expresssion, since those are the things that are known when 546 constructing the replacement expresssion. Finally, we record the bitsize 547 needed in nir_search_value so that we know what to do when building the 548 replacement expression. 549 """ 550 551 def __init__(self, varset): 552 self._var_classes = [None] * len(varset.names) 553 554 def compare_bitsizes(self, a, b): 555 """Determines which bitsize class is a specialization of the other, or 556 whether neither is. When we merge two different bitsizes, the 557 less-specialized bitsize always points to the more-specialized one, so 558 that calling get_bit_size() always gets you the most specialized bitsize. 559 The specialization partial order is given by: 560 - Physical bitsizes are always the most specialized, and a different 561 bitsize can never specialize another. 562 - In the search expression, variables can always be specialized to each 563 other and to physical bitsizes. In the replace expression, we disallow 564 this to avoid adding extra constraints to the search expression that 565 the user didn't specify. 566 - Expressions and constants without a bitsize can always be specialized to 567 each other and variables, but not the other way around. 568 569 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 570 and None if they are not comparable (neither a <= b nor b <= a). 571 """ 572 if isinstance(a, int): 573 if isinstance(b, int): 574 return 0 if a == b else None 575 elif isinstance(b, Variable): 576 return -1 if self.is_search else None 577 else: 578 return -1 579 elif isinstance(a, Variable): 580 if isinstance(b, int): 581 return 1 if self.is_search else None 582 elif isinstance(b, Variable): 583 return 0 if self.is_search or a.index == b.index else None 584 else: 585 return -1 586 else: 587 if isinstance(b, int): 588 return 1 589 elif isinstance(b, Variable): 590 return 1 591 else: 592 return 0 593 594 def unify_bit_size(self, a, b, error_msg): 595 """Record that a must have the same bit-size as b. If both 596 have been assigned conflicting physical bit-sizes, call "error_msg" with 597 the bit-sizes of self and other to get a message and raise an error. 598 In the replace expression, disallow merging variables with other 599 variables and physical bit-sizes as well. 600 """ 601 a_bit_size = a.get_bit_size() 602 b_bit_size = b if isinstance(b, int) else b.get_bit_size() 603 604 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 605 606 assert cmp_result is not None, \ 607 error_msg(a_bit_size, b_bit_size) 608 609 if cmp_result < 0: 610 b_bit_size.set_bit_size(a) 611 elif not isinstance(a_bit_size, int): 612 a_bit_size.set_bit_size(b) 613 614 def merge_variables(self, val): 615 """Perform the first part of type inference by merging all the different 616 uses of the same variable. We always do this as if we're in the search 617 expression, even if we're actually not, since otherwise we'd get errors 618 if the search expression specified some constraint but the replace 619 expression didn't, because we'd be merging a variable and a constant. 620 """ 621 if isinstance(val, Variable): 622 if self._var_classes[val.index] is None: 623 self._var_classes[val.index] = val 624 else: 625 other = self._var_classes[val.index] 626 self.unify_bit_size(other, val, 627 lambda other_bit_size, bit_size: 628 'Variable {} has conflicting bit size requirements: ' \ 629 'it must have bit size {} and {}'.format( 630 val.var_name, other_bit_size, bit_size)) 631 elif isinstance(val, Expression): 632 for src in val.sources: 633 self.merge_variables(src) 634 635 def validate_value(self, val): 636 """Validate the an expression by performing classic Hindley-Milner 637 type inference on bitsizes. This will detect if there are any conflicting 638 requirements, and unify variables so that we know which variables must 639 have the same bitsize. If we're operating on the replace expression, we 640 will refuse to merge different variables together or merge a variable 641 with a constant, in order to prevent surprises due to rules unexpectedly 642 not matching at runtime. 643 """ 644 if not isinstance(val, Expression): 645 return 646 647 # Generic conversion ops are special in that they have a single unsized 648 # source and an unsized destination and the two don't have to match. 649 # This means there's no validation or unioning to do here besides the 650 # len(val.sources) check. 651 if val.opcode in conv_opcode_types: 652 assert len(val.sources) == 1, \ 653 "Expression {} has {} sources, expected 1".format( 654 val, len(val.sources)) 655 self.validate_value(val.sources[0]) 656 return 657 658 nir_op = opcodes[val.opcode] 659 assert len(val.sources) == nir_op.num_inputs, \ 660 "Expression {} has {} sources, expected {}".format( 661 val, len(val.sources), nir_op.num_inputs) 662 663 for src in val.sources: 664 self.validate_value(src) 665 666 dst_type_bits = type_bits(nir_op.output_type) 667 668 # First, unify all the sources. That way, an error coming up because two 669 # sources have an incompatible bit-size won't produce an error message 670 # involving the destination. 671 first_unsized_src = None 672 for src_type, src in zip(nir_op.input_types, val.sources): 673 src_type_bits = type_bits(src_type) 674 if src_type_bits == 0: 675 if first_unsized_src is None: 676 first_unsized_src = src 677 continue 678 679 if self.is_search: 680 self.unify_bit_size(first_unsized_src, src, 681 lambda first_unsized_src_bit_size, src_bit_size: 682 'Source {} of {} must have bit size {}, while source {} ' \ 683 'must have incompatible bit size {}'.format( 684 first_unsized_src, val, first_unsized_src_bit_size, 685 src, src_bit_size)) 686 else: 687 self.unify_bit_size(first_unsized_src, src, 688 lambda first_unsized_src_bit_size, src_bit_size: 689 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 690 'of {} may not have the same bit size when building the ' \ 691 'replacement expression.'.format( 692 first_unsized_src, first_unsized_src_bit_size, src, 693 src_bit_size, val)) 694 else: 695 if self.is_search: 696 self.unify_bit_size(src, src_type_bits, 697 lambda src_bit_size, unused: 698 '{} must have {} bits, but as a source of nir_op_{} '\ 699 'it must have {} bits'.format( 700 src, src_bit_size, nir_op.name, src_type_bits)) 701 else: 702 self.unify_bit_size(src, src_type_bits, 703 lambda src_bit_size, unused: 704 '{} has the bit size of {}, but as a source of ' \ 705 'nir_op_{} it must have {} bits, which may not be the ' \ 706 'same'.format( 707 src, src_bit_size, nir_op.name, src_type_bits)) 708 709 if dst_type_bits == 0: 710 if first_unsized_src is not None: 711 if self.is_search: 712 self.unify_bit_size(val, first_unsized_src, 713 lambda val_bit_size, src_bit_size: 714 '{} must have the bit size of {}, while its source {} ' \ 715 'must have incompatible bit size {}'.format( 716 val, val_bit_size, first_unsized_src, src_bit_size)) 717 else: 718 self.unify_bit_size(val, first_unsized_src, 719 lambda val_bit_size, src_bit_size: 720 '{} must have {} bits, but its source {} ' \ 721 '(bit size of {}) may not have that bit size ' \ 722 'when building the replacement.'.format( 723 val, val_bit_size, first_unsized_src, src_bit_size)) 724 else: 725 self.unify_bit_size(val, dst_type_bits, 726 lambda dst_bit_size, unused: 727 '{} must have {} bits, but as a destination of nir_op_{} ' \ 728 'it must have {} bits'.format( 729 val, dst_bit_size, nir_op.name, dst_type_bits)) 730 731 def validate_replace(self, val, search): 732 bit_size = val.get_bit_size() 733 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 734 bit_size == search.get_bit_size(), \ 735 'Ambiguous bit size for replacement value {}: ' \ 736 'it cannot be deduced from a variable, a fixed bit size ' \ 737 'somewhere, or the search expression.'.format(val) 738 739 if isinstance(val, Expression): 740 for src in val.sources: 741 self.validate_replace(src, search) 742 743 def validate(self, search, replace): 744 self.is_search = True 745 self.merge_variables(search) 746 self.merge_variables(replace) 747 self.validate_value(search) 748 749 self.is_search = False 750 self.validate_value(replace) 751 752 # Check that search is always more specialized than replace. Note that 753 # we're doing this in replace mode, disallowing merging variables. 754 search_bit_size = search.get_bit_size() 755 replace_bit_size = replace.get_bit_size() 756 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 757 758 assert cmp_result is not None and cmp_result <= 0, \ 759 'The search expression bit size {} and replace expression ' \ 760 'bit size {} may not be the same'.format( 761 search_bit_size, replace_bit_size) 762 763 replace.set_bit_size(search) 764 765 self.validate_replace(replace, search) 766 767_optimization_ids = itertools.count() 768 769condition_list = ['true'] 770 771class SearchAndReplace(object): 772 def __init__(self, transform, algebraic_pass): 773 self.id = next(_optimization_ids) 774 775 search = transform[0] 776 replace = transform[1] 777 if len(transform) > 2: 778 self.condition = transform[2] 779 else: 780 self.condition = 'true' 781 782 if self.condition not in condition_list: 783 condition_list.append(self.condition) 784 self.condition_index = condition_list.index(self.condition) 785 786 varset = VarSet() 787 if isinstance(search, Expression): 788 self.search = search 789 else: 790 self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass) 791 792 varset.lock() 793 794 if isinstance(replace, Value): 795 self.replace = replace 796 else: 797 self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass) 798 799 BitSizeValidator(varset).validate(self.search, self.replace) 800 801class TreeAutomaton(object): 802 """This class calculates a bottom-up tree automaton to quickly search for 803 the left-hand sides of tranforms. Tree automatons are a generalization of 804 classical NFA's and DFA's, where the transition function determines the 805 state of the parent node based on the state of its children. We construct a 806 deterministic automaton to match patterns, using a similar algorithm to the 807 classical NFA to DFA construction. At the moment, it only matches opcodes 808 and constants (without checking the actual value), leaving more detailed 809 checking to the search function which actually checks the leaves. The 810 automaton acts as a quick filter for the search function, requiring only n 811 + 1 table lookups for each n-source operation. The implementation is based 812 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 813 In the language of that reference, this is a frontier-to-root deterministic 814 automaton using only symbol filtering. The filtering is crucial to reduce 815 both the time taken to generate the tables and the size of the tables. 816 """ 817 def __init__(self, transforms): 818 self.patterns = [t.search for t in transforms] 819 self._compute_items() 820 self._build_table() 821 #print('num items: {}'.format(len(set(self.items.values())))) 822 #print('num states: {}'.format(len(self.states))) 823 #for state, patterns in zip(self.states, self.patterns): 824 # print('{}: num patterns: {}'.format(state, len(patterns))) 825 826 class IndexMap(object): 827 """An indexed list of objects, where one can either lookup an object by 828 index or find the index associated to an object quickly using a hash 829 table. Compared to a list, it has a constant time index(). Compared to a 830 set, it provides a stable iteration order. 831 """ 832 def __init__(self, iterable=()): 833 self.objects = [] 834 self.map = {} 835 for obj in iterable: 836 self.add(obj) 837 838 def __getitem__(self, i): 839 return self.objects[i] 840 841 def __contains__(self, obj): 842 return obj in self.map 843 844 def __len__(self): 845 return len(self.objects) 846 847 def __iter__(self): 848 return iter(self.objects) 849 850 def clear(self): 851 self.objects = [] 852 self.map.clear() 853 854 def index(self, obj): 855 return self.map[obj] 856 857 def add(self, obj): 858 if obj in self.map: 859 return self.map[obj] 860 else: 861 index = len(self.objects) 862 self.objects.append(obj) 863 self.map[obj] = index 864 return index 865 866 def __repr__(self): 867 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 868 869 class Item(object): 870 """This represents an "item" in the language of "Tree Automatons." This 871 is just a subtree of some pattern, which represents a potential partial 872 match at runtime. We deduplicate them, so that identical subtrees of 873 different patterns share the same object, and store some extra 874 information needed for the main algorithm as well. 875 """ 876 def __init__(self, opcode, children): 877 self.opcode = opcode 878 self.children = children 879 # These are the indices of patterns for which this item is the root node. 880 self.patterns = [] 881 # This the set of opcodes for parents of this item. Used to speed up 882 # filtering. 883 self.parent_ops = set() 884 885 def __str__(self): 886 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 887 888 def __repr__(self): 889 return str(self) 890 891 def _compute_items(self): 892 """Build a set of all possible items, deduplicating them.""" 893 # This is a map from (opcode, sources) to item. 894 self.items = {} 895 896 # The set of all opcodes used by the patterns. Used later to avoid 897 # building and emitting all the tables for opcodes that aren't used. 898 self.opcodes = self.IndexMap() 899 900 def get_item(opcode, children, pattern=None): 901 commutative = len(children) >= 2 \ 902 and "2src_commutative" in opcodes[opcode].algebraic_properties 903 item = self.items.setdefault((opcode, children), 904 self.Item(opcode, children)) 905 if commutative: 906 self.items[opcode, (children[1], children[0]) + children[2:]] = item 907 if pattern is not None: 908 item.patterns.append(pattern) 909 return item 910 911 self.wildcard = get_item("__wildcard", ()) 912 self.const = get_item("__const", ()) 913 914 def process_subpattern(src, pattern=None): 915 if isinstance(src, Constant): 916 # Note: we throw away the actual constant value! 917 return self.const 918 elif isinstance(src, Variable): 919 if src.is_constant: 920 return self.const 921 else: 922 # Note: we throw away which variable it is here! This special 923 # item is equivalent to nu in "Tree Automatons." 924 return self.wildcard 925 else: 926 assert isinstance(src, Expression) 927 opcode = src.opcode 928 stripped = opcode.rstrip('0123456789') 929 if stripped in conv_opcode_types: 930 # Matches that use conversion opcodes with a specific type, 931 # like f2b1, are tricky. Either we construct the automaton to 932 # match specific NIR opcodes like nir_op_f2b1, in which case we 933 # need to create separate items for each possible NIR opcode 934 # for patterns that have a generic opcode like f2b, or we 935 # construct it to match the search opcode, in which case we 936 # need to map f2b1 to f2b when constructing the automaton. Here 937 # we do the latter. 938 opcode = stripped 939 self.opcodes.add(opcode) 940 children = tuple(process_subpattern(c) for c in src.sources) 941 item = get_item(opcode, children, pattern) 942 for i, child in enumerate(children): 943 child.parent_ops.add(opcode) 944 return item 945 946 for i, pattern in enumerate(self.patterns): 947 process_subpattern(pattern, i) 948 949 def _build_table(self): 950 """This is the core algorithm which builds up the transition table. It 951 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 952 Comp_a and Filt_{a,i} using integers to identify match sets." It 953 simultaneously builds up a list of all possible "match sets" or 954 "states", where each match set represents the set of Item's that match a 955 given instruction, and builds up the transition table between states. 956 """ 957 # Map from opcode + filtered state indices to transitioned state. 958 self.table = defaultdict(dict) 959 # Bijection from state to index. q in the original algorithm is 960 # len(self.states) 961 self.states = self.IndexMap() 962 # Lists of pattern matches separated by None 963 self.state_patterns = [None] 964 # Offset in the ->transforms table for each state index 965 self.state_pattern_offsets = [] 966 # Map from state index to filtered state index for each opcode. 967 self.filter = defaultdict(list) 968 # Bijections from filtered state to filtered state index for each 969 # opcode, called the "representor sets" in the original algorithm. 970 # q_{a,j} in the original algorithm is len(self.rep[op]). 971 self.rep = defaultdict(self.IndexMap) 972 973 # Everything in self.states with a index at least worklist_index is part 974 # of the worklist of newly created states. There is also a worklist of 975 # newly fitered states for each opcode, for which worklist_indices 976 # serves a similar purpose. worklist_index corresponds to p in the 977 # original algorithm, while worklist_indices is p_{a,j} (although since 978 # we only filter by opcode/symbol, it's really just p_a). 979 self.worklist_index = 0 980 worklist_indices = defaultdict(lambda: 0) 981 982 # This is the set of opcodes for which the filtered worklist is non-empty. 983 # It's used to avoid scanning opcodes for which there is nothing to 984 # process when building the transition table. It corresponds to new_a in 985 # the original algorithm. 986 new_opcodes = self.IndexMap() 987 988 # Process states on the global worklist, filtering them for each opcode, 989 # updating the filter tables, and updating the filtered worklists if any 990 # new filtered states are found. Similar to ComputeRepresenterSets() in 991 # the original algorithm, although that only processes a single state. 992 def process_new_states(): 993 while self.worklist_index < len(self.states): 994 state = self.states[self.worklist_index] 995 # Calculate pattern matches for this state. Each pattern is 996 # assigned to a unique item, so we don't have to worry about 997 # deduplicating them here. However, we do have to sort them so 998 # that they're visited at runtime in the order they're specified 999 # in the source. 1000 patterns = list(sorted(p for item in state for p in item.patterns)) 1001 1002 if patterns: 1003 # Add our patterns to the global table. 1004 self.state_pattern_offsets.append(len(self.state_patterns)) 1005 self.state_patterns.extend(patterns) 1006 self.state_patterns.append(None) 1007 else: 1008 # Point to the initial sentinel in the global table. 1009 self.state_pattern_offsets.append(0) 1010 1011 # calculate filter table for this state, and update filtered 1012 # worklists. 1013 for op in self.opcodes: 1014 filt = self.filter[op] 1015 rep = self.rep[op] 1016 filtered = frozenset(item for item in state if \ 1017 op in item.parent_ops) 1018 if filtered in rep: 1019 rep_index = rep.index(filtered) 1020 else: 1021 rep_index = rep.add(filtered) 1022 new_opcodes.add(op) 1023 assert len(filt) == self.worklist_index 1024 filt.append(rep_index) 1025 self.worklist_index += 1 1026 1027 # There are two start states: one which can only match as a wildcard, 1028 # and one which can match as a wildcard or constant. These will be the 1029 # states of intrinsics/other instructions and load_const instructions, 1030 # respectively. The indices of these must match the definitions of 1031 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1032 # initialize things correctly. 1033 self.states.add(frozenset((self.wildcard,))) 1034 self.states.add(frozenset((self.const,self.wildcard))) 1035 process_new_states() 1036 1037 while len(new_opcodes) > 0: 1038 for op in new_opcodes: 1039 rep = self.rep[op] 1040 table = self.table[op] 1041 op_worklist_index = worklist_indices[op] 1042 if op in conv_opcode_types: 1043 num_srcs = 1 1044 else: 1045 num_srcs = opcodes[op].num_inputs 1046 1047 # Iterate over all possible source combinations where at least one 1048 # is on the worklist. 1049 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1050 if all(src_idx < op_worklist_index for src_idx in src_indices): 1051 continue 1052 1053 srcs = tuple(rep[src_idx] for src_idx in src_indices) 1054 1055 # Try all possible pairings of source items and add the 1056 # corresponding parent items. This is Comp_a from the paper. 1057 parent = set(self.items[op, item_srcs] for item_srcs in 1058 itertools.product(*srcs) if (op, item_srcs) in self.items) 1059 1060 # We could always start matching something else with a 1061 # wildcard. This is Cl from the paper. 1062 parent.add(self.wildcard) 1063 1064 table[src_indices] = self.states.add(frozenset(parent)) 1065 worklist_indices[op] = len(rep) 1066 new_opcodes.clear() 1067 process_new_states() 1068 1069_algebraic_pass_template = mako.template.Template(""" 1070#include "nir.h" 1071#include "nir_builder.h" 1072#include "nir_search.h" 1073#include "nir_search_helpers.h" 1074 1075/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1076 * transforms: 1077% for xform in xforms: 1078 * ${xform.search} => ${xform.replace} 1079% endfor 1080 */ 1081 1082<% cache = {"next_index": 0} %> 1083static const nir_search_value_union ${pass_name}_values[] = { 1084% for xform in xforms: 1085 /* ${xform.search} => ${xform.replace} */ 1086${xform.search.render(cache)} 1087${xform.replace.render(cache)} 1088% endfor 1089}; 1090 1091% if expression_cond: 1092static const nir_search_expression_cond ${pass_name}_expression_cond[] = { 1093% for cond in expression_cond: 1094 ${cond[0]}, 1095% endfor 1096}; 1097% endif 1098 1099% if variable_cond: 1100static const nir_search_variable_cond ${pass_name}_variable_cond[] = { 1101% for cond in variable_cond: 1102 ${cond[0]}, 1103% endfor 1104}; 1105% endif 1106 1107static const struct transform ${pass_name}_transforms[] = { 1108% for i in automaton.state_patterns: 1109% if i is not None: 1110 { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} }, 1111% else: 1112 { ~0, ~0, ~0 }, /* Sentinel */ 1113 1114% endif 1115% endfor 1116}; 1117 1118static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = { 1119% for op in automaton.opcodes: 1120 [${get_c_opcode(op)}] = { 1121% if all(e == 0 for e in automaton.filter[op]): 1122 .filter = NULL, 1123% else: 1124 .filter = (const uint16_t []) { 1125 % for e in automaton.filter[op]: 1126 ${e}, 1127 % endfor 1128 }, 1129% endif 1130 <% 1131 num_filtered = len(automaton.rep[op]) 1132 %> 1133 .num_filtered_states = ${num_filtered}, 1134 .table = (const uint16_t []) { 1135 <% 1136 num_srcs = len(next(iter(automaton.table[op]))) 1137 %> 1138 % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1139 ${automaton.table[op][indices]}, 1140 % endfor 1141 }, 1142 }, 1143% endfor 1144}; 1145 1146/* Mapping from state index to offset in transforms (0 being no transforms) */ 1147static const uint16_t ${pass_name}_transform_offsets[] = { 1148% for offset in automaton.state_pattern_offsets: 1149 ${offset}, 1150% endfor 1151}; 1152 1153static const nir_algebraic_table ${pass_name}_table = { 1154 .transforms = ${pass_name}_transforms, 1155 .transform_offsets = ${pass_name}_transform_offsets, 1156 .pass_op_table = ${pass_name}_pass_op_table, 1157 .values = ${pass_name}_values, 1158 .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, 1159 .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, 1160}; 1161 1162bool 1163${pass_name}(nir_shader *shader) 1164{ 1165 bool progress = false; 1166 bool condition_flags[${len(condition_list)}]; 1167 const nir_shader_compiler_options *options = shader->options; 1168 const shader_info *info = &shader->info; 1169 (void) options; 1170 (void) info; 1171 1172 STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values)); 1173 % for index, condition in enumerate(condition_list): 1174 condition_flags[${index}] = ${condition}; 1175 % endfor 1176 1177 nir_foreach_function(function, shader) { 1178 if (function->impl) { 1179 progress |= nir_algebraic_impl(function->impl, condition_flags, 1180 &${pass_name}_table); 1181 } 1182 } 1183 1184 return progress; 1185} 1186""") 1187 1188 1189class AlgebraicPass(object): 1190 def __init__(self, pass_name, transforms): 1191 self.xforms = [] 1192 self.opcode_xforms = defaultdict(lambda : []) 1193 self.pass_name = pass_name 1194 self.expression_cond = {} 1195 self.variable_cond = {} 1196 1197 error = False 1198 1199 for xform in transforms: 1200 if not isinstance(xform, SearchAndReplace): 1201 try: 1202 xform = SearchAndReplace(xform, self) 1203 except: 1204 print("Failed to parse transformation:", file=sys.stderr) 1205 print(" " + str(xform), file=sys.stderr) 1206 traceback.print_exc(file=sys.stderr) 1207 print('', file=sys.stderr) 1208 error = True 1209 continue 1210 1211 self.xforms.append(xform) 1212 if xform.search.opcode in conv_opcode_types: 1213 dst_type = conv_opcode_types[xform.search.opcode] 1214 for size in type_sizes(dst_type): 1215 sized_opcode = xform.search.opcode + str(size) 1216 self.opcode_xforms[sized_opcode].append(xform) 1217 else: 1218 self.opcode_xforms[xform.search.opcode].append(xform) 1219 1220 # Check to make sure the search pattern does not unexpectedly contain 1221 # more commutative expressions than match_expression (nir_search.c) 1222 # can handle. 1223 comm_exprs = xform.search.comm_exprs 1224 1225 if xform.search.many_commutative_expressions: 1226 if comm_exprs <= nir_search_max_comm_ops: 1227 print("Transform expected to have too many commutative " \ 1228 "expression but did not " \ 1229 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1230 file=sys.stderr) 1231 print(" " + str(xform), file=sys.stderr) 1232 traceback.print_exc(file=sys.stderr) 1233 print('', file=sys.stderr) 1234 error = True 1235 else: 1236 if comm_exprs > nir_search_max_comm_ops: 1237 print("Transformation with too many commutative expressions " \ 1238 "({} > {}). Modify pattern or annotate with " \ 1239 "\"many-comm-expr\".".format(comm_exprs, 1240 nir_search_max_comm_ops), 1241 file=sys.stderr) 1242 print(" " + str(xform.search), file=sys.stderr) 1243 print("{}".format(xform.search.cond), file=sys.stderr) 1244 error = True 1245 1246 self.automaton = TreeAutomaton(self.xforms) 1247 1248 if error: 1249 sys.exit(1) 1250 1251 1252 def render(self): 1253 return _algebraic_pass_template.render(pass_name=self.pass_name, 1254 xforms=self.xforms, 1255 opcode_xforms=self.opcode_xforms, 1256 condition_list=condition_list, 1257 automaton=self.automaton, 1258 expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), 1259 variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), 1260 get_c_opcode=get_c_opcode, 1261 itertools=itertools) 1262 1263# The replacement expression isn't necessarily exact if the search expression is exact. 1264def ignore_exact(*expr): 1265 expr = SearchExpression.create(expr) 1266 expr.ignore_exact = True 1267 return expr 1268