1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #ifndef SkVM_DEFINED
9 #define SkVM_DEFINED
10 
11 #include "include/core/SkTypes.h"
12 #include "include/private/SkTHash.h"
13 #include <functional>  // std::hash
14 #include <vector>      // std::vector
15 
16 class SkWStream;
17 
18 namespace skvm {
19 
20     class Assembler {
21     public:
22         explicit Assembler(void* buf);
23 
24         size_t size() const;
25 
26         // Order matters... GP64, Xmm, Ymm values match 4-bit register encoding for each.
27         enum GP64 {
28             rax, rcx, rdx, rbx, rsp, rbp, rsi, rdi,
29             r8 , r9 , r10, r11, r12, r13, r14, r15,
30         };
31         enum Xmm {
32             xmm0, xmm1, xmm2 , xmm3 , xmm4 , xmm5 , xmm6 , xmm7 ,
33             xmm8, xmm9, xmm10, xmm11, xmm12, xmm13, xmm14, xmm15,
34         };
35         enum Ymm {
36             ymm0, ymm1, ymm2 , ymm3 , ymm4 , ymm5 , ymm6 , ymm7 ,
37             ymm8, ymm9, ymm10, ymm11, ymm12, ymm13, ymm14, ymm15,
38         };
39 
40         // X and V values match 5-bit encoding for each (nothing tricky).
41         enum X {
42             x0 , x1 , x2 , x3 , x4 , x5 , x6 , x7 ,
43             x8 , x9 , x10, x11, x12, x13, x14, x15,
44             x16, x17, x18, x19, x20, x21, x22, x23,
45             x24, x25, x26, x27, x28, x29, x30, xzr,
46         };
47         enum V {
48             v0 , v1 , v2 , v3 , v4 , v5 , v6 , v7 ,
49             v8 , v9 , v10, v11, v12, v13, v14, v15,
50             v16, v17, v18, v19, v20, v21, v22, v23,
51             v24, v25, v26, v27, v28, v29, v30, v31,
52         };
53 
54         void bytes(const void*, int);
55         void byte(uint8_t);
56         void word(uint32_t);
57 
58         // x86-64
59 
60         void align(int mod);
61 
62         void vzeroupper();
63         void ret();
64 
65         void add(GP64, int imm);
66         void sub(GP64, int imm);
67 
68         // All dst = x op y.
69         using DstEqXOpY = void(Ymm dst, Ymm x, Ymm y);
70         DstEqXOpY vpand, vpor, vpxor, vpandn,
71                   vpaddd, vpsubd, vpmulld,
72                           vpsubw, vpmullw,
73                   vaddps, vsubps, vmulps, vdivps,
74                   vfmadd132ps, vfmadd213ps, vfmadd231ps,
75                   vpackusdw, vpackuswb,
76                   vpcmpeqd, vpcmpgtd;
77 
78         using DstEqXOpImm = void(Ymm dst, Ymm x, int imm);
79         DstEqXOpImm vpslld, vpsrld, vpsrad,
80                     vpsrlw,
81                     vpermq;
82 
83         using DstEqOpX = void(Ymm dst, Ymm x);
84         DstEqOpX vmovdqa, vcvtdq2ps, vcvttps2dq;
85 
86         void vpblendvb(Ymm dst, Ymm x, Ymm y, Ymm z);
87 
88         struct Label {
89             int                                 offset = 0;
90             enum { None, ARMDisp19, X86Disp32 } kind = None;
91             std::vector<int>                    references;
92         };
93 
94         Label here();
95         void label(Label*);
96 
97         void jmp(Label*);
98         void je (Label*);
99         void jne(Label*);
100         void jl (Label*);
101         void cmp(GP64, int imm);
102 
103         void vbroadcastss(Ymm dst, Label*);
104         void vbroadcastss(Ymm dst, Xmm src);
105         void vbroadcastss(Ymm dst, GP64 ptr, int off);  // dst = *(ptr+off)
106 
107         void vpshufb(Ymm dst, Ymm x, Label*);
108 
109         void vmovups  (Ymm dst, GP64 ptr);   // dst = *ptr, 256-bit
110         void vpmovzxwd(Ymm dst, GP64 ptr);   // dst = *ptr, 128-bit, each uint16_t expanded to int
111         void vpmovzxbd(Ymm dst, GP64 ptr);   // dst = *ptr,  64-bit, each uint8_t  expanded to int
112         void vmovd    (Xmm dst, GP64 ptr);   // dst = *ptr,  32-bit
113 
114         void vmovups(GP64 ptr, Ymm src);     // *ptr = src, 256-bit
115         void vmovups(GP64 ptr, Xmm src);     // *ptr = src, 128-bit
116         void vmovq  (GP64 ptr, Xmm src);     // *ptr = src,  64-bit
117         void vmovd  (GP64 ptr, Xmm src);     // *ptr = src,  32-bit
118 
119         void movzbl(GP64 dst, GP64 ptr, int off);  // dst = *(ptr+off), uint8_t -> int
120         void movb  (GP64 ptr, GP64 src);           // *ptr = src, 8-bit
121 
122         void vmovd_direct(GP64 dst, Xmm src);  // dst = src, 32-bit
123         void vmovd_direct(Xmm dst, GP64 src);  // dst = src, 32-bit
124 
125         void vpinsrw(Xmm dst, Xmm src, GP64 ptr, int imm);  // dst = src; dst[imm] = *ptr, 16-bit
126         void vpinsrb(Xmm dst, Xmm src, GP64 ptr, int imm);  // dst = src; dst[imm] = *ptr,  8-bit
127 
128         void vpextrw(GP64 ptr, Xmm src, int imm);           // *dst = src[imm]           , 16-bit
129         void vpextrb(GP64 ptr, Xmm src, int imm);           // *dst = src[imm]           ,  8-bit
130 
131         // aarch64
132 
133         // d = op(n,m)
134         using DOpNM = void(V d, V n, V m);
135         DOpNM  and16b, orr16b, eor16b, bic16b, bsl16b,
136                add4s,  sub4s,  mul4s,
137               cmeq4s, cmgt4s,
138                        sub8h,  mul8h,
139               fadd4s, fsub4s, fmul4s, fdiv4s,
140               tbl;
141 
142         // d += n*m
143         void fmla4s(V d, V n, V m);
144 
145         // d = op(n,imm)
146         using DOpNImm = void(V d, V n, int imm);
147         DOpNImm sli4s,
148                 shl4s, sshr4s, ushr4s,
149                                ushr8h;
150 
151         // d = op(n)
152         using DOpN = void(V d, V n);
153         DOpN scvtf4s,   // int -> float
154              fcvtzs4s,  // truncate float -> int
155              xtns2h,    // u32 -> u16
156              xtnh2b,    // u16 -> u8
157              uxtlb2h,   // u8 -> u16
158              uxtlh2s;   // u16 -> u32
159 
160         // TODO: both these platforms support rounding float->int (vcvtps2dq, fcvtns.4s)... use?
161 
162         void ret (X);
163         void add (X d, X n, int imm12);
164         void sub (X d, X n, int imm12);
165         void subs(X d, X n, int imm12);  // subtract setting condition flags
166 
167         // There's another encoding for unconditional branches that can jump further,
168         // but this one encoded as b.al is simple to implement and should be fine.
b(Label * l)169         void b  (Label* l) { this->b(Condition::al, l); }
bne(Label * l)170         void bne(Label* l) { this->b(Condition::ne, l); }
blt(Label * l)171         void blt(Label* l) { this->b(Condition::lt, l); }
172 
173         // "cmp ..." is just an assembler mnemonic for "subs xzr, ..."!
cmp(X n,int imm12)174         void cmp(X n, int imm12) { this->subs(xzr, n, imm12); }
175 
176         // Compare and branch if zero/non-zero, as if
177         //      cmp(t,0)
178         //      beq/bne(l)
179         // but without setting condition flags.
180         void cbz (X t, Label* l);
181         void cbnz(X t, Label* l);
182 
183         void ldrq(V dst, Label*);  // 128-bit PC-relative load
184 
185         void ldrq(V dst, X src);  // 128-bit dst = *src
186         void ldrs(V dst, X src);  //  32-bit dst = *src
187         void ldrb(V dst, X src);  //   8-bit dst = *src
188 
189         void strq(V src, X dst);  // 128-bit *dst = src
190         void strs(V src, X dst);  //  32-bit *dst = src
191         void strb(V src, X dst);  //   8-bit *dst = src
192 
193     private:
194         // dst = op(dst, imm)
195         void op(int opcode, int opcode_ext, GP64 dst, int imm);
196 
197 
198         // dst = op(x,y) or op(x)
199         void op(int prefix, int map, int opcode, Ymm dst, Ymm x, Ymm y, bool W=false);
200         void op(int prefix, int map, int opcode, Ymm dst, Ymm x,        bool W=false) {
201             // Two arguments ops seem to pass them in dst and y, forcing x to 0 so VEX.vvvv == 1111.
202             this->op(prefix, map, opcode, dst,(Ymm)0,x, W);
203         }
204 
205         // dst = op(x,imm)
206         void op(int prefix, int map, int opcode, int opcode_ext, Ymm dst, Ymm x, int imm);
207 
208         // dst = op(x,label) or op(label)
209         void op(int prefix, int map, int opcode, Ymm dst, Ymm x, Label* l);
210 
211         // *ptr = ymm or ymm = *ptr, depending on opcode.
212         void load_store(int prefix, int map, int opcode, Ymm ymm, GP64 ptr);
213 
214         // Opcode for 3-arguments ops is split between hi and lo:
215         //    [11 bits hi] [5 bits m] [6 bits lo] [5 bits n] [5 bits d]
216         void op(uint32_t hi, V m, uint32_t lo, V n, V d);
217 
218         // 2-argument ops, with or without an immediate.
219         void op(uint32_t op22, int imm, V n, V d);
op(uint32_t op22,V n,V d)220         void op(uint32_t op22, V n, V d) { this->op(op22,0,n,d); }
op(uint32_t op22,X x,V v)221         void op(uint32_t op22, X x, V v) { this->op(op22,0,(V)x,v); }
222 
223         // Order matters... value is 4-bit encoding for condition code.
224         enum class Condition { eq,ne,cs,cc,mi,pl,vs,vc,hi,ls,ge,lt,gt,le,al };
225         void b(Condition, Label*);
226 
227         void jump(uint8_t condition, Label*);
228 
229         int disp19(Label*);
230         int disp32(Label*);
231 
232         uint8_t* fCode;
233         uint8_t* fCurr;
234         size_t   fSize;
235     };
236 
237     enum class Op : uint8_t {
238           store8,   store16,   store32,
239     // ↑ side effects / no side effects ↓
240 
241            load8,    load16,    load32,
242          gather8,  gather16,  gather32,
243     // ↑ always varying / uniforms, constants, Just Math ↓
244 
245         uniform8, uniform16, uniform32,
246         splat,
247 
248         add_f32, add_i32, add_i16x2,
249         sub_f32, sub_i32, sub_i16x2,
250         mul_f32, mul_i32, mul_i16x2,
251         div_f32,
252         mad_f32,
253                  shl_i32, shl_i16x2,
254                  shr_i32, shr_i16x2,
255                  sra_i32, sra_i16x2,
256 
257          to_i32,  to_f32,
258 
259          eq_f32,  eq_i32,  eq_i16x2,
260         neq_f32, neq_i32, neq_i16x2,
261          lt_f32,  lt_i32,  lt_i16x2,
262         lte_f32, lte_i32, lte_i16x2,
263          gt_f32,  gt_i32,  gt_i16x2,
264         gte_f32, gte_i32, gte_i16x2,
265 
266         bit_and,
267         bit_or,
268         bit_xor,
269         bit_clear,
270         select,
271 
272         bytes, extract, pack,
273     };
274 
275     using Val = int;
276     // We reserve the last Val ID as a sentinel meaning none, n/a, null, nil, etc.
277     static const Val NA = ~0;
278 
279     struct Arg { int ix; };
280     struct I32 { Val id; };
281     struct F32 { Val id; };
282 
283     class Program;
284 
285     class Builder {
286     public:
287         struct Instruction {
288             Op  op;         // v* = op(x,y,z,imm), where * == index of this Instruction.
289             Val x,y,z;      // Enough arguments for mad().
290             int imm;        // Immediate bit pattern, shift count, argument index, etc.
291 
292             // Not populated until done() has been called.
293             int  death;         // Index of last live instruction taking this input; live if != 0.
294             bool can_hoist;     // Value independent of all loop variables?
295             bool used_in_loop;  // Is the value used in the loop (or only by hoisted values)?
296         };
297 
298         Program done(const char* debug_name = nullptr);
299 
300         // Mostly for debugging, tests, etc.
program()301         std::vector<Instruction> program() const { return fProgram; }
302 
303 
304         // Declare an argument with given stride (use stride=0 for uniforms).
305         // TODO: different types for varying and uniforms?
306         Arg arg(int stride);
307 
308         // Convenience arg() wrappers for most common strides, sizeof(T) and 0.
309         template <typename T>
varying()310         Arg varying() { return this->arg(sizeof(T)); }
uniform()311         Arg uniform() { return this->arg(0); }
312 
313         // TODO: allow uniform (i.e. Arg) offsets to store* and load*?
314         // TODO: sign extension (signed types) for <32-bit loads?
315         // TODO: unsigned integer operations where relevant (just comparisons?)?
316 
317         // Store {8,16,32}-bit varying.
318         void store8 (Arg ptr, I32 val);
319         void store16(Arg ptr, I32 val);
320         void store32(Arg ptr, I32 val);
321 
322         // Load u8,u16,i32 varying.
323         I32 load8 (Arg ptr);
324         I32 load16(Arg ptr);
325         I32 load32(Arg ptr);
326 
327         // Gather u8,u16,i32 with varying element-count offset.
328         I32 gather8 (Arg ptr, I32 offset);
329         I32 gather16(Arg ptr, I32 offset);
330         I32 gather32(Arg ptr, I32 offset);
331 
332         // Load u8,u16,i32 uniform with optional byte-count offset.
333         I32 uniform8 (Arg ptr, int offset=0);
334         I32 uniform16(Arg ptr, int offset=0);
335         I32 uniform32(Arg ptr, int offset=0);
336 
337         // Load an immediate constant.
338         I32 splat(int      n);
splat(unsigned u)339         I32 splat(unsigned u) { return this->splat((int)u); }
340         F32 splat(float    f);
341 
342         // float math, comparisons, etc.
343         F32 add(F32 x, F32 y);
344         F32 sub(F32 x, F32 y);
345         F32 mul(F32 x, F32 y);
346         F32 div(F32 x, F32 y);
347         F32 mad(F32 x, F32 y, F32 z);  //  x*y+z, often an FMA
348 
349         I32 eq (F32 x, F32 y);
350         I32 neq(F32 x, F32 y);
351         I32 lt (F32 x, F32 y);
352         I32 lte(F32 x, F32 y);
353         I32 gt (F32 x, F32 y);
354         I32 gte(F32 x, F32 y);
355 
356         I32 to_i32(F32 x);
bit_cast(F32 x)357         I32 bit_cast(F32 x) { return {x.id}; }
358 
359         // int math, comparisons, etc.
360         I32 add(I32 x, I32 y);
361         I32 sub(I32 x, I32 y);
362         I32 mul(I32 x, I32 y);
363 
364         I32 shl(I32 x, int bits);
365         I32 shr(I32 x, int bits);
366         I32 sra(I32 x, int bits);
367 
368         I32 eq (I32 x, I32 y);
369         I32 neq(I32 x, I32 y);
370         I32 lt (I32 x, I32 y);
371         I32 lte(I32 x, I32 y);
372         I32 gt (I32 x, I32 y);
373         I32 gte(I32 x, I32 y);
374 
375         F32 to_f32(I32 x);
bit_cast(I32 x)376         F32 bit_cast(I32 x) { return {x.id}; }
377 
378         // Treat each 32-bit lane as a pair of 16-bit ints.
379         I32 add_16x2(I32 x, I32 y);
380         I32 sub_16x2(I32 x, I32 y);
381         I32 mul_16x2(I32 x, I32 y);
382 
383         I32 shl_16x2(I32 x, int bits);
384         I32 shr_16x2(I32 x, int bits);
385         I32 sra_16x2(I32 x, int bits);
386 
387         I32  eq_16x2(I32 x, I32 y);
388         I32 neq_16x2(I32 x, I32 y);
389         I32  lt_16x2(I32 x, I32 y);
390         I32 lte_16x2(I32 x, I32 y);
391         I32  gt_16x2(I32 x, I32 y);
392         I32 gte_16x2(I32 x, I32 y);
393 
394         // Bitwise operations.
395         I32 bit_and  (I32 x, I32 y);
396         I32 bit_or   (I32 x, I32 y);
397         I32 bit_xor  (I32 x, I32 y);
398         I32 bit_clear(I32 x, I32 y);   // x & ~y
399 
400         I32 select(I32 cond, I32 t, I32 f);  // cond ? t : f
select(I32 cond,F32 t,F32 f)401         F32 select(I32 cond, F32 t, F32 f) {
402             return this->bit_cast(this->select(cond, this->bit_cast(t)
403                                                    , this->bit_cast(f)));
404         }
405 
406         // More complex operations...
407 
408         // Shuffle the bytes in x according to each nibble of control, as if
409         //
410         //    uint8_t bytes[] = {
411         //        0,
412         //        ((uint32_t)x      ) & 0xff,
413         //        ((uint32_t)x >>  8) & 0xff,
414         //        ((uint32_t)x >> 16) & 0xff,
415         //        ((uint32_t)x >> 24) & 0xff,
416         //    };
417         //    return (uint32_t)bytes[(control >>  0) & 0xf] <<  0
418         //         | (uint32_t)bytes[(control >>  4) & 0xf] <<  8
419         //         | (uint32_t)bytes[(control >>  8) & 0xf] << 16
420         //         | (uint32_t)bytes[(control >> 12) & 0xf] << 24;
421         //
422         // So, e.g.,
423         //    - bytes(x, 0x1111) splats the low byte of x to all four bytes
424         //    - bytes(x, 0x4321) is x, an identity
425         //    - bytes(x, 0x0000) is 0
426         //    - bytes(x, 0x0404) transforms an RGBA pixel into an A0A0 bit pattern.
427         I32 bytes  (I32 x, int control);
428 
429         I32 extract(I32 x, int bits, I32 y);   // (x >> bits) & y
430         I32 pack   (I32 x, I32 y, int bits);   // x | (y << bits), assuming (x & (y << bits)) == 0
431 
432         void dump(SkWStream* = nullptr) const;
433 
434     private:
435         struct InstructionHash {
436             template <typename T>
HashInstructionHash437             static size_t Hash(T val) {
438                 return std::hash<T>{}(val);
439             }
440             size_t operator()(const Instruction& inst) const;
441         };
442 
443         Val push(Op, Val x, Val y=NA, Val z=NA, int imm=0);
444         bool isZero(Val) const;
445 
446         SkTHashMap<Instruction, Val, InstructionHash> fIndex;
447         std::vector<Instruction>                      fProgram;
448         std::vector<int>                              fStrides;
449     };
450 
451     using Reg = int;
452 
453     class Program {
454     public:
455         struct Instruction {   // d = op(x, y, z/imm)
456             Op  op;
457             Reg d,x,y;
458             union { Reg z; int imm; };
459         };
460 
461         Program(const std::vector<Builder::Instruction>& instructions,
462                 const std::vector<int>                 & strides,
463                 const char* debug_name);
464 
465         Program();
466         ~Program();
467         Program(Program&&);
468         Program& operator=(Program&&);
469         Program(const Program&) = delete;
470         Program& operator=(const Program&) = delete;
471 
472         void eval(int n, void* args[]) const;
473 
474         template <typename... T>
eval(int n,T * ...arg)475         void eval(int n, T*... arg) const {
476             SkASSERT(sizeof...(arg) == fStrides.size());
477             // This nullptr isn't important except that it makes args[] non-empty if you pass none.
478             void* args[] = { (void*)arg..., nullptr };
479             this->eval(n, args);
480         }
481 
instructions()482         std::vector<Instruction> instructions() const { return fInstructions; }
nregs()483         int nregs() const { return fRegs; }
loop()484         int loop() const { return fLoop; }
empty()485         bool empty() const { return fInstructions.empty(); }
486 
487         bool hasJIT() const;  // Has this Program been JITted?
488         void dropJIT();       // If hasJIT(), drop it, forcing interpreter fallback.
489 
490         void dump(SkWStream* = nullptr) const;
491 
492     private:
493         void setupInterpreter(const std::vector<Builder::Instruction>&);
494         void setupJIT        (const std::vector<Builder::Instruction>&, const char* debug_name);
495 
496         bool jit(const std::vector<Builder::Instruction>&,
497                  bool try_hoisting,
498                  Assembler*) const;
499 
500         // Dump jit-*.dump files for perf inject.
501         void dumpJIT(const char* debug_name, size_t size) const;
502 
503         std::vector<Instruction> fInstructions;
504         int                      fRegs = 0;
505         int                      fLoop = 0;
506         std::vector<int>         fStrides;
507 
508         // We only hang onto these to help debugging.
509         std::vector<Builder::Instruction> fOriginalProgram;
510 
511         void*  fJITBuf  = nullptr;
512         size_t fJITSize = 0;
513     };
514 
515     // TODO: control flow
516     // TODO: 64-bit values?
517     // TODO: SSE2/SSE4.1, AVX-512F, ARMv8.2 JITs?
518     // TODO: lower to LLVM or WebASM for comparison?
519 }
520 
521 #endif//SkVM_DEFINED
522