1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3
4# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5# See https://llvm.org/LICENSE.txt for license information.
6# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7
8# Script for updating SPIR-V dialect by scraping information from SPIR-V
9# HTML and JSON specs from the Internet.
10#
11# For example, to define the enum attribute for SPIR-V memory model:
12#
13# ./gen_spirv_dialect.py --base_td_path /path/to/SPIRVBase.td \
14#                        --new-enum MemoryModel
15#
16# The 'operand_kinds' dict of spirv.core.grammar.json contains all supported
17# SPIR-V enum classes.
18
19import itertools
20import re
21import requests
22import textwrap
23import yaml
24
25SPIRV_HTML_SPEC_URL = 'https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html'
26SPIRV_JSON_SPEC_URL = 'https://raw.githubusercontent.com/KhronosGroup/SPIRV-Headers/master/include/spirv/unified1/spirv.core.grammar.json'
27
28AUTOGEN_OP_DEF_SEPARATOR = '\n// -----\n\n'
29AUTOGEN_ENUM_SECTION_MARKER = 'enum section. Generated from SPIR-V spec; DO NOT MODIFY!'
30AUTOGEN_OPCODE_SECTION_MARKER = (
31    'opcode section. Generated from SPIR-V spec; DO NOT MODIFY!')
32
33
34def get_spirv_doc_from_html_spec():
35  """Extracts instruction documentation from SPIR-V HTML spec.
36
37  Returns:
38    - A dict mapping from instruction opcode to documentation.
39  """
40  response = requests.get(SPIRV_HTML_SPEC_URL)
41  spec = response.content
42
43  from bs4 import BeautifulSoup
44  spirv = BeautifulSoup(spec, 'html.parser')
45
46  section_anchor = spirv.find('h3', {'id': '_a_id_instructions_a_instructions'})
47
48  doc = {}
49
50  for section in section_anchor.parent.find_all('div', {'class': 'sect3'}):
51    for table in section.find_all('table'):
52      inst_html = table.tbody.tr.td.p
53      opname = inst_html.a['id']
54      # Ignore the first line, which is just the opname.
55      doc[opname] = inst_html.text.split('\n', 1)[1].strip()
56
57  return doc
58
59
60def get_spirv_grammar_from_json_spec():
61  """Extracts operand kind and instruction grammar from SPIR-V JSON spec.
62
63  Returns:
64    - A list containing all operand kinds' grammar
65    - A list containing all instructions' grammar
66  """
67  response = requests.get(SPIRV_JSON_SPEC_URL)
68  spec = response.content
69
70  import json
71  spirv = json.loads(spec)
72
73  return spirv['operand_kinds'], spirv['instructions']
74
75
76def split_list_into_sublists(items):
77  """Split the list of items into multiple sublists.
78
79  This is to make sure the string composed from each sublist won't exceed
80  80 characters.
81
82  Arguments:
83    - items: a list of strings
84  """
85  chuncks = []
86  chunk = []
87  chunk_len = 0
88
89  for item in items:
90    chunk_len += len(item) + 2
91    if chunk_len > 80:
92      chuncks.append(chunk)
93      chunk = []
94      chunk_len = len(item) + 2
95    chunk.append(item)
96
97  if len(chunk) != 0:
98    chuncks.append(chunk)
99
100  return chuncks
101
102
103def uniquify_enum_cases(lst):
104  """Prunes duplicate enum cases from the list.
105
106  Arguments:
107   - lst: List whose elements are to be uniqued. Assumes each element is a
108     (symbol, value) pair and elements already sorted according to value.
109
110  Returns:
111   - A list with all duplicates removed. The elements are sorted according to
112     value and, for each value, uniqued according to symbol.
113     original list,
114   - A map from deduplicated cases to the uniqued case.
115  """
116  cases = lst
117  uniqued_cases = []
118  duplicated_cases = {}
119
120  # First sort according to the value
121  cases.sort(key=lambda x: x[1])
122
123  # Then group them according to the value
124  for _, groups in itertools.groupby(cases, key=lambda x: x[1]):
125    # For each value, sort according to the enumerant symbol.
126    sorted_group = sorted(groups, key=lambda x: x[0])
127    # Keep the "smallest" case, which is typically the symbol without extension
128    # suffix. But we have special cases that we want to fix.
129    case = sorted_group[0]
130    for i in range(1, len(sorted_group)):
131      duplicated_cases[sorted_group[i][0]] = case[0]
132    if case[0] == 'HlslSemanticGOOGLE':
133      assert len(sorted_group) == 2, 'unexpected new variant for HlslSemantic'
134      case = sorted_group[1]
135      duplicated_cases[sorted_group[0][0]] = case[0]
136    uniqued_cases.append(case)
137
138  return uniqued_cases, duplicated_cases
139
140
141def toposort(dag, sort_fn):
142  """Topologically sorts the given dag.
143
144  Arguments:
145    - dag: a dict mapping from a node to its incoming nodes.
146    - sort_fn: a function for sorting nodes in the same batch.
147
148  Returns:
149    A list containing topologically sorted nodes.
150  """
151
152  # Returns the next batch of nodes without incoming edges
153  def get_next_batch(dag):
154    while True:
155      no_prev_nodes = set(node for node, prev in dag.items() if not prev)
156      if not no_prev_nodes:
157        break
158      yield sorted(no_prev_nodes, key=sort_fn)
159      dag = {
160          node: (prev - no_prev_nodes)
161          for node, prev in dag.items()
162          if node not in no_prev_nodes
163      }
164    assert not dag, 'found cyclic dependency'
165
166  sorted_nodes = []
167  for batch in get_next_batch(dag):
168    sorted_nodes.extend(batch)
169
170  return sorted_nodes
171
172
173def toposort_capabilities(all_cases, capability_mapping):
174  """Returns topologically sorted capability (symbol, value) pairs.
175
176  Arguments:
177    - all_cases: all capability cases (containing symbol, value, and implied
178      capabilities).
179    - capability_mapping: mapping from duplicated capability symbols to the
180      canonicalized symbol chosen for SPIRVBase.td.
181
182  Returns:
183    A list containing topologically sorted capability (symbol, value) pairs.
184  """
185  dag = {}
186  name_to_value = {}
187  for case in all_cases:
188    # Get the current capability.
189    cur = case['enumerant']
190    name_to_value[cur] = case['value']
191    # Ignore duplicated symbols.
192    if cur in capability_mapping:
193      continue
194
195    # Get capabilities implied by the current capability.
196    prev = case.get('capabilities', [])
197    uniqued_prev = set([capability_mapping.get(c, c) for c in prev])
198    dag[cur] = uniqued_prev
199
200  sorted_caps = toposort(dag, lambda x: name_to_value[x])
201  # Attach the capability's value as the second component of the pair.
202  return [(c, name_to_value[c]) for c in sorted_caps]
203
204
205def get_capability_mapping(operand_kinds):
206  """Returns the capability mapping from duplicated cases to canonicalized ones.
207
208  Arguments:
209    - operand_kinds: all operand kinds' grammar spec
210
211  Returns:
212    - A map mapping from duplicated capability symbols to the canonicalized
213      symbol chosen for SPIRVBase.td.
214  """
215  # Find the operand kind for capability
216  cap_kind = {}
217  for kind in operand_kinds:
218    if kind['kind'] == 'Capability':
219      cap_kind = kind
220
221  kind_cases = [
222      (case['enumerant'], case['value']) for case in cap_kind['enumerants']
223  ]
224  _, capability_mapping = uniquify_enum_cases(kind_cases)
225
226  return capability_mapping
227
228
229def get_availability_spec(enum_case, capability_mapping, for_op, for_cap):
230  """Returns the availability specification string for the given enum case.
231
232  Arguments:
233    - enum_case: the enum case to generate availability spec for. It may contain
234      'version', 'lastVersion', 'extensions', or 'capabilities'.
235    - capability_mapping: mapping from duplicated capability symbols to the
236      canonicalized symbol chosen for SPIRVBase.td.
237    - for_op: bool value indicating whether this is the availability spec for an
238      op itself.
239    - for_cap: bool value indicating whether this is the availability spec for
240      capabilities themselves.
241
242  Returns:
243    - A `let availability = [...];` string if with availability spec or
244      empty string if without availability spec
245  """
246  assert not (for_op and for_cap), 'cannot set both for_op and for_cap'
247
248  DEFAULT_MIN_VERSION = 'MinVersion<SPV_V_1_0>'
249  DEFAULT_MAX_VERSION = 'MaxVersion<SPV_V_1_5>'
250  DEFAULT_CAP = 'Capability<[]>'
251  DEFAULT_EXT = 'Extension<[]>'
252
253  min_version = enum_case.get('version', '')
254  if min_version == 'None':
255    min_version = ''
256  elif min_version:
257    min_version = 'MinVersion<SPV_V_{}>'.format(min_version.replace('.', '_'))
258  # TODO: delete this once ODS can support dialect-specific content
259  # and we can use omission to mean no requirements.
260  if for_op and not min_version:
261    min_version = DEFAULT_MIN_VERSION
262
263  max_version = enum_case.get('lastVersion', '')
264  if max_version:
265    max_version = 'MaxVersion<SPV_V_{}>'.format(max_version.replace('.', '_'))
266  # TODO: delete this once ODS can support dialect-specific content
267  # and we can use omission to mean no requirements.
268  if for_op and not max_version:
269    max_version = DEFAULT_MAX_VERSION
270
271  exts = enum_case.get('extensions', [])
272  if exts:
273    exts = 'Extension<[{}]>'.format(', '.join(sorted(set(exts))))
274    # We need to strip the minimal version requirement if this symbol is
275    # available via an extension, which means *any* SPIR-V version can support
276    # it as long as the extension is provided. The grammar's 'version' field
277    # under such case should be interpreted as this symbol is introduced as
278    # a core symbol since the given version, rather than a minimal version
279    # requirement.
280    min_version = DEFAULT_MIN_VERSION if for_op else ''
281  # TODO: delete this once ODS can support dialect-specific content
282  # and we can use omission to mean no requirements.
283  if for_op and not exts:
284    exts = DEFAULT_EXT
285
286  caps = enum_case.get('capabilities', [])
287  implies = ''
288  if caps:
289    canonicalized_caps = []
290    for c in caps:
291      if c in capability_mapping:
292        canonicalized_caps.append(capability_mapping[c])
293      else:
294        canonicalized_caps.append(c)
295    prefixed_caps = [
296        'SPV_C_{}'.format(c) for c in sorted(set(canonicalized_caps))
297    ]
298    if for_cap:
299      # If this is generating the availability for capabilities, we need to
300      # put the capability "requirements" in implies field because now
301      # the "capabilities" field in the source grammar means so.
302      caps = ''
303      implies = 'list<I32EnumAttrCase> implies = [{}];'.format(
304          ', '.join(prefixed_caps))
305    else:
306      caps = 'Capability<[{}]>'.format(', '.join(prefixed_caps))
307      implies = ''
308  # TODO: delete this once ODS can support dialect-specific content
309  # and we can use omission to mean no requirements.
310  if for_op and not caps:
311    caps = DEFAULT_CAP
312
313  avail = ''
314  # Compose availability spec if any of the requirements is not empty.
315  # For ops, because we have a default in SPV_Op class, omit if the spec
316  # is the same.
317  if (min_version or max_version or caps or exts) and not (
318      for_op and min_version == DEFAULT_MIN_VERSION and
319      max_version == DEFAULT_MAX_VERSION and caps == DEFAULT_CAP and
320      exts == DEFAULT_EXT):
321    joined_spec = ',\n    '.join(
322        [e for e in [min_version, max_version, exts, caps] if e])
323    avail = '{} availability = [\n    {}\n  ];'.format(
324        'let' if for_op else 'list<Availability>', joined_spec)
325
326  return '{}{}{}'.format(implies, '\n  ' if implies and avail else '', avail)
327
328
329def gen_operand_kind_enum_attr(operand_kind, capability_mapping):
330  """Generates the TableGen EnumAttr definition for the given operand kind.
331
332  Returns:
333    - The operand kind's name
334    - A string containing the TableGen EnumAttr definition
335  """
336  if 'enumerants' not in operand_kind:
337    return '', ''
338
339  # Returns a symbol for the given case in the given kind. This function
340  # handles Dim specially to avoid having numbers as the start of symbols,
341  # which does not play well with C++ and the MLIR parser.
342  def get_case_symbol(kind_name, case_name):
343    if kind_name == 'Dim':
344      if case_name == '1D' or case_name == '2D' or case_name == '3D':
345        return 'Dim{}'.format(case_name)
346    return case_name
347
348  kind_name = operand_kind['kind']
349  is_bit_enum = operand_kind['category'] == 'BitEnum'
350  kind_category = 'Bit' if is_bit_enum else 'I32'
351  kind_acronym = ''.join([c for c in kind_name if c >= 'A' and c <= 'Z'])
352
353  name_to_case_dict = {}
354  for case in operand_kind['enumerants']:
355    name_to_case_dict[case['enumerant']] = case
356
357  if kind_name == 'Capability':
358    # Special treatment for capability cases: we need to sort them topologically
359    # because a capability can refer to another via the 'implies' field.
360    kind_cases = toposort_capabilities(operand_kind['enumerants'],
361                                       capability_mapping)
362  else:
363    kind_cases = [(case['enumerant'], case['value'])
364                  for case in operand_kind['enumerants']]
365    kind_cases, _ = uniquify_enum_cases(kind_cases)
366  max_len = max([len(symbol) for (symbol, _) in kind_cases])
367
368  # Generate the definition for each enum case
369  fmt_str = 'def SPV_{acronym}_{case} {colon:>{offset}} '\
370            '{category}EnumAttrCase<"{symbol}", {value}>{avail}'
371  case_defs = []
372  for case in kind_cases:
373    avail = get_availability_spec(name_to_case_dict[case[0]],
374                                  capability_mapping,
375                                  False, kind_name == 'Capability')
376    case_def = fmt_str.format(
377        category=kind_category,
378        acronym=kind_acronym,
379        case=case[0],
380        symbol=get_case_symbol(kind_name, case[0]),
381        value=case[1],
382        avail=' {{\n  {}\n}}'.format(avail) if avail else ';',
383        colon=':',
384        offset=(max_len + 1 - len(case[0])))
385    case_defs.append(case_def)
386  case_defs = '\n'.join(case_defs)
387
388  # Generate the list of enum case names
389  fmt_str = 'SPV_{acronym}_{symbol}';
390  case_names = [fmt_str.format(acronym=kind_acronym,symbol=case[0])
391                for case in kind_cases]
392
393  # Split them into sublists and concatenate into multiple lines
394  case_names = split_list_into_sublists(case_names)
395  case_names = ['{:6}'.format('') + ', '.join(sublist)
396                for sublist in case_names]
397  case_names = ',\n'.join(case_names)
398
399  # Generate the enum attribute definition
400  enum_attr = '''def SPV_{name}Attr :
401    SPV_{category}EnumAttr<"{name}", "valid SPIR-V {name}", [
402{cases}
403    ]>;'''.format(
404          name=kind_name, category=kind_category, cases=case_names)
405  return kind_name, case_defs + '\n\n' + enum_attr
406
407
408def gen_opcode(instructions):
409  """ Generates the TableGen definition to map opname to opcode
410
411  Returns:
412    - A string containing the TableGen SPV_OpCode definition
413  """
414
415  max_len = max([len(inst['opname']) for inst in instructions])
416  def_fmt_str = 'def SPV_OC_{name} {colon:>{offset}} '\
417            'I32EnumAttrCase<"{name}", {value}>;'
418  opcode_defs = [
419      def_fmt_str.format(
420          name=inst['opname'],
421          value=inst['opcode'],
422          colon=':',
423          offset=(max_len + 1 - len(inst['opname']))) for inst in instructions
424  ]
425  opcode_str = '\n'.join(opcode_defs)
426
427  decl_fmt_str = 'SPV_OC_{name}'
428  opcode_list = [
429      decl_fmt_str.format(name=inst['opname']) for inst in instructions
430  ]
431  opcode_list = split_list_into_sublists(opcode_list)
432  opcode_list = [
433      '{:6}'.format('') + ', '.join(sublist) for sublist in opcode_list
434  ]
435  opcode_list = ',\n'.join(opcode_list)
436  enum_attr = 'def SPV_OpcodeAttr :\n'\
437              '    SPV_I32EnumAttr<"{name}", "valid SPIR-V instructions", [\n'\
438              '{lst}\n'\
439              '    ]>;'.format(name='Opcode', lst=opcode_list)
440  return opcode_str + '\n\n' + enum_attr
441
442def map_cap_to_opnames(instructions):
443  """Maps capabilities to instructions enabled by those capabilities
444
445  Arguments:
446    - instructions: a list containing a subset of SPIR-V instructions' grammar
447  Returns:
448    - A map with keys representing capabilities and values of lists of
449    instructions enabled by the corresponding key
450  """
451  cap_to_inst = {}
452
453  for inst in instructions:
454    caps = inst['capabilities'] if 'capabilities' in inst else ['0_core_0']
455    for cap in caps:
456      if cap not in cap_to_inst:
457        cap_to_inst[cap] = []
458      cap_to_inst[cap].append(inst['opname'])
459
460  return cap_to_inst
461
462def gen_instr_coverage_report(path, instructions):
463  """Dumps to standard output a YAML report of current instruction coverage
464
465  Arguments:
466    - path: the path to SPIRBase.td
467    - instructions: a list containing all SPIR-V instructions' grammar
468  """
469  with open(path, 'r') as f:
470    content = f.read()
471
472  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
473
474  existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])]
475  existing_instructions = list(
476          filter(lambda inst: (inst['opname'] in existing_opcodes),
477              instructions))
478
479  instructions_opnames = [inst['opname'] for inst in instructions]
480
481  remaining_opcodes = list(set(instructions_opnames) - set(existing_opcodes))
482  remaining_instructions = list(
483          filter(lambda inst: (inst['opname'] in remaining_opcodes),
484              instructions))
485
486  rem_cap_to_instr = map_cap_to_opnames(remaining_instructions)
487  ex_cap_to_instr = map_cap_to_opnames(existing_instructions)
488
489  rem_cap_to_cov = {}
490
491  # Calculate coverage for each capability
492  for cap in rem_cap_to_instr:
493    if cap not in ex_cap_to_instr:
494      rem_cap_to_cov[cap] = 0.0
495    else:
496      rem_cap_to_cov[cap] = \
497              (len(ex_cap_to_instr[cap]) / (len(ex_cap_to_instr[cap]) \
498              + len(rem_cap_to_instr[cap])))
499
500  report = {}
501
502  # Merge the 3 maps into one report
503  for cap in rem_cap_to_instr:
504    report[cap] = {}
505    report[cap]['Supported Instructions'] = \
506            ex_cap_to_instr[cap] if cap in ex_cap_to_instr else []
507    report[cap]['Unsupported Instructions']  = rem_cap_to_instr[cap]
508    report[cap]['Coverage'] = '{}%'.format(int(rem_cap_to_cov[cap] * 100))
509
510  print(yaml.dump(report))
511
512def update_td_opcodes(path, instructions, filter_list):
513  """Updates SPIRBase.td with new generated opcode cases.
514
515  Arguments:
516    - path: the path to SPIRBase.td
517    - instructions: a list containing all SPIR-V instructions' grammar
518    - filter_list: a list containing new opnames to add
519  """
520
521  with open(path, 'r') as f:
522    content = f.read()
523
524  content = content.split(AUTOGEN_OPCODE_SECTION_MARKER)
525  assert len(content) == 3
526
527  # Extend opcode list with existing list
528  existing_opcodes = [k[11:] for k in re.findall('def SPV_OC_\w+', content[1])]
529  filter_list.extend(existing_opcodes)
530  filter_list = list(set(filter_list))
531
532  # Generate the opcode for all instructions in SPIR-V
533  filter_instrs = list(
534      filter(lambda inst: (inst['opname'] in filter_list), instructions))
535  # Sort instruction based on opcode
536  filter_instrs.sort(key=lambda inst: inst['opcode'])
537  opcode = gen_opcode(filter_instrs)
538
539  # Substitute the opcode
540  content = content[0] + AUTOGEN_OPCODE_SECTION_MARKER + '\n\n' + \
541        opcode + '\n\n// End ' + AUTOGEN_OPCODE_SECTION_MARKER \
542        + content[2]
543
544  with open(path, 'w') as f:
545    f.write(content)
546
547
548def update_td_enum_attrs(path, operand_kinds, filter_list):
549  """Updates SPIRBase.td with new generated enum definitions.
550
551  Arguments:
552    - path: the path to SPIRBase.td
553    - operand_kinds: a list containing all operand kinds' grammar
554    - filter_list: a list containing new enums to add
555  """
556  with open(path, 'r') as f:
557    content = f.read()
558
559  content = content.split(AUTOGEN_ENUM_SECTION_MARKER)
560  assert len(content) == 3
561
562  # Extend filter list with existing enum definitions
563  existing_kinds = [
564      k[8:-4] for k in re.findall('def SPV_\w+Attr', content[1])]
565  filter_list.extend(existing_kinds)
566
567  capability_mapping = get_capability_mapping(operand_kinds)
568
569  # Generate definitions for all enums in filter list
570  defs = [
571      gen_operand_kind_enum_attr(kind, capability_mapping)
572      for kind in operand_kinds
573      if kind['kind'] in filter_list
574  ]
575  # Sort alphabetically according to enum name
576  defs.sort(key=lambda enum : enum[0])
577  # Only keep the definitions from now on
578  # Put Capability's definition at the very beginning because capability cases
579  # will be referenced later
580  defs = [enum[1] for enum in defs if enum[0] == 'Capability'
581         ] + [enum[1] for enum in defs if enum[0] != 'Capability']
582
583  # Substitute the old section
584  content = content[0] + AUTOGEN_ENUM_SECTION_MARKER + '\n\n' + \
585      '\n\n'.join(defs) + "\n\n// End " + AUTOGEN_ENUM_SECTION_MARKER  \
586      + content[2];
587
588  with open(path, 'w') as f:
589    f.write(content)
590
591
592def snake_casify(name):
593  """Turns the given name to follow snake_case convention."""
594  name = re.sub('\W+', '', name).split()
595  name = [s.lower() for s in name]
596  return '_'.join(name)
597
598
599def map_spec_operand_to_ods_argument(operand):
600  """Maps an operand in SPIR-V JSON spec to an op argument in ODS.
601
602  Arguments:
603    - A dict containing the operand's kind, quantifier, and name
604
605  Returns:
606    - A string containing both the type and name for the argument
607  """
608  kind = operand['kind']
609  quantifier = operand.get('quantifier', '')
610
611  # These instruction "operands" are for encoding the results; they should
612  # not be handled here.
613  assert kind != 'IdResultType', 'unexpected to handle "IdResultType" kind'
614  assert kind != 'IdResult', 'unexpected to handle "IdResult" kind'
615
616  if kind == 'IdRef':
617    if quantifier == '':
618      arg_type = 'SPV_Type'
619    elif quantifier == '?':
620      arg_type = 'Optional<SPV_Type>'
621    else:
622      arg_type = 'Variadic<SPV_Type>'
623  elif kind == 'IdMemorySemantics' or kind == 'IdScope':
624    # TODO: Need to further constrain 'IdMemorySemantics'
625    # and 'IdScope' given that they should be generated from OpConstant.
626    assert quantifier == '', ('unexpected to have optional/variadic memory '
627                              'semantics or scope <id>')
628    arg_type = 'SPV_' + kind[2:] + 'Attr'
629  elif kind == 'LiteralInteger':
630    if quantifier == '':
631      arg_type = 'I32Attr'
632    elif quantifier == '?':
633      arg_type = 'OptionalAttr<I32Attr>'
634    else:
635      arg_type = 'OptionalAttr<I32ArrayAttr>'
636  elif kind == 'LiteralString' or \
637      kind == 'LiteralContextDependentNumber' or \
638      kind == 'LiteralExtInstInteger' or \
639      kind == 'LiteralSpecConstantOpInteger' or \
640      kind == 'PairLiteralIntegerIdRef' or \
641      kind == 'PairIdRefLiteralInteger' or \
642      kind == 'PairIdRefIdRef':
643    assert False, '"{}" kind unimplemented'.format(kind)
644  else:
645    # The rest are all enum operands that we represent with op attributes.
646    assert quantifier != '*', 'unexpected to have variadic enum attribute'
647    arg_type = 'SPV_{}Attr'.format(kind)
648    if quantifier == '?':
649      arg_type = 'OptionalAttr<{}>'.format(arg_type)
650
651  name = operand.get('name', '')
652  name = snake_casify(name) if name else kind.lower()
653
654  return '{}:${}'.format(arg_type, name)
655
656
657def get_description(text, appendix):
658  """Generates the description for the given SPIR-V instruction.
659
660  Arguments:
661    - text: Textual description of the operation as string.
662    - appendix: Additional contents to attach in description as string,
663                includking IR examples, and others.
664
665  Returns:
666    - A string that corresponds to the description of the Tablegen op.
667  """
668  fmt_str = '{text}\n\n    <!-- End of AutoGen section -->\n{appendix}\n  '
669  return fmt_str.format(text=text, appendix=appendix)
670
671
672def get_op_definition(instruction, doc, existing_info, capability_mapping):
673  """Generates the TableGen op definition for the given SPIR-V instruction.
674
675  Arguments:
676    - instruction: the instruction's SPIR-V JSON grammar
677    - doc: the instruction's SPIR-V HTML doc
678    - existing_info: a dict containing potential manually specified sections for
679      this instruction
680    - capability_mapping: mapping from duplicated capability symbols to the
681                   canonicalized symbol chosen for SPIRVBase.td
682
683  Returns:
684    - A string containing the TableGen op definition
685  """
686  fmt_str = ('def SPV_{opname}Op : '
687             'SPV_{inst_category}<"{opname}"{category_args}[{traits}]> '
688             '{{\n  let summary = {summary};\n\n  let description = '
689             '[{{\n{description}}}];{availability}\n')
690  inst_category = existing_info.get('inst_category', 'Op')
691  if inst_category == 'Op':
692    fmt_str +='\n  let arguments = (ins{args});\n\n'\
693              '  let results = (outs{results});\n'
694
695  fmt_str +='{extras}'\
696            '}}\n'
697
698  opname = instruction['opname'][2:]
699  category_args = existing_info.get('category_args', '')
700
701  if '\n' in doc:
702    summary, text = doc.split('\n', 1)
703  else:
704    summary = doc
705    text = ''
706  wrapper = textwrap.TextWrapper(
707      width=76, initial_indent='    ', subsequent_indent='    ')
708
709  # Format summary. If the summary can fit in the same line, we print it out
710  # as a "-quoted string; otherwise, wrap the lines using "[{...}]".
711  summary = summary.strip();
712  if len(summary) + len('  let summary = "";') <= 80:
713    summary = '"{}"'.format(summary)
714  else:
715    summary = '[{{\n{}\n  }}]'.format(wrapper.fill(summary))
716
717  # Wrap text
718  text = text.split('\n')
719  text = [wrapper.fill(line) for line in text if line]
720  text = '\n\n'.join(text)
721
722  operands = instruction.get('operands', [])
723
724  # Op availability
725  avail = get_availability_spec(instruction, capability_mapping, True, False)
726  if avail:
727    avail = '\n\n  {0}'.format(avail)
728
729  # Set op's result
730  results = ''
731  if len(operands) > 0 and operands[0]['kind'] == 'IdResultType':
732    results = '\n    SPV_Type:$result\n  '
733    operands = operands[1:]
734  if 'results' in existing_info:
735    results = existing_info['results']
736
737  # Ignore the operand standing for the result <id>
738  if len(operands) > 0 and operands[0]['kind'] == 'IdResult':
739    operands = operands[1:]
740
741  # Set op' argument
742  arguments = existing_info.get('arguments', None)
743  if arguments is None:
744    arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
745    arguments = ',\n    '.join(arguments)
746    if arguments:
747      # Prepend and append whitespace for formatting
748      arguments = '\n    {}\n  '.format(arguments)
749
750  description = existing_info.get('description', None)
751  if description is None:
752    assembly = '\n    ```\n'\
753               '    [TODO]\n'\
754               '    ```mlir\n\n'\
755               '    #### Example:\n\n'\
756               '    ```\n'\
757               '    [TODO]\n' \
758               '    ```'
759    description = get_description(text, assembly)
760
761  return fmt_str.format(
762      opname=opname,
763      category_args=category_args,
764      inst_category=inst_category,
765      traits=existing_info.get('traits', ''),
766      summary=summary,
767      description=description,
768      availability=avail,
769      args=arguments,
770      results=results,
771      extras=existing_info.get('extras', ''))
772
773
774def get_string_between(base, start, end):
775  """Extracts a substring with a specified start and end from a string.
776
777  Arguments:
778    - base: string to extract from.
779    - start: string to use as the start of the substring.
780    - end: string to use as the end of the substring.
781
782  Returns:
783    - The substring if found
784    - The part of the base after end of the substring. Is the base string itself
785      if the substring wasnt found.
786  """
787  split = base.split(start, 1)
788  if len(split) == 2:
789    rest = split[1].split(end, 1)
790    assert len(rest) == 2, \
791           'cannot find end "{end}" while extracting substring '\
792           'starting with {start}'.format(start=start, end=end)
793    return rest[0].rstrip(end), rest[1]
794  return '', split[0]
795
796
797def get_string_between_nested(base, start, end):
798  """Extracts a substring with a nested start and end from a string.
799
800  Arguments:
801    - base: string to extract from.
802    - start: string to use as the start of the substring.
803    - end: string to use as the end of the substring.
804
805  Returns:
806    - The substring if found
807    - The part of the base after end of the substring. Is the base string itself
808      if the substring wasn't found.
809  """
810  split = base.split(start, 1)
811  if len(split) == 2:
812    # Handle nesting delimiters
813    rest = split[1]
814    unmatched_start = 1
815    index = 0
816    while unmatched_start > 0 and index < len(rest):
817      if rest[index:].startswith(end):
818        unmatched_start -= 1
819        if unmatched_start == 0:
820          break
821        index += len(end)
822      elif rest[index:].startswith(start):
823        unmatched_start += 1
824        index += len(start)
825      else:
826        index += 1
827
828    assert index < len(rest), \
829           'cannot find end "{end}" while extracting substring '\
830           'starting with "{start}"'.format(start=start, end=end)
831    return rest[:index], rest[index + len(end):]
832  return '', split[0]
833
834
835def extract_td_op_info(op_def):
836  """Extracts potentially manually specified sections in op's definition.
837
838  Arguments: - A string containing the op's TableGen definition
839
840  Returns:
841    - A dict containing potential manually specified sections
842  """
843  # Get opname
844  opname = [o[8:-2] for o in re.findall('def SPV_\w+Op', op_def)]
845  assert len(opname) == 1, 'more than one ops in the same section!'
846  opname = opname[0]
847
848  # Get instruction category
849  inst_category = [
850      o[4:] for o in re.findall('SPV_\w+Op',
851                                op_def.split(':', 1)[1])
852  ]
853  assert len(inst_category) <= 1, 'more than one ops in the same section!'
854  inst_category = inst_category[0] if len(inst_category) == 1 else 'Op'
855
856  # Get category_args
857  op_tmpl_params, _ = get_string_between_nested(op_def, '<', '>')
858  opstringname, rest = get_string_between(op_tmpl_params, '"', '"')
859  category_args = rest.split('[', 1)[0]
860
861  # Get traits
862  traits, _ = get_string_between_nested(rest, '[', ']')
863
864  # Get description
865  description, rest = get_string_between(op_def, 'let description = [{\n',
866                                         '}];\n')
867
868  # Get arguments
869  args, rest = get_string_between(rest, '  let arguments = (ins', ');\n')
870
871  # Get results
872  results, rest = get_string_between(rest, '  let results = (outs', ');\n')
873
874  extras = rest.strip(' }\n')
875  if extras:
876    extras = '\n  {}\n'.format(extras)
877
878  return {
879      # Prefix with 'Op' to make it consistent with SPIR-V spec
880      'opname': 'Op{}'.format(opname),
881      'inst_category': inst_category,
882      'category_args': category_args,
883      'traits': traits,
884      'description': description,
885      'arguments': args,
886      'results': results,
887      'extras': extras
888  }
889
890
891def update_td_op_definitions(path, instructions, docs, filter_list,
892                             inst_category, capability_mapping):
893  """Updates SPIRVOps.td with newly generated op definition.
894
895  Arguments:
896    - path: path to SPIRVOps.td
897    - instructions: SPIR-V JSON grammar for all instructions
898    - docs: SPIR-V HTML doc for all instructions
899    - filter_list: a list containing new opnames to include
900    - capability_mapping: mapping from duplicated capability symbols to the
901                   canonicalized symbol chosen for SPIRVBase.td.
902
903  Returns:
904    - A string containing all the TableGen op definitions
905  """
906  with open(path, 'r') as f:
907    content = f.read()
908
909  # Split the file into chunks, each containing one op.
910  ops = content.split(AUTOGEN_OP_DEF_SEPARATOR)
911  header = ops[0]
912  footer = ops[-1]
913  ops = ops[1:-1]
914
915  # For each existing op, extract the manually-written sections out to retain
916  # them when re-generating the ops. Also append the existing ops to filter
917  # list.
918  name_op_map = {}  # Map from opname to its existing ODS definition
919  op_info_dict = {}
920  for op in ops:
921    info_dict = extract_td_op_info(op)
922    opname = info_dict['opname']
923    name_op_map[opname] = op
924    op_info_dict[opname] = info_dict
925    filter_list.append(opname)
926  filter_list = sorted(list(set(filter_list)))
927
928  op_defs = []
929  for opname in filter_list:
930    # Find the grammar spec for this op
931    try:
932      instruction = next(
933          inst for inst in instructions if inst['opname'] == opname)
934      op_defs.append(
935          get_op_definition(
936              instruction, docs[opname],
937              op_info_dict.get(opname, {'inst_category': inst_category}),
938              capability_mapping))
939    except StopIteration:
940      # This is an op added by us; use the existing ODS definition.
941      op_defs.append(name_op_map[opname])
942
943  # Substitute the old op definitions
944  op_defs = [header] + op_defs + [footer]
945  content = AUTOGEN_OP_DEF_SEPARATOR.join(op_defs)
946
947  with open(path, 'w') as f:
948    f.write(content)
949
950
951if __name__ == '__main__':
952  import argparse
953
954  cli_parser = argparse.ArgumentParser(
955      description='Update SPIR-V dialect definitions using SPIR-V spec')
956
957  cli_parser.add_argument(
958      '--base-td-path',
959      dest='base_td_path',
960      type=str,
961      default=None,
962      help='Path to SPIRVBase.td')
963  cli_parser.add_argument(
964      '--op-td-path',
965      dest='op_td_path',
966      type=str,
967      default=None,
968      help='Path to SPIRVOps.td')
969
970  cli_parser.add_argument(
971      '--new-enum',
972      dest='new_enum',
973      type=str,
974      default=None,
975      help='SPIR-V enum to be added to SPIRVBase.td')
976  cli_parser.add_argument(
977      '--new-opcodes',
978      dest='new_opcodes',
979      type=str,
980      default=None,
981      nargs='*',
982      help='update SPIR-V opcodes in SPIRVBase.td')
983  cli_parser.add_argument(
984      '--new-inst',
985      dest='new_inst',
986      type=str,
987      default=None,
988      nargs='*',
989      help='SPIR-V instruction to be added to ops file')
990  cli_parser.add_argument(
991      '--inst-category',
992      dest='inst_category',
993      type=str,
994      default='Op',
995      help='SPIR-V instruction category used for choosing '\
996           'the TableGen base class to define this op')
997  cli_parser.add_argument('--gen-inst-coverage', dest='gen_inst_coverage', action='store_true')
998  cli_parser.set_defaults(gen_inst_coverage=False)
999
1000  args = cli_parser.parse_args()
1001
1002  operand_kinds, instructions = get_spirv_grammar_from_json_spec()
1003
1004  # Define new enum attr
1005  if args.new_enum is not None:
1006    assert args.base_td_path is not None
1007    filter_list = [args.new_enum] if args.new_enum else []
1008    update_td_enum_attrs(args.base_td_path, operand_kinds, filter_list)
1009
1010  # Define new opcode
1011  if args.new_opcodes is not None:
1012    assert args.base_td_path is not None
1013    update_td_opcodes(args.base_td_path, instructions, args.new_opcodes)
1014
1015  # Define new op
1016  if args.new_inst is not None:
1017    assert args.op_td_path is not None
1018    docs = get_spirv_doc_from_html_spec()
1019    capability_mapping = get_capability_mapping(operand_kinds)
1020    update_td_op_definitions(args.op_td_path, instructions, docs, args.new_inst,
1021                             args.inst_category, capability_mapping)
1022    print('Done. Note that this script just generates a template; ', end='')
1023    print('please read the spec and update traits, arguments, and ', end='')
1024    print('results accordingly.')
1025
1026  if args.gen_inst_coverage:
1027    gen_instr_coverage_report(args.base_td_path, instructions)
1028