1"""
2Copyright (C) 2021 Alyssa Rosenzweig <alyssa@rosenzweig.io>
3
4Permission is hereby granted, free of charge, to any person obtaining a
5copy of this software and associated documentation files (the "Software"),
6to deal in the Software without restriction, including without limitation
7the rights to use, copy, modify, merge, publish, distribute, sublicense,
8and/or sell copies of the Software, and to permit persons to whom the
9Software is furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice (including the next
12paragraph) shall be included in all copies or substantial portions of the
13Software.
14
15THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21SOFTWARE.
22"""
23
24opcodes = {}
25immediates = {}
26enums = {}
27
28class Opcode(object):
29   def __init__(self, name, dests, srcs, imms, is_float, can_eliminate, encoding_16, encoding_32):
30      self.name = name
31      self.dests = dests
32      self.srcs = srcs
33      self.imms = imms
34      self.is_float = is_float
35      self.can_eliminate = can_eliminate
36      self.encoding_16 = encoding_16
37      self.encoding_32 = encoding_32
38
39class Immediate(object):
40   def __init__(self, name, ctype):
41      self.name = name
42      self.ctype = ctype
43
44class Encoding(object):
45   def __init__(self, description):
46      (exact, mask, length_short, length_long) = description
47
48      # Convenience
49      if length_long is None:
50         length_long = length_short
51
52      self.exact = exact
53      self.mask = mask
54      self.length_short = length_short
55      self.extensible = length_short != length_long
56
57      if self.extensible:
58         assert(length_long == length_short + (4 if length_short > 8 else 2))
59
60def op(name, encoding_32, dests = 1, srcs = 0, imms = [], is_float = False, can_eliminate = True, encoding_16 = None):
61   encoding_16 = Encoding(encoding_16) if encoding_16 is not None else None
62   encoding_32 = Encoding(encoding_32) if encoding_32 is not None else None
63
64   opcodes[name] = Opcode(name, dests, srcs, imms, is_float, can_eliminate, encoding_16, encoding_32)
65
66def immediate(name, ctype = "uint32_t"):
67   imm = Immediate(name, ctype)
68   immediates[name] = imm
69   return imm
70
71def enum(name, value_dict):
72   enums[name] = value_dict
73   return immediate(name, "enum agx_" + name)
74
75L = (1 << 15)
76_ = None
77
78FORMAT = immediate("format", "enum agx_format")
79IMM = immediate("imm")
80WRITEOUT = immediate("writeout")
81INDEX = immediate("index")
82COMPONENT = immediate("component")
83CHANNELS = immediate("channels")
84TRUTH_TABLE = immediate("truth_table")
85ROUND = immediate("round")
86SHIFT = immediate("shift")
87MASK = immediate("mask")
88BFI_MASK = immediate("bfi_mask")
89LOD_MODE = immediate("lod_mode", "enum agx_lod_mode")
90DIM = immediate("dim", "enum agx_dim")
91SCOREBOARD = immediate("scoreboard")
92ICOND = immediate("icond")
93FCOND = immediate("fcond")
94NEST = immediate("nest")
95INVERT_COND = immediate("invert_cond")
96NEST = immediate("nest")
97TARGET = immediate("target", "agx_block *")
98PERSPECTIVE = immediate("perspective", "bool")
99SR = enum("sr", {
100   0:  'threadgroup_position_in_grid.x',
101   1:  'threadgroup_position_in_grid.y',
102   2:  'threadgroup_position_in_grid.z',
103   4:  'threads_per_threadgroup.x',
104   5:  'threads_per_threadgroup.y',
105   6:  'threads_per_threadgroup.z',
106   8:  'dispatch_threads_per_threadgroup.x',
107   9:  'dispatch_threads_per_threadgroup.y',
108   10: 'dispatch_threads_per_threadgroup.z',
109   48: 'thread_position_in_threadgroup.x',
110   49: 'thread_position_in_threadgroup.y',
111   50: 'thread_position_in_threadgroup.z',
112   51: 'thread_index_in_threadgroup',
113   52: 'thread_index_in_subgroup',
114   53: 'subgroup_index_in_threadgroup',
115   56: 'active_thread_index_in_quad',
116   58: 'active_thread_index_in_subgroup',
117   62: 'backfacing',
118   80: 'thread_position_in_grid.x',
119   81: 'thread_position_in_grid.y',
120   82: 'thread_position_in_grid.z',
121})
122
123FUNOP = lambda x: (x << 28)
124FUNOP_MASK = FUNOP((1 << 14) - 1)
125
126def funop(name, opcode):
127   op(name, (0x0A | L | (opcode << 28),
128      0x3F | L | (((1 << 14) - 1) << 28), 6, _),
129      srcs = 1, is_float = True)
130
131# Listing of opcodes
132funop("floor",     0b000000)
133funop("srsqrt",    0b000001)
134funop("dfdx",      0b000100)
135funop("dfdy",      0b000110)
136funop("rcp",       0b001000)
137funop("rsqrt",     0b001001)
138funop("sin_pt_1",  0b001010)
139funop("log2",      0b001100)
140funop("exp2",      0b001101)
141funop("sin_pt_2",  0b001110)
142funop("ceil",      0b010000)
143funop("trunc",     0b100000)
144funop("roundeven", 0b110000)
145
146op("fadd",
147      encoding_16 = (0x26 | L, 0x3F | L, 6, _),
148      encoding_32 = (0x2A | L, 0x3F | L, 6, _),
149      srcs = 2, is_float = True)
150
151op("fma",
152      encoding_16 = (0x36, 0x3F, 6, 8),
153      encoding_32 = (0x3A, 0x3F, 6, 8),
154      srcs = 3, is_float = True)
155
156op("fmul",
157      encoding_16 = ((0x16 | L), (0x3F | L), 6, _),
158      encoding_32 = ((0x1A | L), (0x3F | L), 6, _),
159      srcs = 2, is_float = True)
160
161op("mov_imm",
162      encoding_32 = (0x62, 0xFF, 6, 8),
163      encoding_16 = (0x62, 0xFF, 4, 6),
164      imms = [IMM])
165
166op("iadd",
167      encoding_32 = (0x0E, 0x3F | L, 8, _),
168      srcs = 2, imms = [SHIFT])
169
170op("imad",
171      encoding_32 = (0x1E, 0x3F | L, 8, _),
172      srcs = 3, imms = [SHIFT])
173
174op("bfi",
175      encoding_32 = (0x2E, 0x7F | (0x3 << 26), 8, _),
176      srcs = 3, imms = [BFI_MASK])
177
178op("bfeil",
179      encoding_32 = (0x2E | L, 0x7F | L | (0x3 << 26), 8, _),
180      srcs = 3, imms = [BFI_MASK])
181
182op("asr",
183      encoding_32 = (0x2E | L | (0x1 << 26), 0x7F | L | (0x3 << 26), 8, _),
184      srcs = 2)
185
186op("icmpsel",
187      encoding_32 = (0x12, 0x7F, 8, 10),
188      srcs = 4, imms = [ICOND])
189
190op("fcmpsel",
191      encoding_32 = (0x02, 0x7F, 8, 10),
192      srcs = 4, imms = [FCOND])
193
194# sources are coordinates, LOD, texture, sampler, offset
195# TODO: anything else?
196op("texture_sample",
197      encoding_32 = (0x32, 0x7F, 8, 10), # XXX WRONG SIZE
198      srcs = 5, imms = [DIM, LOD_MODE, MASK, SCOREBOARD])
199
200# sources are base, index
201op("device_load",
202      encoding_32 = (0x05, 0x7F, 6, 8),
203      srcs = 2, imms = [FORMAT, MASK, SCOREBOARD])
204
205op("wait", (0x38, 0xFF, 2, _), dests = 0,
206      can_eliminate = False, imms = [SCOREBOARD])
207
208op("get_sr", (0x72, 0x7F | L, 4, _), dests = 1, imms = [SR])
209
210# Essentially same encoding
211op("ld_tile", (0x49, 0x7F, 8, _), dests = 1, srcs = 0,
212      can_eliminate = False, imms = [FORMAT])
213
214op("st_tile", (0x09, 0x7F, 8, _), dests = 0, srcs = 1,
215      can_eliminate = False, imms = [FORMAT])
216
217for (name, exact) in [("any", 0xC000), ("none", 0xC200)]:
218   op("jmp_exec_" + name, (exact, (1 << 16) - 1, 6, _), dests = 0, srcs = 0,
219         can_eliminate = False, imms = [TARGET])
220
221# TODO: model implicit r0l destinations
222op("pop_exec", (0x52 | (0x3 << 9), ((1 << 48) - 1) ^ (0x3 << 7) ^ (0x3 << 11), 6, _),
223      dests = 0, srcs = 0, can_eliminate = False, imms = [NEST])
224
225for is_float in [False, True]:
226   mod_mask = 0 if is_float else (0x3 << 26) | (0x3 << 38)
227
228   for (cf, cf_op) in [("if", 0), ("else", 1), ("while", 2)]:
229      name = "{}_{}cmp".format(cf, "f" if is_float else "i")
230      exact = 0x42 | (0x0 if is_float else 0x10) | (cf_op << 9)
231      mask = 0x7F | (0x3 << 9) | mod_mask | (0x3 << 44)
232      imms = [NEST, FCOND if is_float else ICOND, INVERT_COND]
233
234      op(name, (exact, mask, 6, _), dests = 0, srcs = 2, can_eliminate = False,
235            imms = imms, is_float = is_float)
236
237op("bitop", (0x7E, 0x7F, 6, _), srcs = 2, imms = [TRUTH_TABLE])
238op("convert", (0x3E | L, 0x7F | L | (0x3 << 38), 6, _), srcs = 2, imms = [ROUND])
239op("ld_vary", (0x21, 0xBF, 8, _), srcs = 1, imms = [CHANNELS, PERSPECTIVE])
240op("ld_vary_flat", (0xA1, 0xBF, 8, _), srcs = 1, imms = [CHANNELS])
241op("st_vary", None, dests = 0, srcs = 2, can_eliminate = False)
242op("stop", (0x88, 0xFFFF, 2, _), dests = 0, can_eliminate = False)
243op("trap", (0x08, 0xFFFF, 2, _), dests = 0, can_eliminate = False)
244op("writeout", (0x48, 0xFF, 4, _), dests = 0, imms = [WRITEOUT], can_eliminate = False)
245
246op("p_combine", _, srcs = 4)
247op("p_extract", _, srcs = 1, imms = [COMPONENT])
248