1 /*
2 * Copyright (c) 2012-2019 Fredrik Mellbin
3 *
4 * This file is part of VapourSynth.
5 *
6 * VapourSynth is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * VapourSynth is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with VapourSynth; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20 
21 #include <algorithm>
22 #include <cmath>
23 #include <functional>
24 #include <iostream>
25 #include <limits>
26 #include <locale>
27 #include <map>
28 #include <memory>
29 #include <set>
30 #include <sstream>
31 #include <stdexcept>
32 #include <string>
33 #include <tuple>
34 #include <unordered_map>
35 #include <unordered_set>
36 #include <vector>
37 #include "VapourSynth.h"
38 #include "VSHelper.h"
39 #include "cpufeatures.h"
40 #include "internalfilters.h"
41 #include "vslog.h"
42 #include "kernel/cpulevel.h"
43 
44 #ifdef VS_TARGET_CPU_X86
45 #include <immintrin.h>
46 #define NOMINMAX
47 #include "jitasm.h"
48 #ifndef VS_TARGET_OS_WINDOWS
49 #include <sys/mman.h>
50 #endif
51 #endif
52 
53 namespace {
54 
55 #define MAX_EXPR_INPUTS 26
56 
57 enum class ExprOpType {
58     // Terminals.
59     MEM_LOAD_U8, MEM_LOAD_U16, MEM_LOAD_F16, MEM_LOAD_F32, CONSTANT,
60     MEM_STORE_U8, MEM_STORE_U16, MEM_STORE_F16, MEM_STORE_F32,
61 
62     // Arithmetic primitives.
63     ADD, SUB, MUL, DIV, FMA, SQRT, ABS, NEG, MAX, MIN, CMP,
64 
65     // Logical operators.
66     AND, OR, XOR, NOT,
67 
68     // Transcendental functions.
69     EXP, LOG, POW, SIN, COS,
70 
71     // Ternary operator
72     TERNARY,
73 
74     // Meta-node holding true/false branches of ternary.
75     MUX,
76 
77     // Stack helpers.
78     DUP, SWAP,
79 };
80 
81 enum class FMAType {
82     FMADD = 0,  // (b * c) + a
83     FMSUB = 1,  // (b * c) - a
84     FNMADD = 2, // -(b * c) + a
85     FNMSUB = 3, // -(b * c) - a
86 };
87 
88 enum class ComparisonType {
89     EQ = 0,
90     LT = 1,
91     LE = 2,
92     NEQ = 4,
93     NLT = 5,
94     NLE = 6,
95 };
96 
97 #ifdef VS_TARGET_CPU_X86
98 static_assert(static_cast<int>(ComparisonType::EQ) == _CMP_EQ_OQ, "");
99 static_assert(static_cast<int>(ComparisonType::LT) == _CMP_LT_OS, "");
100 static_assert(static_cast<int>(ComparisonType::LE) == _CMP_LE_OS, "");
101 static_assert(static_cast<int>(ComparisonType::NEQ) == _CMP_NEQ_UQ, "");
102 static_assert(static_cast<int>(ComparisonType::NLT) == _CMP_NLT_US, "");
103 static_assert(static_cast<int>(ComparisonType::NLE) == _CMP_NLE_US, "");
104 #endif
105 
106 union ExprUnion {
107     int32_t i;
108     uint32_t u;
109     float f;
110 
ExprUnion()111     constexpr ExprUnion() : u{} {}
112 
ExprUnion(int32_t i)113     constexpr ExprUnion(int32_t i) : i(i) {}
ExprUnion(uint32_t u)114     constexpr ExprUnion(uint32_t u) : u(u) {}
ExprUnion(float f)115     constexpr ExprUnion(float f) : f(f) {}
116 };
117 
118 struct ExprOp {
119     ExprOpType type;
120     ExprUnion imm;
121 
ExprOp__anon2a0524d50111::ExprOp122     ExprOp(ExprOpType type, ExprUnion param = {}) : type(type), imm(param) {}
123 };
124 
operator ==(const ExprOp & lhs,const ExprOp & rhs)125 bool operator==(const ExprOp &lhs, const ExprOp &rhs) { return lhs.type == rhs.type && lhs.imm.u == rhs.imm.u; }
operator !=(const ExprOp & lhs,const ExprOp & rhs)126 bool operator!=(const ExprOp &lhs, const ExprOp &rhs) { return !(lhs == rhs); }
127 
128 struct ExprInstruction {
129     ExprOp op;
130     int dst;
131     int src1;
132     int src2;
133     int src3;
134 
ExprInstruction__anon2a0524d50111::ExprInstruction135     ExprInstruction(ExprOp op) : op(op), dst(-1), src1(-1), src2(-1), src3(-1) {}
136 };
137 
138 enum PlaneOp {
139     poProcess, poCopy, poUndefined
140 };
141 
142 struct ExprData {
143     VSNodeRef *node[MAX_EXPR_INPUTS];
144     VSVideoInfo vi;
145     std::vector<ExprInstruction> bytecode[3];
146     int plane[3];
147     int numInputs;
148     typedef void (*ProcessLineProc)(void *rwptrs, intptr_t ptroff[MAX_EXPR_INPUTS + 1], intptr_t niter);
149     ProcessLineProc proc[3];
150     size_t procSize[3];
151 
ExprData__anon2a0524d50111::ExprData152     ExprData() : node(), vi(), plane(), numInputs(), proc() {}
153 
~ExprData__anon2a0524d50111::ExprData154     ~ExprData() {
155 #ifdef VS_TARGET_CPU_X86
156         for (int i = 0; i < 3; i++) {
157             if (proc[i]) {
158 #ifdef VS_TARGET_OS_WINDOWS
159                 VirtualFree((LPVOID)proc[i], 0, MEM_RELEASE);
160 #else
161                 munmap((void *)proc[i], procSize[i]);
162 #endif
163             }
164         }
165 #endif
166     }
167 };
168 
169 #ifdef VS_TARGET_CPU_X86
170 class ExprCompiler {
171     virtual void load8(const ExprInstruction &insn) = 0;
172     virtual void load16(const ExprInstruction &insn) = 0;
173     virtual void loadF16(const ExprInstruction &insn) = 0;
174     virtual void loadF32(const ExprInstruction &insn) = 0;
175     virtual void loadConst(const ExprInstruction &insn) = 0;
176     virtual void store8(const ExprInstruction &insn) = 0;
177     virtual void store16(const ExprInstruction &insn) = 0;
178     virtual void storeF16(const ExprInstruction &insn) = 0;
179     virtual void storeF32(const ExprInstruction &insn) = 0;
180     virtual void add(const ExprInstruction &insn) = 0;
181     virtual void sub(const ExprInstruction &insn) = 0;
182     virtual void mul(const ExprInstruction &insn) = 0;
183     virtual void div(const ExprInstruction &insn) = 0;
184     virtual void fma(const ExprInstruction &insn) = 0;
185     virtual void max(const ExprInstruction &insn) = 0;
186     virtual void min(const ExprInstruction &insn) = 0;
187     virtual void sqrt(const ExprInstruction &insn) = 0;
188     virtual void abs(const ExprInstruction &insn) = 0;
189     virtual void neg(const ExprInstruction &insn) = 0;
190     virtual void not_(const ExprInstruction &insn) = 0;
191     virtual void and_(const ExprInstruction &insn) = 0;
192     virtual void or_(const ExprInstruction &insn) = 0;
193     virtual void xor_(const ExprInstruction &insn) = 0;
194     virtual void cmp(const ExprInstruction &insn) = 0;
195     virtual void ternary(const ExprInstruction &insn) = 0;
196     virtual void exp(const ExprInstruction &insn) = 0;
197     virtual void log(const ExprInstruction &insn) = 0;
198     virtual void pow(const ExprInstruction &insn) = 0;
199     virtual void sin(const ExprInstruction &insn) = 0;
200     virtual void cos(const ExprInstruction &insn) = 0;
201 public:
addInstruction(const ExprInstruction & insn)202     void addInstruction(const ExprInstruction &insn)
203     {
204         switch (insn.op.type) {
205         case ExprOpType::MEM_LOAD_U8: load8(insn); break;
206         case ExprOpType::MEM_LOAD_U16: load16(insn); break;
207         case ExprOpType::MEM_LOAD_F16: loadF16(insn); break;
208         case ExprOpType::MEM_LOAD_F32: loadF32(insn); break;
209         case ExprOpType::CONSTANT: loadConst(insn); break;
210         case ExprOpType::MEM_STORE_U8: store8(insn); break;
211         case ExprOpType::MEM_STORE_U16: store16(insn); break;
212         case ExprOpType::MEM_STORE_F16: storeF16(insn); break;
213         case ExprOpType::MEM_STORE_F32: storeF32(insn); break;
214         case ExprOpType::ADD: add(insn); break;
215         case ExprOpType::SUB: sub(insn); break;
216         case ExprOpType::MUL: mul(insn); break;
217         case ExprOpType::DIV: div(insn); break;
218         case ExprOpType::FMA: fma(insn); break;
219         case ExprOpType::MAX: max(insn); break;
220         case ExprOpType::MIN: min(insn); break;
221         case ExprOpType::SQRT: sqrt(insn); break;
222         case ExprOpType::ABS: abs(insn); break;
223         case ExprOpType::NEG: neg(insn); break;
224         case ExprOpType::NOT: not_(insn); break;
225         case ExprOpType::AND: and_(insn); break;
226         case ExprOpType::OR: or_(insn); break;
227         case ExprOpType::XOR: xor_(insn); break;
228         case ExprOpType::CMP: cmp(insn); break;
229         case ExprOpType::TERNARY: ternary(insn); break;
230         case ExprOpType::EXP: exp(insn); break;
231         case ExprOpType::LOG: log(insn); break;
232         case ExprOpType::POW: pow(insn); break;
233         case ExprOpType::SIN: sin(insn); break;
234         case ExprOpType::COS: cos(insn); break;
235         default: vsFatal("illegal opcode"); break;
236         }
237     }
238 
~ExprCompiler()239     virtual ~ExprCompiler() {}
240     virtual std::pair<ExprData::ProcessLineProc, size_t> getCode() = 0;
241 };
242 
243 class ExprCompiler128 : public ExprCompiler, private jitasm::function<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t> {
244     typedef jitasm::function<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t> jit;
245     friend struct jitasm::function<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t>;
246     friend struct jitasm::function_cdecl<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t>;
247 
248 #define SPLAT(x) { (x), (x), (x), (x) }
249     static constexpr ExprUnion constData alignas(16)[53][4] = {
250         SPLAT(0x7FFFFFFF), // absmask
251         SPLAT(0x80000000), // negmask
252         SPLAT(0x7F), // x7F
253         SPLAT(0x00800000), // min_norm_pos
254         SPLAT(~0x7F800000), // inv_mant_mask
255         SPLAT(1.0f), // float_one
256         SPLAT(0.5f), // float_half
257         SPLAT(255.0f), // float_255
258         SPLAT(511.0f), // float_511
259         SPLAT(1023.0f), // float_1023
260         SPLAT(2047.0f), // float_2047
261         SPLAT(4095.0f), // float_4095
262         SPLAT(8191.0f), // float_8191
263         SPLAT(16383.0f), // float_16383
264         SPLAT(32767.0f), // float_32767
265         SPLAT(65535.0f), // float_65535
266         SPLAT(static_cast<int32_t>(0x80008000)), // i16min_epi16
267         SPLAT(static_cast<int32_t>(0xFFFF8000)), // i16min_epi32
268         SPLAT(88.3762626647949f), // exp_hi
269         SPLAT(-88.3762626647949f), // exp_lo
270         SPLAT(1.44269504088896341f), // log2e
271         SPLAT(0.693359375f), // exp_c1
272         SPLAT(-2.12194440e-4f), // exp_c2
273         SPLAT(1.9875691500E-4f), // exp_p0
274         SPLAT(1.3981999507E-3f), // exp_p1
275         SPLAT(8.3334519073E-3f), // exp_p2
276         SPLAT(4.1665795894E-2f), // exp_p3
277         SPLAT(1.6666665459E-1f), // exp_p4
278         SPLAT(5.0000001201E-1f), // exp_p5
279         SPLAT(0.707106781186547524f), // sqrt_1_2
280         SPLAT(7.0376836292E-2f), // log_p0
281         SPLAT(-1.1514610310E-1f), // log_p1
282         SPLAT(1.1676998740E-1f), // log_p2
283         SPLAT(-1.2420140846E-1f), // log_p3
284         SPLAT(+1.4249322787E-1f), // log_p4
285         SPLAT(-1.6668057665E-1f), // log_p5
286         SPLAT(+2.0000714765E-1f), // log_p6
287         SPLAT(-2.4999993993E-1f), // log_p7
288         SPLAT(+3.3333331174E-1f), // log_p8
289         SPLAT(0x3ea2f983), // float_invpi, 1/pi
290         SPLAT(0x4b400000), // float_rintf
291         SPLAT(0x40490000), // float_pi1
292         SPLAT(0x3a7da000), // float_pi2
293         SPLAT(0x34222000), // float_pi3
294         SPLAT(0x2cb4611a), // float_pi4
295         SPLAT(0xbe2aaaa6), // float_sinC3
296         SPLAT(0x3c08876a), // float_sinC5
297         SPLAT(0xb94fb7ff), // float_sinC7
298         SPLAT(0x362edef8), // float_sinC9
299         SPLAT(static_cast<int32_t>(0xBEFFFFE2)), // float_cosC2
300         SPLAT(0x3D2AA73C), // float_cosC4
301         SPLAT(static_cast<int32_t>(0XBAB58D50)), // float_cosC6
302         SPLAT(0x37C1AD76), // float_cosC8
303     };
304 
305     struct ConstantIndex {
306         static constexpr int absmask = 0;
307         static constexpr int negmask = 1;
308         static constexpr int x7F = 2;
309         static constexpr int min_norm_pos = 3;
310         static constexpr int inv_mant_mask = 4;
311         static constexpr int float_one = 5;
312         static constexpr int float_half = 6;
313         static constexpr int float_255 = 7;
314         static constexpr int float_511 = 8;
315         static constexpr int float_1023 = 9;
316         static constexpr int float_2047 = 10;
317         static constexpr int float_4095 = 11;
318         static constexpr int float_8191 = 12;
319         static constexpr int float_16383 = 13;
320         static constexpr int float_32767 = 14;
321         static constexpr int float_65535 = 15;
322         static constexpr int i16min_epi16 = 16;
323         static constexpr int i16min_epi32 = 17;
324         static constexpr int exp_hi = 18;
325         static constexpr int exp_lo = 19;
326         static constexpr int log2e = 20;
327         static constexpr int exp_c1 = 21;
328         static constexpr int exp_c2 = 22;
329         static constexpr int exp_p0 = 23;
330         static constexpr int exp_p1 = 24;
331         static constexpr int exp_p2 = 25;
332         static constexpr int exp_p3 = 26;
333         static constexpr int exp_p4 = 27;
334         static constexpr int exp_p5 = 28;
335         static constexpr int sqrt_1_2 = 29;
336         static constexpr int log_p0 = 30;
337         static constexpr int log_p1 = 31;
338         static constexpr int log_p2 = 32;
339         static constexpr int log_p3 = 33;
340         static constexpr int log_p4 = 34;
341         static constexpr int log_p5 = 35;
342         static constexpr int log_p6 = 36;
343         static constexpr int log_p7 = 37;
344         static constexpr int log_p8 = 38;
345         static constexpr int log_q1 = exp_c2;
346         static constexpr int log_q2 = exp_c1;
347         static constexpr int float_invpi = 39;
348         static constexpr int float_rintf = 40;
349         static constexpr int float_pi1 = 41;
350         static constexpr int float_pi2 = float_pi1 + 1;
351         static constexpr int float_pi3 = float_pi1 + 2;
352         static constexpr int float_pi4 = float_pi1 + 3;
353         static constexpr int float_sinC3 = 45;
354         static constexpr int float_sinC5 = float_sinC3 + 1;
355         static constexpr int float_sinC7 = float_sinC3 + 2;
356         static constexpr int float_sinC9 = float_sinC3 + 3;
357         static constexpr int float_cosC2 = 49;
358         static constexpr int float_cosC4 = float_cosC2 + 1;
359         static constexpr int float_cosC6 = float_cosC2 + 2;
360         static constexpr int float_cosC8 = float_cosC2 + 3;
361     };
362 #undef SPLAT
363 
364     // JitASM compiles everything from main(), so record the operations for later.
365     std::vector<std::function<void(Reg, XmmReg, Reg, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &)>> deferred;
366 
367     CPUFeatures cpuFeatures;
368     int numInputs;
369     int curLabel;
370 
371 #define EMIT() [this, insn](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
372 #define VEX1(op, arg1, arg2) \
373 do { \
374   if (cpuFeatures.avx) \
375     v##op(arg1, arg2); \
376   else \
377     op(arg1, arg2); \
378 } while (0)
379 #define VEX1IMM(op, arg1, arg2, imm) \
380 do { \
381   if (cpuFeatures.avx) { \
382     v##op(arg1, arg2, imm); \
383   } else if (arg1 == arg2) { \
384     op(arg2, imm); \
385   } else { \
386     movdqa(arg1, arg2); \
387     op(arg1, imm); \
388   } \
389 } while (0)
390 #define VEX2(op, arg1, arg2, arg3) \
391 do { \
392   if (cpuFeatures.avx) { \
393     v##op(arg1, arg2, arg3); \
394   } else if (arg1 == arg2) { \
395     op(arg2, arg3); \
396   } else if (arg1 != arg3) { \
397     movdqa(arg1, arg2); \
398     op(arg1, arg3); \
399   } else { \
400     XmmReg tmp; \
401     movdqa(tmp, arg2); \
402     op(tmp, arg3); \
403     movdqa(arg1, tmp); \
404   } \
405 } while (0)
406 #define VEX2IMM(op, arg1, arg2, arg3, imm) \
407 do { \
408   if (cpuFeatures.avx) { \
409     v##op(arg1, arg2, arg3, imm); \
410   } else if (arg1 == arg2) { \
411     op(arg2, arg3, imm); \
412   } else if (arg1 != arg3) { \
413     movdqa(arg1, arg2); \
414     op(arg1, arg3, imm); \
415   } else { \
416     XmmReg tmp; \
417     movdqa(tmp, arg2); \
418     op(tmp, arg3, imm); \
419     movdqa(arg1, tmp); \
420   } \
421 } while (0)
422 
load8(const ExprInstruction & insn)423     void load8(const ExprInstruction &insn) override
424     {
425         deferred.push_back(EMIT()
426         {
427             auto t1 = bytecodeRegs[insn.dst];
428             Reg a;
429             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
430             VEX1(movq, t1.first, mmword_ptr[a]);
431             VEX2(punpcklbw, t1.first, t1.first, zero);
432             VEX2(punpckhwd, t1.second, t1.first, zero);
433             VEX2(punpcklwd, t1.first, t1.first, zero);
434             VEX1(cvtdq2ps, t1.first, t1.first);
435             VEX1(cvtdq2ps, t1.second, t1.second);
436         });
437     }
438 
load16(const ExprInstruction & insn)439     void load16(const ExprInstruction &insn) override
440     {
441         deferred.push_back(EMIT()
442         {
443             auto t1 = bytecodeRegs[insn.dst];
444             Reg a;
445             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
446             VEX1(movdqa, t1.first, xmmword_ptr[a]);
447             VEX2(punpckhwd, t1.second, t1.first, zero);
448             VEX2(punpcklwd, t1.first, t1.first, zero);
449             VEX1(cvtdq2ps, t1.first, t1.first);
450             VEX1(cvtdq2ps, t1.second, t1.second);
451         });
452     }
453 
loadF16(const ExprInstruction & insn)454     void loadF16(const ExprInstruction &insn) override
455     {
456         deferred.push_back(EMIT()
457         {
458             auto t1 = bytecodeRegs[insn.dst];
459             Reg a;
460             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
461             vcvtph2ps(t1.first, qword_ptr[a]);
462             vcvtph2ps(t1.second, qword_ptr[a + 8]);
463         });
464     }
465 
loadF32(const ExprInstruction & insn)466     void loadF32(const ExprInstruction &insn) override
467     {
468         deferred.push_back(EMIT()
469         {
470             auto t1 = bytecodeRegs[insn.dst];
471             Reg a;
472             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
473             VEX1(movdqa, t1.first, xmmword_ptr[a]);
474             VEX1(movdqa, t1.second, xmmword_ptr[a + 16]);
475         });
476     }
477 
loadConst(const ExprInstruction & insn)478     void loadConst(const ExprInstruction &insn) override
479     {
480         deferred.push_back(EMIT()
481         {
482             auto t1 = bytecodeRegs[insn.dst];
483 
484             if (insn.op.imm.f == 0.0f) {
485                 VEX1(movaps, t1.first, zero);
486                 VEX1(movaps, t1.second, zero);
487                 return;
488             }
489 
490             Reg32 a;
491             mov(a, insn.op.imm.u);
492             VEX1(movd, t1.first, a);
493             VEX2IMM(shufps, t1.first, t1.first, t1.first, 0);
494             VEX1(movaps, t1.second, t1.first);
495         });
496     }
497 
store8(const ExprInstruction & insn)498     void store8(const ExprInstruction &insn) override
499     {
500         deferred.push_back(EMIT()
501         {
502             auto t1 = bytecodeRegs[insn.src1];
503             XmmReg r1, r2, limit;
504             Reg a;
505             VEX1(movaps, limit, xmmword_ptr[constants + ConstantIndex::float_255 * 16]);
506             VEX2(minps, r1, t1.first, limit);
507             VEX2(minps, r2, t1.second, limit);
508             VEX1(cvtps2dq, r1, r1);
509             VEX1(cvtps2dq, r2, r2);
510             VEX2(packssdw, r1, r1, r2);
511             VEX2(packuswb, r1, r1, zero);
512             mov(a, ptr[regptrs]);
513             VEX1(movq, mmword_ptr[a], r1);
514         });
515     }
516 
store16(const ExprInstruction & insn)517     void store16(const ExprInstruction &insn) override
518     {
519         deferred.push_back(EMIT()
520         {
521             int depth = insn.op.imm.u;
522             auto t1 = bytecodeRegs[insn.src1];
523             XmmReg r1, r2, limit;
524             Reg a;
525             VEX1(movaps, limit, xmmword_ptr[constants + (ConstantIndex::float_255 + depth - 8) * 16]);
526             VEX2IMM(shufps, limit, limit, limit, 0);
527             VEX2(minps, r1, t1.first, limit);
528             VEX2(minps, r2, t1.second, limit);
529             VEX1(cvtps2dq, r1, r1);
530             VEX1(cvtps2dq, r2, r2);
531 
532             if (cpuFeatures.sse4_1) {
533                 VEX2(packusdw, r1, r1, r2);
534             } else {
535                 if (depth >= 16) {
536                     VEX1(movaps, limit, xmmword_ptr[constants + ConstantIndex::i16min_epi32 * 16]);
537                     VEX2(paddd, r1, r1, limit);
538                     VEX2(paddd, r2, r2, limit);
539                 }
540                 VEX2(packssdw, r1, r1, r2);
541                 if (depth >= 16)
542                     VEX2(psubw, r1, r1, xmmword_ptr[constants + ConstantIndex::i16min_epi16 * 16]);
543             }
544             mov(a, ptr[regptrs]);
545             VEX1(movaps, xmmword_ptr[a], r1);
546         });
547     }
548 
storeF16(const ExprInstruction & insn)549     void storeF16(const ExprInstruction &insn) override
550     {
551         deferred.push_back(EMIT()
552         {
553             auto t1 = bytecodeRegs[insn.src1];
554 
555             Reg a;
556             mov(a, ptr[regptrs]);
557             vcvtps2ph(qword_ptr[a], t1.first, 0);
558             vcvtps2ph(qword_ptr[a + 8], t1.second, 0);
559         });
560     }
561 
storeF32(const ExprInstruction & insn)562     void storeF32(const ExprInstruction &insn) override
563     {
564         deferred.push_back(EMIT()
565         {
566             auto t1 = bytecodeRegs[insn.src1];
567 
568             Reg a;
569             mov(a, ptr[regptrs]);
570             VEX1(movaps, xmmword_ptr[a], t1.first);
571             VEX1(movaps, xmmword_ptr[a + 16], t1.second);
572         });
573     }
574 
575 #define BINARYOP(op) \
576 do { \
577   auto t1 = bytecodeRegs[insn.src1]; \
578   auto t2 = bytecodeRegs[insn.src2]; \
579   auto t3 = bytecodeRegs[insn.dst]; \
580   VEX2(op, t3.first, t1.first, t2.first); \
581   VEX2(op, t3.second, t1.second, t2.second); \
582 } while (0)
add(const ExprInstruction & insn)583     void add(const ExprInstruction &insn) override
584     {
585         deferred.push_back(EMIT()
586         {
587             BINARYOP(addps);
588         });
589     }
590 
sub(const ExprInstruction & insn)591     void sub(const ExprInstruction &insn) override
592     {
593         deferred.push_back(EMIT()
594         {
595             BINARYOP(subps);
596         });
597     }
598 
mul(const ExprInstruction & insn)599     void mul(const ExprInstruction &insn) override
600     {
601         deferred.push_back(EMIT()
602         {
603             BINARYOP(mulps);
604         });
605     }
606 
div(const ExprInstruction & insn)607     void div(const ExprInstruction &insn) override
608     {
609         deferred.push_back(EMIT()
610         {
611             BINARYOP(divps);
612         });
613     }
614 
fma(const ExprInstruction & insn)615     void fma(const ExprInstruction &insn) override
616     {
617         deferred.push_back(EMIT()
618         {
619             FMAType type = static_cast<FMAType>(insn.op.imm.u);
620 
621             // t1 + t2 * t3
622             auto t1 = bytecodeRegs[insn.src1];
623             auto t2 = bytecodeRegs[insn.src2];
624             auto t3 = bytecodeRegs[insn.src3];
625             auto t4 = bytecodeRegs[insn.dst];
626 
627             if (cpuFeatures.fma3) {
628 #define FMA3(op) \
629 do { \
630   if (insn.dst == insn.src1) { \
631     v##op##231ps(t1.first, t2.first, t3.first); \
632     v##op##231ps(t1.second, t2.second, t3.second); \
633   } else if (insn.dst == insn.src2) { \
634     v##op##132ps(t2.first, t1.first, t3.first); \
635     v##op##132ps(t2.second, t1.second, t3.second); \
636   } else if (insn.dst == insn.src3) { \
637     v##op##132ps(t3.first, t1.first, t2.first); \
638     v##op##132ps(t3.second, t1.second, t2.second); \
639   } else { \
640     vmovaps(t4.first, t1.first); \
641     vmovaps(t4.second, t1.second); \
642     v##op##231ps(t4.first, t2.first, t3.first); \
643     v##op##231ps(t4.second, t2.second, t3.second); \
644   } \
645 } while (0)
646                 switch (type) {
647                 case FMAType::FMADD: FMA3(fmadd); break;
648                 case FMAType::FMSUB: FMA3(fmsub); break;
649                 case FMAType::FNMADD: FMA3(fnmadd); break;
650                 case FMAType::FNMSUB: FMA3(fnmsub); break;
651                 }
652 #undef FMA3
653             } else {
654                 XmmReg r1, r2;
655                 VEX2(mulps, r1, t2.first, t3.first);
656                 VEX2(mulps, r2, t2.second, t3.second);
657 
658                 if (type == FMAType::FMADD || type == FMAType::FNMSUB) {
659                     VEX2(addps, t4.first, r1, t1.first);
660                     VEX2(addps, t4.second, r2, t1.second);
661                 } else if (type == FMAType::FMSUB) {
662                     VEX2(subps, t4.first, r1, t1.first);
663                     VEX2(subps, t4.second, r2, t1.second);
664                 } else if (type == FMAType::FNMADD) {
665                     VEX2(subps, t4.first, t1.first, r1);
666                     VEX2(subps, t4.second, t1.second, r2);
667                 }
668 
669                 if (type == FMAType::FNMSUB) {
670                     VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::negmask * 16]);
671                     VEX2(xorps, t4.first, t4.first, r1);
672                     VEX2(xorps, t4.second, t4.second, r2);
673                 }
674             }
675         });
676     }
677 
max(const ExprInstruction & insn)678     void max(const ExprInstruction &insn) override
679     {
680         deferred.push_back(EMIT()
681         {
682             BINARYOP(maxps);
683         });
684     }
685 
min(const ExprInstruction & insn)686     void min(const ExprInstruction &insn) override
687     {
688         deferred.push_back(EMIT()
689         {
690             BINARYOP(minps);
691         });
692     }
693 #undef BINARYOP
694 
sqrt(const ExprInstruction & insn)695     void sqrt(const ExprInstruction &insn) override
696     {
697         deferred.push_back(EMIT()
698         {
699             auto t1 = bytecodeRegs[insn.src1];
700             auto t2 = bytecodeRegs[insn.dst];
701             VEX2(maxps, t2.first, t1.first, zero);
702             VEX2(maxps, t2.second, t1.second, zero);
703             VEX1(sqrtps, t2.first, t2.first);
704             VEX1(sqrtps, t2.second, t2.second);
705         });
706     }
707 
abs(const ExprInstruction & insn)708     void abs(const ExprInstruction &insn) override
709     {
710         deferred.push_back(EMIT()
711         {
712             auto t1 = bytecodeRegs[insn.src1];
713             auto t2 = bytecodeRegs[insn.dst];
714             XmmReg r1;
715             VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::absmask * 16]);
716             VEX2(andps, t2.first, t1.first, r1);
717             VEX2(andps, t2.second, t1.second, r1);
718         });
719     }
720 
neg(const ExprInstruction & insn)721     void neg(const ExprInstruction &insn) override
722     {
723         deferred.push_back(EMIT()
724         {
725             auto t1 = bytecodeRegs[insn.src1];
726             auto t2 = bytecodeRegs[insn.dst];
727             XmmReg r1;
728             VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::negmask * 16]);
729             VEX2(xorps, t2.first, t1.first, r1);
730             VEX2(xorps, t2.second, t1.second, r1);
731         });
732     }
733 
not_(const ExprInstruction & insn)734     void not_(const ExprInstruction &insn) override
735     {
736         deferred.push_back(EMIT()
737         {
738             auto t1 = bytecodeRegs[insn.src1];
739             auto t2 = bytecodeRegs[insn.dst];
740             XmmReg r1;
741             VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
742             VEX2IMM(cmpps, t2.first, t1.first, zero, _CMP_LE_OS);
743             VEX2IMM(cmpps, t2.second, t1.second, zero, _CMP_LE_OS);
744             VEX2(andps, t2.first, t2.first, r1);
745             VEX2(andps, t2.second, t2.second, r1);
746         });
747     }
748 
749 #define LOGICOP(op) \
750 do { \
751   auto t1 = bytecodeRegs[insn.src1]; \
752   auto t2 = bytecodeRegs[insn.src2]; \
753   auto t3 = bytecodeRegs[insn.dst]; \
754   XmmReg r1, tmp1, tmp2; \
755   VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::float_one * 16]); \
756   VEX2IMM(cmpps, tmp1, t1.first, zero, _CMP_NLE_US); \
757   VEX2IMM(cmpps, tmp2, t1.second, zero, _CMP_NLE_US); \
758   VEX2IMM(cmpps, t3.first, t2.first, zero, _CMP_NLE_US); \
759   VEX2IMM(cmpps, t3.second, t2.second, zero, _CMP_NLE_US); \
760   VEX2(op, t3.first, t3.first, tmp1); \
761   VEX2(op, t3.second, t3.second, tmp2); \
762   VEX2(andps, t3.first, t3.first, r1); \
763   VEX2(andps, t3.second, t3.second, r1); \
764 } while (0)
765 
and_(const ExprInstruction & insn)766     void and_(const ExprInstruction &insn) override
767     {
768         deferred.push_back(EMIT()
769         {
770             LOGICOP(andps);
771         });
772     }
773 
or_(const ExprInstruction & insn)774     void or_(const ExprInstruction &insn) override
775     {
776         deferred.push_back(EMIT()
777         {
778             LOGICOP(orps);
779         });
780     }
781 
xor_(const ExprInstruction & insn)782     void xor_(const ExprInstruction &insn) override
783     {
784         deferred.push_back(EMIT()
785         {
786             LOGICOP(xorps);
787         });
788     }
789 #undef LOGICOP
790 
cmp(const ExprInstruction & insn)791     void cmp(const ExprInstruction &insn) override
792     {
793         deferred.push_back(EMIT()
794         {
795             auto t1 = bytecodeRegs[insn.src1];
796             auto t2 = bytecodeRegs[insn.src2];
797             auto t3 = bytecodeRegs[insn.dst];
798             XmmReg r1;
799             VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
800             VEX2IMM(cmpps, t3.first, t1.first, t2.first, insn.op.imm.u);
801             VEX2IMM(cmpps, t3.second, t1.second, t2.second, insn.op.imm.u);
802             VEX2(andps, t3.first, t3.first, r1);
803             VEX2(andps, t3.second, t3.second, r1);
804         });
805     }
806 
ternary(const ExprInstruction & insn)807     void ternary(const ExprInstruction &insn) override
808     {
809         deferred.push_back(EMIT()
810         {
811             auto t1 = bytecodeRegs[insn.src1];
812             auto t2 = bytecodeRegs[insn.src2];
813             auto t3 = bytecodeRegs[insn.src3];
814             auto t4 = bytecodeRegs[insn.dst];
815 
816             XmmReg r1, r2;
817             VEX2IMM(cmpps, r1, t1.first, zero, _CMP_NLE_US);
818             VEX2IMM(cmpps, r2, t1.second, zero, _CMP_NLE_US);
819 
820             if (cpuFeatures.sse4_1) {
821                 VEX2IMM(blendvps, t4.first, t3.first, t2.first, r1);
822                 VEX2IMM(blendvps, t4.second, t3.second, t2.second, r2);
823             } else {
824                 VEX2(andps, t4.first, t3.first, r1);
825                 VEX2(andps, t4.second, t3.second, r2);
826                 VEX2(andnps, r1, r1, t2.first);
827                 VEX2(andnps, r2, r2, t2.second);
828                 VEX2(orps, t4.first, t4.first, r1);
829                 VEX2(orps, t4.second, t4.second, r2);
830             }
831         });
832     }
833 
exp_(XmmReg x,XmmReg one,Reg constants)834     void exp_(XmmReg x, XmmReg one, Reg constants)
835     {
836         XmmReg fx, emm0, etmp, y, mask, z;
837         VEX2(minps, x, x, xmmword_ptr[constants + ConstantIndex::exp_hi * 16]);
838         VEX2(maxps, x, x, xmmword_ptr[constants + ConstantIndex::exp_lo * 16]);
839         VEX2(mulps, fx, x, xmmword_ptr[constants + ConstantIndex::log2e * 16]);
840         VEX2(addps, fx, fx, xmmword_ptr[constants + ConstantIndex::float_half * 16]);
841         VEX1(cvttps2dq, emm0, fx);
842         VEX1(cvtdq2ps, etmp, emm0);
843         VEX2IMM(cmpps, mask, etmp, fx, _CMP_NLE_US);
844         VEX2(andps, mask, mask, one);
845         VEX2(subps, fx, etmp, mask);
846         VEX2(mulps, etmp, fx, xmmword_ptr[constants + ConstantIndex::exp_c1 * 16]);
847         VEX2(mulps, z, fx, xmmword_ptr[constants + ConstantIndex::exp_c2 * 16]);
848         VEX2(subps, x, x, etmp);
849         VEX2(subps, x, x, z);
850         VEX2(mulps, z, x, x);
851         VEX2(mulps, y, x, xmmword_ptr[constants + ConstantIndex::exp_p0 * 16]);
852         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p1 * 16]);
853         VEX2(mulps, y, y, x);
854         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p2 * 16]);
855         VEX2(mulps, y, y, x);
856         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p3 * 16]);
857         VEX2(mulps, y, y, x);
858         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p4 * 16]);
859         VEX2(mulps, y, y, x);
860         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p5 * 16]);
861         VEX2(mulps, y, y, z);
862         VEX2(addps, y, y, x);
863         VEX2(addps, y, y, one);
864         VEX1(cvttps2dq, emm0, fx);
865         VEX2(paddd, emm0, emm0, xmmword_ptr[constants + ConstantIndex::x7F * 16]);
866         VEX1IMM(pslld, emm0, emm0, 23);
867         VEX2(mulps, x, y, emm0);
868     }
869 
log_(XmmReg x,XmmReg zero,XmmReg one,Reg constants)870     void log_(XmmReg x, XmmReg zero, XmmReg one, Reg constants)
871     {
872         XmmReg emm0, invalid_mask, mask, y, etmp, z;
873         VEX2IMM(cmpps, invalid_mask, zero, x, _CMP_NLT_US);
874         VEX2(maxps, x, x, xmmword_ptr[constants + ConstantIndex::min_norm_pos * 16]);
875         VEX1IMM(psrld, emm0, x, 23);
876         VEX2(andps, x, x, xmmword_ptr[constants + ConstantIndex::inv_mant_mask * 16]);
877         VEX2(orps, x, x, xmmword_ptr[constants + ConstantIndex::float_half * 16]);
878         VEX2(psubd, emm0, emm0, xmmword_ptr[constants + ConstantIndex::x7F * 16]);
879         VEX1(cvtdq2ps, emm0, emm0);
880         VEX2(addps, emm0, emm0, one);
881         VEX2IMM(cmpps, mask, x, xmmword_ptr[constants + ConstantIndex::sqrt_1_2 * 16], _CMP_LT_OS);
882         VEX2(andps, etmp, x, mask);
883         VEX2(subps, x, x, one);
884         VEX2(andps, mask, mask, one);
885         VEX2(subps, emm0, emm0, mask);
886         VEX2(addps, x, x, etmp);
887         VEX2(mulps, z, x, x);
888         VEX2(mulps, y, x, xmmword_ptr[constants + ConstantIndex::log_p0 * 16]);
889         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p1 * 16]);
890         VEX2(mulps, y, y, x);
891         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p2 * 16]);
892         VEX2(mulps, y, y, x);
893         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p3 * 16]);
894         VEX2(mulps, y, y, x);
895         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p4 * 16]);
896         VEX2(mulps, y, y, x);
897         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p5 * 16]);
898         VEX2(mulps, y, y, x);
899         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p6 * 16]);
900         VEX2(mulps, y, y, x);
901         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p7 * 16]);
902         VEX2(mulps, y, y, x);
903         VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p8 * 16]);
904         VEX2(mulps, y, y, x);
905         VEX2(mulps, y, y, z);
906         VEX2(mulps, etmp, emm0, xmmword_ptr[constants + ConstantIndex::log_q1 * 16]);
907         VEX2(addps, y, y, etmp);
908         VEX2(mulps, z, z, xmmword_ptr[constants + ConstantIndex::float_half * 16]);
909         VEX2(subps, y, y, z);
910         VEX2(mulps, emm0, emm0, xmmword_ptr[constants + ConstantIndex::log_q2 * 16]);
911         VEX2(addps, x, x, y);
912         VEX2(addps, x, x, emm0);
913         VEX2(orps, x, x, invalid_mask);
914     }
915 
exp(const ExprInstruction & insn)916     void exp(const ExprInstruction &insn) override
917     {
918         int l = curLabel++;
919 
920         deferred.push_back([this, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
921         {
922             char label[] = "label-0000";
923             sprintf(label, "label-%04d", l);
924 
925             auto t1 = bytecodeRegs[insn.src1];
926             auto t2 = bytecodeRegs[insn.dst];
927             XmmReg r1, r2, one;
928             Reg a;
929             mov(a, 2);
930             VEX1(movaps, r1, t1.first);
931             VEX1(movaps, r2, t1.second);
932             VEX1(movaps, one, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
933 
934             L(label);
935 
936             exp_(r1, one, constants);
937             VEX1(movaps, t2.first, t2.second);
938             VEX1(movaps, t2.second, r1);
939             VEX1(movaps, r1, r2);
940 
941             jit::sub(a, 1);
942             jnz(label);
943         });
944     }
945 
log(const ExprInstruction & insn)946     void log(const ExprInstruction &insn) override
947     {
948         int l = curLabel++;
949 
950         deferred.push_back([this, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
951         {
952             char label[] = "label-0000";
953             sprintf(label, "label-%04d", l);
954 
955             auto t1 = bytecodeRegs[insn.src1];
956             auto t2 = bytecodeRegs[insn.dst];
957             XmmReg r1, r2, one;
958             Reg a;
959             mov(a, 2);
960             VEX1(movaps, r1, t1.first);
961             VEX1(movaps, r2, t1.second);
962             VEX1(movaps, one, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
963 
964             L(label);
965 
966             log_(r1, zero, one, constants);
967             VEX1(movaps, t2.first, t2.second);
968             VEX1(movaps, t2.second, r1);
969             VEX1(movaps, r1, r2);
970 
971             jit::sub(a, 1);
972             jnz(label);
973         });
974     }
975 
pow(const ExprInstruction & insn)976     void pow(const ExprInstruction &insn) override
977     {
978         int l = curLabel++;
979 
980         deferred.push_back([this, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
981         {
982             char label[] = "label-0000";
983             sprintf(label, "label-%04d", l);
984 
985             auto t1 = bytecodeRegs[insn.src1];
986             auto t2 = bytecodeRegs[insn.src2];
987             auto t3 = bytecodeRegs[insn.dst];
988 
989             XmmReg r1, r2, r3, r4, one;
990             Reg a;
991             mov(a, 2);
992             VEX1(movaps, r1, t1.first);
993             VEX1(movaps, r2, t1.second);
994             VEX1(movaps, r3, t2.first);
995             VEX1(movaps, r4, t2.second);
996             VEX1(movaps, one, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
997 
998             L(label);
999 
1000             log_(r1, zero, one, constants);
1001             VEX2(mulps, r1, r1, r3);
1002             exp_(r1, one, constants);
1003 
1004             VEX1(movaps, t3.first, t3.second);
1005             VEX1(movaps, t3.second, r1);
1006             VEX1(movaps, r1, r2);
1007             VEX1(movaps, r3, r4);
1008 
1009             jit::sub(a, 1);
1010             jnz(label);
1011         });
1012     }
1013 
sincos_(bool issin,XmmReg y,XmmReg x,Reg constants)1014     void sincos_(bool issin, XmmReg y, XmmReg x, Reg constants)
1015     {
1016         XmmReg t1, sign, t2, t3, t4;
1017         // Remove sign
1018         VEX1(movaps, t1, xmmword_ptr[constants + ConstantIndex::absmask * 16]);
1019         if (issin) {
1020             VEX1(movaps, sign, t1);
1021             VEX2(andnps, sign, sign, x);
1022         } else {
1023             VEX2(pxor, sign, sign, sign);
1024         }
1025         VEX2(andps, t1, t1, x);
1026         // Range reduction
1027         VEX1(movaps, t3, xmmword_ptr[constants + ConstantIndex::float_rintf * 16]);
1028         VEX2(mulps, t2, t1, xmmword_ptr[constants + ConstantIndex::float_invpi * 16]);
1029         VEX2(addps, t2, t2, t3);
1030         VEX1IMM(pslld, t4, t2, 31);
1031         VEX2(xorps, sign, sign, t4);
1032         VEX2(subps, t2, t2, t3);
1033         if (cpuFeatures.fma3) {
1034             vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi1 * 16]);
1035             vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi2 * 16]);
1036             vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi3 * 16]);
1037             vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi4 * 16]);
1038         } else {
1039             VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi1 * 16]);
1040             VEX2(subps, t1, t1, t4);
1041             VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi2 * 16]);
1042             VEX2(subps, t1, t1, t4);
1043             VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi3 * 16]);
1044             VEX2(subps, t1, t1, t4);
1045             VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi4 * 16]);
1046             VEX2(subps, t1, t1, t4);
1047         }
1048         if (issin) {
1049             // Evaluate minimax polynomial for sin(x) in [-pi/2, pi/2] interval
1050             // Y <- X + X * X^2 * (C3 + X^2 * (C5 + X^2 * (C7 + X^2 * C9)))
1051             VEX2(mulps, t2, t1, t1);
1052             if (cpuFeatures.fma3) {
1053                 vmovaps(t3, xmmword_ptr[constants + ConstantIndex::float_sinC7 * 16]);
1054                 vfmadd231ps(t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC9 * 16]);
1055                 vfmadd213ps(t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC5 * 16]);
1056                 vfmadd213ps(t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC3 * 16]);
1057                 VEX2(mulps, t3, t3, t2);
1058                 vfmadd231ps(t1, t1, t3);
1059             } else {
1060                 VEX2(mulps, t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC9 * 16]);
1061                 VEX2(addps, t3, t3, xmmword_ptr[constants + ConstantIndex::float_sinC7 * 16]);
1062                 VEX2(mulps, t3, t3, t2);
1063                 VEX2(addps, t3, t3, xmmword_ptr[constants + ConstantIndex::float_sinC5 * 16]);
1064                 VEX2(mulps, t3, t3, t2);
1065                 VEX2(addps, t3, t3, xmmword_ptr[constants + ConstantIndex::float_sinC3 * 16]);
1066                 VEX2(mulps, t3, t3, t2);
1067                 VEX2(mulps, t3, t3, t1);
1068                 VEX2(addps, t1, t1, t3);
1069             }
1070         } else {
1071             // Evaluate minimax polynomial for cos(x) in [-pi/2, pi/2] interval
1072             // Y <- 1 + X^2 * (C2 + X^2 * (C4 + X^2 * (C6 + X^2 * C8)))
1073             VEX2(mulps, t2, t1, t1);
1074             if (cpuFeatures.fma3) {
1075                 vmovaps(t1, xmmword_ptr[constants + ConstantIndex::float_cosC6 * 16]);
1076                 vfmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC8 * 16]);
1077                 vfmadd213ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC4 * 16]);
1078                 vfmadd213ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC2 * 16]);
1079                 vfmadd213ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
1080             } else {
1081                 VEX2(mulps, t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC8 * 16]);
1082                 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_cosC6 * 16]);
1083                 VEX2(mulps, t1, t1, t2);
1084                 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_cosC4 * 16]);
1085                 VEX2(mulps, t1, t1, t2);
1086                 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_cosC2 * 16]);
1087                 VEX2(mulps, t1, t1, t2);
1088                 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
1089             }
1090         }
1091         // Apply sign
1092         VEX2(xorps, y, t1, sign);
1093     }
1094 
sincos(bool issin,const ExprInstruction & insn)1095     void sincos(bool issin, const ExprInstruction &insn)
1096     {
1097         int l = curLabel++;
1098 
1099         deferred.push_back([this, issin, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
1100         {
1101             char label[] = "label-0000";
1102             sprintf(label, "label-%04d", l);
1103 
1104             auto t1 = bytecodeRegs[insn.src1];
1105             auto t3 = bytecodeRegs[insn.dst];
1106 
1107             XmmReg r1, r2;
1108             Reg a;
1109             mov(a, 2);
1110             VEX1(movaps, r1, t1.first);
1111             VEX1(movaps, r2, t1.second);
1112 
1113             L(label);
1114 
1115             sincos_(issin, r1, r1, constants);
1116             VEX1(movaps, t3.first, t3.second);
1117             VEX1(movaps, t3.second, r1);
1118             VEX1(movaps, r1, r2);
1119 
1120             jit::sub(a, 1);
1121             jnz(label);
1122         });
1123     }
1124 
sin(const ExprInstruction & insn)1125     void sin(const ExprInstruction &insn) override
1126     {
1127         sincos(true, insn);
1128     }
1129 
cos(const ExprInstruction & insn)1130     void cos(const ExprInstruction &insn) override
1131     {
1132         sincos(false, insn);
1133     }
1134 
main(Reg regptrs,Reg regoffs,Reg niter)1135     void main(Reg regptrs, Reg regoffs, Reg niter)
1136     {
1137         std::unordered_map<int, std::pair<XmmReg, XmmReg>> bytecodeRegs;
1138         XmmReg zero;
1139         VEX2(pxor, zero, zero, zero);
1140         Reg constants;
1141         mov(constants, (uintptr_t)constData);
1142 
1143         L("wloop");
1144 
1145         for (const auto &f : deferred) {
1146             f(regptrs, zero, constants, bytecodeRegs);
1147         }
1148 
1149 #if UINTPTR_MAX > UINT32_MAX
1150         for (int i = 0; i < numInputs / 2 + 1; i++) {
1151             XmmReg r1, r2;
1152             VEX1(movdqu, r1, xmmword_ptr[regptrs + 16 * i]);
1153             VEX1(movdqu, r2, xmmword_ptr[regoffs + 16 * i]);
1154             VEX2(paddq, r1, r1, r2);
1155             VEX1(movdqu, xmmword_ptr[regptrs + 16 * i], r1);
1156         }
1157 #else
1158         for (int i = 0; i < numInputs / 4 + 1; i++) {
1159             XmmReg r1, r2;
1160             VEX1(movdqu, r1, xmmword_ptr[regptrs + 16 * i]);
1161             VEX1(movdqu, r2, xmmword_ptr[regoffs + 16 * i]);
1162             VEX2(paddd, r1, r1, r2);
1163             VEX1(movdqu, xmmword_ptr[regptrs + 16 * i], r1);
1164         }
1165 #endif
1166 
1167         jit::sub(niter, 1);
1168         jnz("wloop");
1169     }
1170 
1171 public:
ExprCompiler128(int numInputs)1172     explicit ExprCompiler128(int numInputs) : cpuFeatures(*getCPUFeatures()), numInputs(numInputs), curLabel() {}
1173 
getCode()1174     std::pair<ExprData::ProcessLineProc, size_t> getCode() override
1175     {
1176         size_t size;
1177         if (jit::GetCode() && (size = GetCodeSize())) {
1178 #ifdef VS_TARGET_OS_WINDOWS
1179             void *ptr = VirtualAlloc(nullptr, size, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
1180 #else
1181             void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, 0, 0);
1182 #endif
1183             memcpy(ptr, jit::GetCode(), size);
1184             return {reinterpret_cast<ExprData::ProcessLineProc>(ptr), size};
1185         }
1186         return {nullptr, 0};
1187     }
1188 #undef VEX2IMM
1189 #undef VEX2
1190 #undef VEX1IMM
1191 #undef VEX1
1192 #undef EMIT
1193 };
1194 
1195 constexpr ExprUnion ExprCompiler128::constData alignas(16)[53][4];
1196 
1197 class ExprCompiler256 : public ExprCompiler, private jitasm::function<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t> {
1198     typedef jitasm::function<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t> jit;
1199     friend struct jitasm::function<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t>;
1200     friend struct jitasm::function_cdecl<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t>;
1201 
1202 #define SPLAT(x) { (x), (x), (x), (x), (x), (x), (x), (x) }
1203     static constexpr ExprUnion constData alignas(32)[53][8] = {
1204         SPLAT(0x7FFFFFFF), // absmask
1205         SPLAT(0x80000000), // negmask
1206         SPLAT(0x7F), // x7F
1207         SPLAT(0x00800000), // min_norm_pos
1208         SPLAT(~0x7F800000), // inv_mant_mask
1209         SPLAT(1.0f), // float_one
1210         SPLAT(0.5f), // float_half
1211         SPLAT(255.0f), // float_255
1212         SPLAT(511.0f), // float_511
1213         SPLAT(1023.0f), // float_1023
1214         SPLAT(2047.0f), // float_2047
1215         SPLAT(4095.0f), // float_4095
1216         SPLAT(8191.0f), // float_8191
1217         SPLAT(16383.0f), // float_16383
1218         SPLAT(32767.0f), // float_32767
1219         SPLAT(65535.0f), // float_65535
1220         SPLAT(static_cast<int32_t>(0x80008000)), // i16min_epi16
1221         SPLAT(static_cast<int32_t>(0xFFFF8000)), // i16min_epi32
1222         SPLAT(88.3762626647949f), // exp_hi
1223         SPLAT(-88.3762626647949f), // exp_lo
1224         SPLAT(1.44269504088896341f), // log2e
1225         SPLAT(0.693359375f), // exp_c1
1226         SPLAT(-2.12194440e-4f), // exp_c2
1227         SPLAT(1.9875691500E-4f), // exp_p0
1228         SPLAT(1.3981999507E-3f), // exp_p1
1229         SPLAT(8.3334519073E-3f), // exp_p2
1230         SPLAT(4.1665795894E-2f), // exp_p3
1231         SPLAT(1.6666665459E-1f), // exp_p4
1232         SPLAT(5.0000001201E-1f), // exp_p5
1233         SPLAT(0.707106781186547524f), // sqrt_1_2
1234         SPLAT(7.0376836292E-2f), // log_p0
1235         SPLAT(-1.1514610310E-1f), // log_p1
1236         SPLAT(1.1676998740E-1f), // log_p2
1237         SPLAT(-1.2420140846E-1f), // log_p3
1238         SPLAT(+1.4249322787E-1f), // log_p4
1239         SPLAT(-1.6668057665E-1f), // log_p5
1240         SPLAT(+2.0000714765E-1f), // log_p6
1241         SPLAT(-2.4999993993E-1f), // log_p7
1242         SPLAT(+3.3333331174E-1f), // log_p8
1243         SPLAT(0x3ea2f983), // float_invpi, 1/pi
1244         SPLAT(0x4b400000), // float_rintf
1245         SPLAT(0x40490000), // float_pi1
1246         SPLAT(0x3a7da000), // float_pi2
1247         SPLAT(0x34222000), // float_pi3
1248         SPLAT(0x2cb4611a), // float_pi4
1249         SPLAT(0xbe2aaaa6), // float_sinC3
1250         SPLAT(0x3c08876a), // float_sinC5
1251         SPLAT(0xb94fb7ff), // float_sinC7
1252         SPLAT(0x362edef8), // float_sinC9
1253         SPLAT(static_cast<int32_t>(0xBEFFFFE2)), // float_cosC2
1254         SPLAT(0x3D2AA73C), // float_cosC4
1255         SPLAT(static_cast<int32_t>(0XBAB58D50)), // float_cosC6
1256         SPLAT(0x37C1AD76), // float_cosC8
1257     };
1258 
1259     struct ConstantIndex {
1260         static constexpr int absmask = 0;
1261         static constexpr int negmask = 1;
1262         static constexpr int x7F = 2;
1263         static constexpr int min_norm_pos = 3;
1264         static constexpr int inv_mant_mask = 4;
1265         static constexpr int float_one = 5;
1266         static constexpr int float_half = 6;
1267         static constexpr int float_255 = 7;
1268         static constexpr int float_511 = 8;
1269         static constexpr int float_1023 = 9;
1270         static constexpr int float_2047 = 10;
1271         static constexpr int float_4095 = 11;
1272         static constexpr int float_8191 = 12;
1273         static constexpr int float_16383 = 13;
1274         static constexpr int float_32767 = 14;
1275         static constexpr int float_65535 = 15;
1276         static constexpr int i16min_epi16 = 16;
1277         static constexpr int i16min_epi32 = 17;
1278         static constexpr int exp_hi = 18;
1279         static constexpr int exp_lo = 19;
1280         static constexpr int log2e = 20;
1281         static constexpr int exp_c1 = 21;
1282         static constexpr int exp_c2 = 22;
1283         static constexpr int exp_p0 = 23;
1284         static constexpr int exp_p1 = 24;
1285         static constexpr int exp_p2 = 25;
1286         static constexpr int exp_p3 = 26;
1287         static constexpr int exp_p4 = 27;
1288         static constexpr int exp_p5 = 28;
1289         static constexpr int sqrt_1_2 = 29;
1290         static constexpr int log_p0 = 30;
1291         static constexpr int log_p1 = 31;
1292         static constexpr int log_p2 = 32;
1293         static constexpr int log_p3 = 33;
1294         static constexpr int log_p4 = 34;
1295         static constexpr int log_p5 = 35;
1296         static constexpr int log_p6 = 36;
1297         static constexpr int log_p7 = 37;
1298         static constexpr int log_p8 = 38;
1299         static constexpr int log_q1 = exp_c2;
1300         static constexpr int log_q2 = exp_c1;
1301         static constexpr int float_invpi = 39;
1302         static constexpr int float_rintf = 40;
1303         static constexpr int float_pi1 = 41;
1304         static constexpr int float_pi2 = float_pi1 + 1;
1305         static constexpr int float_pi3 = float_pi1 + 2;
1306         static constexpr int float_pi4 = float_pi1 + 3;
1307         static constexpr int float_sinC3 = 45;
1308         static constexpr int float_sinC5 = float_sinC3 + 1;
1309         static constexpr int float_sinC7 = float_sinC3 + 2;
1310         static constexpr int float_sinC9 = float_sinC3 + 3;
1311         static constexpr int float_cosC2 = 49;
1312         static constexpr int float_cosC4 = float_cosC2 + 1;
1313         static constexpr int float_cosC6 = float_cosC2 + 2;
1314         static constexpr int float_cosC8 = float_cosC2 + 3;
1315     };
1316 #undef SPLAT
1317 
1318     // JitASM compiles everything from main(), so record the operations for later.
1319     std::vector<std::function<void(Reg, YmmReg, Reg, std::unordered_map<int, YmmReg> &)>> deferred;
1320 
1321     CPUFeatures cpuFeatures;
1322     int numInputs;
1323     int curLabel;
1324 
1325 #define EMIT() [this, insn](Reg regptrs, YmmReg zero, Reg constants, std::unordered_map<int, YmmReg> &bytecodeRegs)
1326 
load8(const ExprInstruction & insn)1327     void load8(const ExprInstruction &insn) override
1328     {
1329         deferred.push_back(EMIT()
1330         {
1331             auto t1 = bytecodeRegs[insn.dst];
1332             Reg a;
1333             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1334             vpmovzxbd(t1, mmword_ptr[a]);
1335             vcvtdq2ps(t1, t1);
1336         });
1337     }
1338 
load16(const ExprInstruction & insn)1339     void load16(const ExprInstruction &insn) override
1340     {
1341         deferred.push_back(EMIT()
1342         {
1343             auto t1 = bytecodeRegs[insn.dst];
1344             Reg a;
1345             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1346             vpmovzxwd(t1, xmmword_ptr[a]);
1347             vcvtdq2ps(t1, t1);
1348         });
1349     }
1350 
loadF16(const ExprInstruction & insn)1351     void loadF16(const ExprInstruction &insn) override
1352     {
1353         deferred.push_back(EMIT()
1354         {
1355             auto t1 = bytecodeRegs[insn.dst];
1356             Reg a;
1357             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1358             vcvtph2ps(t1, xmmword_ptr[a]);
1359         });
1360     }
1361 
loadF32(const ExprInstruction & insn)1362     void loadF32(const ExprInstruction &insn) override
1363     {
1364         deferred.push_back(EMIT()
1365         {
1366             auto t1 = bytecodeRegs[insn.dst];
1367             Reg a;
1368             mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1369             vmovaps(t1, ymmword_ptr[a]);
1370         });
1371     }
1372 
loadConst(const ExprInstruction & insn)1373     void loadConst(const ExprInstruction &insn) override
1374     {
1375         deferred.push_back(EMIT()
1376         {
1377             auto t1 = bytecodeRegs[insn.dst];
1378 
1379             if (insn.op.imm.f == 0.0f) {
1380                 vmovaps(t1, zero);
1381                 return;
1382             }
1383 
1384             XmmReg r1;
1385             Reg32 a;
1386             mov(a, insn.op.imm.u);
1387             vmovd(r1, a);
1388             vbroadcastss(t1, r1);
1389         });
1390     }
1391 
store8(const ExprInstruction & insn)1392     void store8(const ExprInstruction &insn) override
1393     {
1394         deferred.push_back(EMIT()
1395         {
1396             auto t1 = bytecodeRegs[insn.src1];
1397             YmmReg r1;
1398             Reg a;
1399             vminps(r1, t1, ymmword_ptr[constants + ConstantIndex::float_255 * 32]);
1400             vcvtps2dq(r1, r1);
1401             vpackssdw(r1, r1, r1);
1402             vpermq(r1, r1, 0x08);
1403             vpackuswb(r1, r1, zero);
1404             mov(a, ptr[regptrs]);
1405             vmovq(qword_ptr[a], r1.as128());
1406         });
1407     }
1408 
store16(const ExprInstruction & insn)1409     void store16(const ExprInstruction &insn) override
1410     {
1411         deferred.push_back(EMIT()
1412         {
1413             int depth = insn.op.imm.u;
1414             auto t1 = bytecodeRegs[insn.src1];
1415             YmmReg r1, limit;
1416             Reg a;
1417             vminps(r1, t1, ymmword_ptr[constants + (ConstantIndex::float_255 + depth - 8) * 32]);
1418             vcvtps2dq(r1, r1);
1419             vpackusdw(r1, r1, r1);
1420             vpermq(r1, r1, 0x08);
1421             mov(a, ptr[regptrs]);
1422             vmovaps(xmmword_ptr[a], r1.as128());
1423         });
1424     }
1425 
storeF16(const ExprInstruction & insn)1426     void storeF16(const ExprInstruction &insn) override
1427     {
1428         deferred.push_back(EMIT()
1429         {
1430             auto t1 = bytecodeRegs[insn.src1];
1431             Reg a;
1432             mov(a, ptr[regptrs]);
1433             vcvtps2ph(xmmword_ptr[a], t1, 0);
1434         });
1435     }
1436 
storeF32(const ExprInstruction & insn)1437     void storeF32(const ExprInstruction &insn) override
1438     {
1439         deferred.push_back(EMIT()
1440         {
1441             auto t1 = bytecodeRegs[insn.src1];
1442             Reg a;
1443             mov(a, ptr[regptrs]);
1444             vmovaps(ymmword_ptr[a], t1);
1445         });
1446     }
1447 
1448 #define BINARYOP(op) \
1449 do { \
1450   auto t1 = bytecodeRegs[insn.src1]; \
1451   auto t2 = bytecodeRegs[insn.src2]; \
1452   auto t3 = bytecodeRegs[insn.dst]; \
1453   op(t3, t1, t2); \
1454 } while (0)
add(const ExprInstruction & insn)1455     void add(const ExprInstruction &insn) override
1456     {
1457         deferred.push_back(EMIT()
1458         {
1459             BINARYOP(vaddps);
1460         });
1461     }
1462 
sub(const ExprInstruction & insn)1463     void sub(const ExprInstruction &insn) override
1464     {
1465         deferred.push_back(EMIT()
1466         {
1467             BINARYOP(vsubps);
1468         });
1469     }
1470 
mul(const ExprInstruction & insn)1471     void mul(const ExprInstruction &insn) override
1472     {
1473         deferred.push_back(EMIT()
1474         {
1475             BINARYOP(vmulps);
1476         });
1477     }
1478 
div(const ExprInstruction & insn)1479     void div(const ExprInstruction &insn) override
1480     {
1481         deferred.push_back(EMIT()
1482         {
1483             BINARYOP(vdivps);
1484         });
1485     }
1486 
fma(const ExprInstruction & insn)1487     void fma(const ExprInstruction &insn) override
1488     {
1489         deferred.push_back(EMIT()
1490         {
1491             FMAType type = static_cast<FMAType>(insn.op.imm.u);
1492 
1493             // t1 + t2 * t3
1494             auto t1 = bytecodeRegs[insn.src1];
1495             auto t2 = bytecodeRegs[insn.src2];
1496             auto t3 = bytecodeRegs[insn.src3];
1497             auto t4 = bytecodeRegs[insn.dst];
1498 
1499 #define FMA3(op) \
1500 do { \
1501   if (insn.dst == insn.src1) { \
1502     op##231ps(t1, t2, t3); \
1503   } else if (insn.dst == insn.src2) { \
1504     op##132ps(t2, t1, t3); \
1505   } else if (insn.dst == insn.src3) { \
1506     op##132ps(t3, t1, t2); \
1507   } else { \
1508     vmovaps(t4, t1); \
1509     op##231ps(t4, t2, t3); \
1510   } \
1511 } while (0)
1512             switch (type) {
1513             case FMAType::FMADD: FMA3(vfmadd); break;
1514             case FMAType::FMSUB: FMA3(vfmsub); break;
1515             case FMAType::FNMADD: FMA3(vfnmadd); break;
1516             case FMAType::FNMSUB: FMA3(vfnmsub); break;
1517             }
1518 #undef FMA3
1519         });
1520     }
1521 
max(const ExprInstruction & insn)1522     void max(const ExprInstruction &insn) override
1523     {
1524         deferred.push_back(EMIT()
1525         {
1526             BINARYOP(vmaxps);
1527         });
1528     }
1529 
min(const ExprInstruction & insn)1530     void min(const ExprInstruction &insn) override
1531     {
1532         deferred.push_back(EMIT()
1533         {
1534             BINARYOP(vminps);
1535         });
1536     }
1537 #undef BINARYOP
1538 
sqrt(const ExprInstruction & insn)1539     void sqrt(const ExprInstruction &insn) override
1540     {
1541         deferred.push_back(EMIT()
1542         {
1543             auto t1 = bytecodeRegs[insn.src1];
1544             auto t2 = bytecodeRegs[insn.dst];
1545             vmaxps(t2, t1, zero);
1546             vsqrtps(t2, t2);
1547         });
1548     }
1549 
abs(const ExprInstruction & insn)1550     void abs(const ExprInstruction &insn) override
1551     {
1552         deferred.push_back(EMIT()
1553         {
1554             auto t1 = bytecodeRegs[insn.src1];
1555             auto t2 = bytecodeRegs[insn.dst];
1556             vandps(t2, t1, ymmword_ptr[constants + ConstantIndex::absmask * 32]);
1557         });
1558     }
1559 
neg(const ExprInstruction & insn)1560     void neg(const ExprInstruction &insn) override
1561     {
1562         deferred.push_back(EMIT()
1563         {
1564             auto t1 = bytecodeRegs[insn.src1];
1565             auto t2 = bytecodeRegs[insn.dst];
1566             vxorps(t2, t1, ymmword_ptr[constants + ConstantIndex::negmask * 32]);
1567         });
1568     }
1569 
not_(const ExprInstruction & insn)1570     void not_(const ExprInstruction &insn) override
1571     {
1572         deferred.push_back(EMIT()
1573         {
1574             auto t1 = bytecodeRegs[insn.src1];
1575             auto t2 = bytecodeRegs[insn.dst];
1576             YmmReg r1;
1577             vcmpps(t2, t1, zero, _CMP_LE_OS);
1578             vandps(t2, t2, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1579         });
1580     }
1581 
1582 #define LOGICOP(op) \
1583 do { \
1584   auto t1 = bytecodeRegs[insn.src1]; \
1585   auto t2 = bytecodeRegs[insn.src2]; \
1586   auto t3 = bytecodeRegs[insn.dst]; \
1587   YmmReg tmp; \
1588   vcmpps(tmp, t1, zero, _CMP_NLE_US); \
1589   vcmpps(t3, t2, zero, _CMP_NLE_US); \
1590   op(t3, t3, tmp); \
1591   vandps(t3, t3, ymmword_ptr[constants + ConstantIndex::float_one * 32]); \
1592 } while (0)
1593 
and_(const ExprInstruction & insn)1594     void and_(const ExprInstruction &insn) override
1595     {
1596         deferred.push_back(EMIT()
1597         {
1598             LOGICOP(vandps);
1599         });
1600     }
1601 
or_(const ExprInstruction & insn)1602     void or_(const ExprInstruction &insn) override
1603     {
1604         deferred.push_back(EMIT()
1605         {
1606             LOGICOP(vorps);
1607         });
1608     }
1609 
xor_(const ExprInstruction & insn)1610     void xor_(const ExprInstruction &insn) override
1611     {
1612         deferred.push_back(EMIT()
1613         {
1614             LOGICOP(vxorps);
1615         });
1616     }
1617 #undef LOGICOP
1618 
cmp(const ExprInstruction & insn)1619     void cmp(const ExprInstruction &insn) override
1620     {
1621         deferred.push_back(EMIT()
1622         {
1623             auto t1 = bytecodeRegs[insn.src1];
1624             auto t2 = bytecodeRegs[insn.src2];
1625             auto t3 = bytecodeRegs[insn.dst];
1626             vcmpps(t3, t1, t2, insn.op.imm.u);
1627             vandps(t3, t3, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1628         });
1629     }
1630 
ternary(const ExprInstruction & insn)1631     void ternary(const ExprInstruction &insn) override
1632     {
1633         deferred.push_back(EMIT()
1634         {
1635             auto t1 = bytecodeRegs[insn.src1];
1636             auto t2 = bytecodeRegs[insn.src2];
1637             auto t3 = bytecodeRegs[insn.src3];
1638             auto t4 = bytecodeRegs[insn.dst];
1639             YmmReg r1;
1640             vcmpps(r1, t1, zero, _CMP_NLE_US);
1641             vblendvps(t4, t3, t2, r1);
1642         });
1643     }
1644 
exp_(YmmReg x,YmmReg one,Reg constants)1645     void exp_(YmmReg x, YmmReg one, Reg constants)
1646     {
1647         YmmReg fx, emm0, etmp, y, mask, z;
1648         vminps(x, x, ymmword_ptr[constants + ConstantIndex::exp_hi * 32]);
1649         vmaxps(x, x, ymmword_ptr[constants + ConstantIndex::exp_lo * 32]);
1650         vmovaps(fx, ymmword_ptr[constants + ConstantIndex::log2e * 32]);
1651         vfmadd213ps(fx, x, ymmword_ptr[constants + ConstantIndex::float_half * 32]);
1652         vcvttps2dq(emm0, fx);
1653         vcvtdq2ps(etmp, emm0);
1654         vcmpps(mask, etmp, fx, _CMP_NLE_US);
1655         vandps(mask, mask, one);
1656         vsubps(fx, etmp, mask);
1657         vfnmadd231ps(x, fx, ymmword_ptr[constants + ConstantIndex::exp_c1 * 32]);
1658         vfnmadd231ps(x, fx, ymmword_ptr[constants + ConstantIndex::exp_c2 * 32]);
1659         vmulps(z, x, x);
1660         vmovaps(y, ymmword_ptr[constants + ConstantIndex::exp_p0 * 32]);
1661         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p1 * 32]);
1662         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p2 * 32]);
1663         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p3 * 32]);
1664         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p4 * 32]);
1665         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p5 * 32]);
1666         vfmadd213ps(y, z, x);
1667         vaddps(y, y, one);
1668         vcvttps2dq(emm0, fx);
1669         vpaddd(emm0, emm0, ymmword_ptr[constants + ConstantIndex::x7F * 32]);
1670         vpslld(emm0, emm0, 23);
1671         vmulps(x, y, emm0);
1672     }
1673 
log_(YmmReg x,YmmReg zero,YmmReg one,Reg constants)1674     void log_(YmmReg x, YmmReg zero, YmmReg one, Reg constants)
1675     {
1676         YmmReg emm0, invalid_mask, mask, y, etmp, z;
1677         vcmpps(invalid_mask, zero, x, _CMP_NLT_US);
1678         vmaxps(x, x, ymmword_ptr[constants + ConstantIndex::min_norm_pos * 32]);
1679         vpsrld(emm0, x, 23);
1680         vandps(x, x, ymmword_ptr[constants + ConstantIndex::inv_mant_mask * 32]);
1681         vorps(x, x, ymmword_ptr[constants + ConstantIndex::float_half * 32]);
1682         vpsubd(emm0, emm0, ymmword_ptr[constants + ConstantIndex::x7F * 32]);
1683         vcvtdq2ps(emm0, emm0);
1684         vaddps(emm0, emm0, one);
1685         vcmpps(mask, x, ymmword_ptr[constants + ConstantIndex::sqrt_1_2 * 32], _CMP_LT_OS);
1686         vandps(etmp, x, mask);
1687         vsubps(x, x, one);
1688         vandps(mask, mask, one);
1689         vsubps(emm0, emm0, mask);
1690         vaddps(x, x, etmp);
1691         vmulps(z, x, x);
1692         vmovaps(y, ymmword_ptr[constants + ConstantIndex::log_p0 * 32]);
1693         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p1 * 32]);
1694         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p2 * 32]);
1695         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p3 * 32]);
1696         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p4 * 32]);
1697         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p5 * 32]);
1698         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p6 * 32]);
1699         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p7 * 32]);
1700         vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p8 * 32]);
1701         vmulps(y, y, x);
1702         vmulps(y, y, z);
1703         vfmadd231ps(y, emm0, ymmword_ptr[constants + ConstantIndex::log_q1 * 32]);
1704         vfnmadd231ps(y, z, ymmword_ptr[constants + ConstantIndex::float_half * 32]);
1705         vaddps(x, x, y);
1706         vfmadd231ps(x, emm0, ymmword_ptr[constants + ConstantIndex::log_q2 * 32]);
1707         vorps(x, x, invalid_mask);
1708     }
1709 
exp(const ExprInstruction & insn)1710     void exp(const ExprInstruction &insn) override
1711     {
1712         deferred.push_back(EMIT()
1713         {
1714             auto t1 = bytecodeRegs[insn.src1];
1715             auto t2 = bytecodeRegs[insn.dst];
1716             YmmReg one;
1717             vmovaps(one, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1718             vmovaps(t1, t2);
1719             exp_(t1, one, constants);
1720         });
1721     }
1722 
log(const ExprInstruction & insn)1723     void log(const ExprInstruction &insn) override
1724     {
1725         deferred.push_back(EMIT()
1726         {
1727             auto t1 = bytecodeRegs[insn.src1];
1728             auto t2 = bytecodeRegs[insn.dst];
1729             YmmReg one;
1730             vmovaps(one, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1731             vmovaps(t1, t2);
1732             log_(t1, zero, one, constants);
1733         });
1734     }
1735 
pow(const ExprInstruction & insn)1736     void pow(const ExprInstruction &insn) override
1737     {
1738         deferred.push_back(EMIT()
1739         {
1740             auto t1 = bytecodeRegs[insn.src1];
1741             auto t2 = bytecodeRegs[insn.src2];
1742             auto t3 = bytecodeRegs[insn.dst];
1743 
1744             YmmReg r1, one;
1745             vmovaps(one, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1746             vmovaps(r1, t1);
1747             log_(r1, zero, one, constants);
1748             vmulps(r1, r1, t2);
1749             exp_(r1, one, constants);
1750             vmovaps(t3, r1);
1751         });
1752     }
1753 
sincos_(bool issin,const ExprInstruction & insn,Reg constants,std::unordered_map<int,YmmReg> & bytecodeRegs)1754     void sincos_(bool issin, const ExprInstruction &insn, Reg constants, std::unordered_map<int, YmmReg> &bytecodeRegs)
1755     {
1756         auto x = bytecodeRegs[insn.src1];
1757         auto y = bytecodeRegs[insn.dst];
1758         YmmReg t1, sign, t2, t3, t4;
1759         // Remove sign
1760         vmovaps(t1, ymmword_ptr[constants + ConstantIndex::absmask * 32]);
1761         if (issin) {
1762             vmovaps(sign, t1);
1763             vandnps(sign, sign, x);
1764         } else {
1765             vxorps(sign, sign, sign);
1766         }
1767         vandps(t1, t1, x);
1768         // Range reduction
1769         vmovaps(t3, ymmword_ptr[constants + ConstantIndex::float_rintf * 32]);
1770         vmulps(t2, t1, ymmword_ptr[constants + ConstantIndex::float_invpi * 32]);
1771         vaddps(t2, t2, t3);
1772         vpslld(t4, t2, 31);
1773         vxorps(sign, sign, t4);
1774         vsubps(t2, t2, t3);
1775         vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi1 * 32]);
1776         vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi2 * 32]);
1777         vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi3 * 32]);
1778         vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi4 * 32]);
1779         if (issin) {
1780             // Evaluate minimax polynomial for sin(x) in [-pi/2, pi/2] interval
1781             // Y <- X + X * X^2 * (C3 + X^2 * (C5 + X^2 * (C7 + X^2 * C9)))
1782             vmulps(t2, t1, t1);
1783             vmovaps(t3, ymmword_ptr[constants + ConstantIndex::float_sinC7 * 32]);
1784             vfmadd231ps(t3, t2, ymmword_ptr[constants + ConstantIndex::float_sinC9 * 32]);
1785             vfmadd213ps(t3, t2, ymmword_ptr[constants + ConstantIndex::float_sinC5 * 32]);
1786             vfmadd213ps(t3, t2, ymmword_ptr[constants + ConstantIndex::float_sinC3 * 32]);
1787             vmulps(t3, t3, t2);
1788             vfmadd231ps(t1, t1, t3);
1789         } else {
1790             // Evaluate minimax polynomial for cos(x) in [-pi/2, pi/2] interval
1791             // Y <- 1 + X^2 * (C2 + X^2 * (C4 + X^2 * (C6 + X^2 * C8)))
1792             vmulps(t2, t1, t1);
1793             vmovaps(t1, ymmword_ptr[constants + ConstantIndex::float_cosC6 * 32]);
1794             vfmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_cosC8 * 32]);
1795             vfmadd213ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_cosC4 * 32]);
1796             vfmadd213ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_cosC2 * 32]);
1797             vfmadd213ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1798         }
1799         // Apply sign
1800         vxorps(y, t1, sign);
1801     }
1802 
sin(const ExprInstruction & insn)1803     void sin(const ExprInstruction &insn) override
1804     {
1805         deferred.push_back(EMIT()
1806         {
1807             sincos_(true, insn, constants, bytecodeRegs);
1808         });
1809     }
1810 
cos(const ExprInstruction & insn)1811     void cos(const ExprInstruction &insn) override
1812     {
1813         deferred.push_back(EMIT()
1814         {
1815             sincos_(false, insn, constants, bytecodeRegs);
1816         });
1817     }
1818 
main(Reg regptrs,Reg regoffs,Reg niter)1819     void main(Reg regptrs, Reg regoffs, Reg niter)
1820     {
1821         std::unordered_map<int, YmmReg> bytecodeRegs;
1822         YmmReg zero;
1823         vpxor(zero, zero, zero);
1824         Reg constants;
1825         mov(constants, (uintptr_t)constData);
1826 
1827         L("wloop");
1828 
1829         for (const auto &f : deferred) {
1830             f(regptrs, zero, constants, bytecodeRegs);
1831         }
1832 
1833 #if UINTPTR_MAX > UINT32_MAX
1834         for (int i = 0; i < numInputs / 4 + 1; i++) {
1835             YmmReg r1, r2;
1836             vmovdqu(r1, ymmword_ptr[regptrs + 32 * i]);
1837             vmovdqu(r2, ymmword_ptr[regoffs + 32 * i]);
1838             vpaddq(r1, r1, r2);
1839             vmovdqu(ymmword_ptr[regptrs + 32 * i], r1);
1840         }
1841 #else
1842         for (int i = 0; i < numInputs / 8 + 1; i++) {
1843             YmmReg r1, r2;
1844             vmovdqu(r1, ymmword_ptr[regptrs + 32 * i]);
1845             vmovdqu(r2, ymmword_ptr[regoffs + 32 * i]);
1846             vpaddd(r1, r1, r2);
1847             vmovdqu(ymmword_ptr[regptrs + 32 * i], r1);
1848         }
1849 #endif
1850 
1851         jit::sub(niter, 1);
1852         jnz("wloop");
1853     }
1854 
1855 public:
ExprCompiler256(int numInputs)1856     explicit ExprCompiler256(int numInputs) : cpuFeatures(*getCPUFeatures()), numInputs(numInputs) {}
1857 
getCode()1858     std::pair<ExprData::ProcessLineProc, size_t> getCode() override
1859     {
1860         size_t size;
1861         if (jit::GetCode(true) && (size = GetCodeSize())) {
1862 #ifdef VS_TARGET_OS_WINDOWS
1863             void *ptr = VirtualAlloc(nullptr, size, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
1864 #else
1865             void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, 0, 0);
1866 #endif
1867             memcpy(ptr, jit::GetCode(true), size);
1868             return {reinterpret_cast<ExprData::ProcessLineProc>(ptr), size};
1869         }
1870         return {nullptr, 0};
1871     }
1872 #undef EMIT
1873 };
1874 
1875 constexpr ExprUnion ExprCompiler256::constData alignas(32)[53][8];
1876 
make_compiler(int numInputs,int cpulevel)1877 std::unique_ptr<ExprCompiler> make_compiler(int numInputs, int cpulevel)
1878 {
1879     if (getCPUFeatures()->avx2 && cpulevel >= VS_CPU_LEVEL_AVX2)
1880         return std::unique_ptr<ExprCompiler>(new ExprCompiler256(numInputs));
1881     else
1882         return std::unique_ptr<ExprCompiler>(new ExprCompiler128(numInputs));
1883 }
1884 #endif
1885 
1886 class ExprInterpreter {
1887     const ExprInstruction *bytecode;
1888     size_t numInsns;
1889     std::vector<float> registers;
1890 
1891     template <class T>
clamp_int(float x,int depth=std::numeric_limits<T>::digits)1892     static T clamp_int(float x, int depth = std::numeric_limits<T>::digits)
1893     {
1894         float maxval = static_cast<float>((1U << depth) - 1);
1895         return static_cast<T>(std::lrint(std::min(std::max(x, static_cast<float>(std::numeric_limits<T>::min())), maxval)));
1896     }
1897 
bool2float(bool x)1898     static float bool2float(bool x) { return x ? 1.0f : 0.0f; }
float2bool(float x)1899     static bool float2bool(float x) { return x > 0.0f; }
1900 public:
ExprInterpreter(const ExprInstruction * bytecode,size_t numInsns)1901     ExprInterpreter(const ExprInstruction *bytecode, size_t numInsns) : bytecode(bytecode), numInsns(numInsns)
1902     {
1903         int maxreg = 0;
1904         for (size_t i = 0; i < numInsns; ++i) {
1905             maxreg = std::max(maxreg, bytecode[i].dst);
1906         }
1907         registers.resize(maxreg + 1);
1908     }
1909 
eval(const uint8_t * const * srcp,uint8_t * dstp,int x)1910     void eval(const uint8_t * const *srcp, uint8_t *dstp, int x)
1911     {
1912         for (size_t i = 0; i < numInsns; ++i) {
1913             const ExprInstruction &insn = bytecode[i];
1914 
1915 #define SRC1 registers[insn.src1]
1916 #define SRC2 registers[insn.src2]
1917 #define SRC3 registers[insn.src3]
1918 #define DST registers[insn.dst]
1919             switch (insn.op.type) {
1920             case ExprOpType::MEM_LOAD_U8: DST = reinterpret_cast<const uint8_t *>(srcp[insn.op.imm.u])[x]; break;
1921             case ExprOpType::MEM_LOAD_U16: DST = reinterpret_cast<const uint16_t *>(srcp[insn.op.imm.u])[x]; break;
1922             case ExprOpType::MEM_LOAD_F16: DST = 0; break;
1923             case ExprOpType::MEM_LOAD_F32: DST = reinterpret_cast<const float *>(srcp[insn.op.imm.u])[x]; break;
1924             case ExprOpType::CONSTANT: DST = insn.op.imm.f; break;
1925             case ExprOpType::ADD: DST = SRC1 + SRC2; break;
1926             case ExprOpType::SUB: DST = SRC1 - SRC2; break;
1927             case ExprOpType::MUL: DST = SRC1 * SRC2; break;
1928             case ExprOpType::DIV: DST = SRC1 / SRC2; break;
1929             case ExprOpType::FMA:
1930                 switch (static_cast<FMAType>(insn.op.imm.u)) {
1931                 case FMAType::FMADD: DST = SRC2 * SRC3 + SRC1; break;
1932                 case FMAType::FMSUB: DST = SRC2 * SRC3 - SRC1; break;
1933                 case FMAType::FNMADD: DST = -(SRC2 * SRC3) + SRC1; break;
1934                 case FMAType::FNMSUB: DST = -(SRC2 * SRC3) - SRC1; break;
1935                 };
1936                 break;
1937             case ExprOpType::MAX: DST = std::max(SRC1, SRC2); break;
1938             case ExprOpType::MIN: DST = std::min(SRC1, SRC2); break;
1939             case ExprOpType::EXP: DST = std::exp(SRC1); break;
1940             case ExprOpType::LOG: DST = std::log(SRC1); break;
1941             case ExprOpType::POW: DST = std::pow(SRC1, SRC2); break;
1942             case ExprOpType::SQRT: DST = std::sqrt(SRC1); break;
1943             case ExprOpType::SIN: DST = std::sin(SRC1); break;
1944             case ExprOpType::COS: DST = std::cos(SRC1); break;
1945             case ExprOpType::ABS: DST = std::fabs(SRC1); break;
1946             case ExprOpType::NEG: DST = -SRC1; break;
1947             case ExprOpType::CMP:
1948                 switch (static_cast<ComparisonType>(insn.op.imm.u)) {
1949                 case ComparisonType::EQ: DST = bool2float(SRC1 == SRC2); break;
1950                 case ComparisonType::LT: DST = bool2float(SRC1 < SRC2); break;
1951                 case ComparisonType::LE: DST = bool2float(SRC1 <= SRC2); break;
1952                 case ComparisonType::NEQ: DST = bool2float(SRC1 != SRC2); break;
1953                 case ComparisonType::NLT: DST = bool2float(SRC1 >= SRC2); break;
1954                 case ComparisonType::NLE: DST = bool2float(SRC1 > SRC2); break;
1955                 }
1956                 break;
1957             case ExprOpType::TERNARY: DST = float2bool(SRC1) ? SRC2 : SRC3; break;
1958             case ExprOpType::AND: DST = bool2float((float2bool(SRC1) && float2bool(SRC2))); break;
1959             case ExprOpType::OR:  DST = bool2float((float2bool(SRC1) || float2bool(SRC2))); break;
1960             case ExprOpType::XOR: DST = bool2float((float2bool(SRC1) != float2bool(SRC2))); break;
1961             case ExprOpType::NOT: DST = bool2float(!float2bool(SRC1)); break;
1962             case ExprOpType::MEM_STORE_U8:  reinterpret_cast<uint8_t *>(dstp)[x] = clamp_int<uint8_t>(SRC1); return;
1963             case ExprOpType::MEM_STORE_U16: reinterpret_cast<uint16_t *>(dstp)[x] = clamp_int<uint16_t>(SRC1, insn.op.imm.u); return;
1964             case ExprOpType::MEM_STORE_F16: reinterpret_cast<uint16_t *>(dstp)[x] = 0; return;
1965             case ExprOpType::MEM_STORE_F32: reinterpret_cast<float *>(dstp)[x] = SRC1; return;
1966             default: vsFatal("illegal opcode"); return;
1967             }
1968 #undef DST
1969 #undef SRC3
1970 #undef SRC2
1971 #undef SRC1
1972         }
1973     }
1974 };
1975 
1976 struct ExpressionTreeNode {
1977     ExpressionTreeNode *parent;
1978     ExpressionTreeNode *left;
1979     ExpressionTreeNode *right;
1980     ExprOp op;
1981     int valueNum;
1982 
ExpressionTreeNode__anon2a0524d50111::ExpressionTreeNode1983     explicit ExpressionTreeNode(ExprOp op) : parent(), left(), right(), op(op), valueNum(-1) {}
1984 
setLeft__anon2a0524d50111::ExpressionTreeNode1985     void setLeft(ExpressionTreeNode *node)
1986     {
1987         if (left)
1988             left->parent = nullptr;
1989 
1990         left = node;
1991 
1992         if (left)
1993             left->parent = this;
1994     }
1995 
setRight__anon2a0524d50111::ExpressionTreeNode1996     void setRight(ExpressionTreeNode *node)
1997     {
1998         if (right)
1999             right->parent = nullptr;
2000 
2001         right = node;
2002 
2003         if (right)
2004             right->parent = this;
2005     }
2006 
2007     template <class T>
preorder__anon2a0524d50111::ExpressionTreeNode2008     void preorder(T visitor)
2009     {
2010         if (visitor(*this))
2011             return;
2012 
2013         if (left)
2014             left->preorder(visitor);
2015         if (right)
2016             right->preorder(visitor);
2017     }
2018 
2019     template <class T>
postorder__anon2a0524d50111::ExpressionTreeNode2020     void postorder(T visitor)
2021     {
2022         if (left)
2023             left->postorder(visitor);
2024         if (right)
2025             right->postorder(visitor);
2026         visitor(*this);
2027     }
2028 };
2029 
2030 class ExpressionTree {
2031     std::vector<std::unique_ptr<ExpressionTreeNode>> nodes;
2032     ExpressionTreeNode *root;
2033 public:
ExpressionTree()2034     ExpressionTree() : root() {}
2035 
getRoot()2036     ExpressionTreeNode *getRoot() { return root; }
getRoot() const2037     const ExpressionTreeNode *getRoot() const { return root; }
2038 
setRoot(ExpressionTreeNode * node)2039     void setRoot(ExpressionTreeNode *node) { root = node; }
2040 
makeNode(ExprOp data)2041     ExpressionTreeNode *makeNode(ExprOp data)
2042     {
2043         nodes.push_back(std::unique_ptr<ExpressionTreeNode>(new ExpressionTreeNode(data)));
2044         return nodes.back().get();
2045     }
2046 
clone(const ExpressionTreeNode * node)2047     ExpressionTreeNode *clone(const ExpressionTreeNode *node)
2048     {
2049         if (!node)
2050             return nullptr;
2051 
2052         ExpressionTreeNode *newnode = makeNode(node->op);
2053         newnode->setLeft(clone(node->left));
2054         newnode->setRight(clone(node->right));
2055         return newnode;
2056     }
2057 };
2058 
makeTreeNode(ExprOp data)2059 std::unique_ptr<ExpressionTreeNode> makeTreeNode(ExprOp data)
2060 {
2061     return std::unique_ptr<ExpressionTreeNode>(new ExpressionTreeNode(data));
2062 }
2063 
equalSubTree(const ExpressionTreeNode * lhs,const ExpressionTreeNode * rhs)2064 bool equalSubTree(const ExpressionTreeNode *lhs, const ExpressionTreeNode *rhs)
2065 {
2066     if (lhs->valueNum >= 0 && rhs->valueNum >= 0)
2067         return lhs->valueNum == rhs->valueNum;
2068     if (lhs->op.type != rhs->op.type || lhs->op.imm.u != rhs->op.imm.u)
2069         return false;
2070     if (!!lhs->left != !!rhs->left || !!lhs->right != !!rhs->right)
2071         return false;
2072     if (lhs->left && !equalSubTree(lhs->left, rhs->left))
2073         return false;
2074     if (lhs->right && !equalSubTree(lhs->right, rhs->right))
2075         return false;
2076     return true;
2077 }
2078 
tokenize(const std::string & expr)2079 std::vector<std::string> tokenize(const std::string &expr)
2080 {
2081     std::vector<std::string> tokens;
2082     auto it = expr.begin();
2083     auto prev = expr.begin();
2084 
2085     while (it != expr.end()) {
2086         char c = *it;
2087 
2088         if (std::isspace(c)) {
2089             if (it != prev)
2090                 tokens.push_back(expr.substr(prev - expr.begin(), it - prev));
2091             prev = it + 1;
2092         }
2093         ++it;
2094     }
2095     if (prev != expr.end())
2096         tokens.push_back(expr.substr(prev - expr.begin(), expr.end() - prev));
2097 
2098     return tokens;
2099 }
2100 
decodeToken(const std::string & token)2101 ExprOp decodeToken(const std::string &token)
2102 {
2103     static const std::unordered_map<std::string, ExprOp> simple{
2104         { "+",    { ExprOpType::ADD } },
2105         { "-",    { ExprOpType::SUB } },
2106         { "*",    { ExprOpType::MUL } },
2107         { "/",    { ExprOpType::DIV } } ,
2108         { "sqrt", { ExprOpType::SQRT } },
2109         { "abs",  { ExprOpType::ABS } },
2110         { "max",  { ExprOpType::MAX } },
2111         { "min",  { ExprOpType::MIN } },
2112         { "<",    { ExprOpType::CMP, static_cast<int>(ComparisonType::LT) } },
2113         { ">",    { ExprOpType::CMP, static_cast<int>(ComparisonType::NLE) } },
2114         { "=",    { ExprOpType::CMP, static_cast<int>(ComparisonType::EQ) } },
2115         { ">=",   { ExprOpType::CMP, static_cast<int>(ComparisonType::NLT) } },
2116         { "<=",   { ExprOpType::CMP, static_cast<int>(ComparisonType::LE) } },
2117         { "and",  { ExprOpType::AND } },
2118         { "or",   { ExprOpType::OR } },
2119         { "xor",  { ExprOpType::XOR } },
2120         { "not",  { ExprOpType::NOT } },
2121         { "?",    { ExprOpType::TERNARY } },
2122         { "exp",  { ExprOpType::EXP } },
2123         { "log",  { ExprOpType::LOG } },
2124         { "pow",  { ExprOpType::POW } },
2125         { "sin",  { ExprOpType::SIN } },
2126         { "cos",  { ExprOpType::COS } },
2127         { "dup",  { ExprOpType::DUP, 0 } },
2128         { "swap", { ExprOpType::SWAP, 1 } },
2129     };
2130 
2131     auto it = simple.find(token);
2132     if (it != simple.end()) {
2133         return it->second;
2134     } else if (token.size() == 1 && token[0] >= 'a' && token[0] <= 'z') {
2135         return{ ExprOpType::MEM_LOAD_U8, token[0] >= 'x' ? token[0] - 'x' : token[0] - 'a' + 3 };
2136     } else if (token.substr(0, 3) == "dup" || token.substr(0, 4) == "swap") {
2137         size_t prefix = token[0] == 'd' ? 3 : 4;
2138         size_t count = 0;
2139         int idx = -1;
2140 
2141         try {
2142             idx = std::stoi(token.substr(prefix), &count);
2143         } catch (...) {
2144             // ...
2145         }
2146 
2147         if (idx < 0 || prefix + count != token.size())
2148             throw std::runtime_error("illegal token: " + token);
2149         return{ token[0] == 'd' ? ExprOpType::DUP : ExprOpType::SWAP, idx };
2150     } else {
2151         float f;
2152         std::string s;
2153         std::istringstream numStream(token);
2154         numStream.imbue(std::locale::classic());
2155         if (!(numStream >> f))
2156             throw std::runtime_error("failed to convert '" + token + "' to float");
2157         if (numStream >> s)
2158             throw std::runtime_error("failed to convert '" + token + "' to float, not the whole token could be converted");
2159         return{ ExprOpType::CONSTANT, f };
2160     }
2161 }
2162 
parseExpr(const std::string & expr,const VSVideoInfo * const * vi,int numInputs)2163 ExpressionTree parseExpr(const std::string &expr, const VSVideoInfo * const *vi, int numInputs)
2164 {
2165     constexpr unsigned char numOperands[] = {
2166         0, // MEM_LOAD_U8
2167         0, // MEM_LOAD_U16
2168         0, // MEM_LOAD_F16
2169         0, // MEM_LOAD_F32
2170         0, // CONSTANT
2171         0, // MEM_STORE_U8
2172         0, // MEM_STORE_U16
2173         0, // MEM_STORE_F16
2174         0, // MEM_STORE_F32
2175         2, // ADD
2176         2, // SUB
2177         2, // MUL
2178         2, // DIV
2179         3, // FMA
2180         1, // SQRT
2181         1, // ABS
2182         1, // NEG
2183         2, // MAX
2184         2, // MIN
2185         2, // CMP
2186         2, // AND
2187         2, // OR
2188         2, // XOR
2189         1, // NOT
2190         1, // EXP
2191         1, // LOG
2192         2, // POW
2193         1, // SIN
2194         1, // COS
2195         3, // TERNARY
2196         0, // MUX
2197         0, // DUP
2198         0, // SWAP
2199     };
2200     static_assert(sizeof(numOperands) == static_cast<unsigned>(ExprOpType::SWAP) + 1, "invalid table");
2201 
2202     auto tokens = tokenize(expr);
2203 
2204     ExpressionTree tree;
2205     std::vector<ExpressionTreeNode *> stack;
2206 
2207     for (const std::string &tok : tokens) {
2208         ExprOp op = decodeToken(tok);
2209 
2210         // Check validity.
2211         if (op.type == ExprOpType::MEM_LOAD_U8 && op.imm.i >= numInputs)
2212             throw std::runtime_error("reference to undefined clip: " + tok);
2213         if ((op.type == ExprOpType::DUP || op.type == ExprOpType::SWAP) && op.imm.u >= stack.size())
2214             throw std::runtime_error("insufficient values on stack: " + tok);
2215         if (stack.size() < numOperands[static_cast<size_t>(op.type)])
2216             throw std::runtime_error("insufficient values on stack: " + tok);
2217 
2218         // Rename load operations with the correct data type.
2219         if (op.type == ExprOpType::MEM_LOAD_U8) {
2220             const VSFormat *format = vi[op.imm.i]->format;
2221 
2222             if (format->sampleType == stInteger && format->bytesPerSample == 1)
2223                 op.type = ExprOpType::MEM_LOAD_U8;
2224             else if (format->sampleType == stInteger && format->bytesPerSample == 2)
2225                 op.type = ExprOpType::MEM_LOAD_U16;
2226             else if (format->sampleType == stFloat && format->bytesPerSample == 2)
2227                 op.type = ExprOpType::MEM_LOAD_F16;
2228             else if (format->sampleType == stFloat && format->bytesPerSample == 4)
2229                 op.type = ExprOpType::MEM_LOAD_F32;
2230         }
2231 
2232         // Apply DUP and SWAP in the frontend.
2233         if (op.type == ExprOpType::DUP) {
2234             stack.push_back(tree.clone(stack[stack.size() - 1 - op.imm.u]));
2235         } else if (op.type == ExprOpType::SWAP) {
2236             std::swap(stack.back(), stack[stack.size() - 1 - op.imm.u]);
2237         } else {
2238             size_t operands = numOperands[static_cast<size_t>(op.type)];
2239 
2240             if (operands == 0) {
2241                 stack.push_back(tree.makeNode(op));
2242             } else if (operands == 1) {
2243                 ExpressionTreeNode *child = stack.back();
2244                 stack.pop_back();
2245 
2246                 ExpressionTreeNode *node = tree.makeNode(op);
2247                 node->setLeft(child);
2248                 stack.push_back(node);
2249             } else if (operands == 2) {
2250                 ExpressionTreeNode *left = stack[stack.size() - 2];
2251                 ExpressionTreeNode *right = stack[stack.size() - 1];
2252                 stack.resize(stack.size() - 2);
2253 
2254                 ExpressionTreeNode *node = tree.makeNode(op);
2255                 node->setLeft(left);
2256                 node->setRight(right);
2257                 stack.push_back(node);
2258             } else if (operands == 3) {
2259                 ExpressionTreeNode *arg1 = stack[stack.size() - 3];
2260                 ExpressionTreeNode *arg2 = stack[stack.size() - 2];
2261                 ExpressionTreeNode *arg3 = stack[stack.size() - 1];
2262                 stack.resize(stack.size() - 3);
2263 
2264                 ExpressionTreeNode *mux = tree.makeNode(ExprOpType::MUX);
2265                 mux->setLeft(arg2);
2266                 mux->setRight(arg3);
2267 
2268                 ExpressionTreeNode *node = tree.makeNode(op);
2269                 node->setLeft(arg1);
2270                 node->setRight(mux);
2271                 stack.push_back(node);
2272             }
2273         }
2274     }
2275 
2276     if (stack.empty())
2277         throw std::runtime_error("empty expression: " + expr);
2278     if (stack.size() > 1)
2279         throw std::runtime_error("unconsumed values on stack: " + expr);
2280 
2281     tree.setRoot(stack.back());
2282     return tree;
2283 }
2284 
isConstantExpr(const ExpressionTreeNode & node)2285 bool isConstantExpr(const ExpressionTreeNode &node)
2286 {
2287     switch (node.op.type) {
2288     case ExprOpType::MEM_LOAD_U8:
2289     case ExprOpType::MEM_LOAD_U16:
2290     case ExprOpType::MEM_LOAD_F16:
2291     case ExprOpType::MEM_LOAD_F32:
2292         return false;
2293     case ExprOpType::CONSTANT:
2294         return true;
2295     default:
2296         return (!node.left || isConstantExpr(*node.left)) && (!node.right || isConstantExpr(*node.right));
2297     }
2298 }
2299 
isConstant(const ExpressionTreeNode & node)2300 bool isConstant(const ExpressionTreeNode &node)
2301 {
2302     return node.op.type == ExprOpType::CONSTANT;
2303 }
2304 
isConstant(const ExpressionTreeNode & node,float val)2305 bool isConstant(const ExpressionTreeNode &node, float val)
2306 {
2307     return node.op.type == ExprOpType::CONSTANT && node.op.imm.f == val;
2308 }
2309 
evalConstantExpr(const ExpressionTreeNode & node)2310 float evalConstantExpr(const ExpressionTreeNode &node)
2311 {
2312     auto bool2float = [](bool x) { return x ? 1.0f : 0.0f; };
2313     auto float2bool = [](float x) { return x > 0.0f; };
2314 
2315 #define LEFT evalConstantExpr(*node.left)
2316 #define RIGHT evalConstantExpr(*node.right)
2317 #define RIGHTLEFT evalConstantExpr(*node.right->left)
2318 #define RIGHTRIGHT evalConstantExpr(*node.right->right)
2319     switch (node.op.type) {
2320     case ExprOpType::CONSTANT: return node.op.imm.f;
2321     case ExprOpType::ADD: return LEFT + RIGHT;
2322     case ExprOpType::SUB: return LEFT - RIGHT;
2323     case ExprOpType::MUL: return LEFT * RIGHT;
2324     case ExprOpType::DIV: return LEFT / RIGHT;
2325     case ExprOpType::FMA:
2326         switch (static_cast<FMAType>(node.op.imm.u)) {
2327         case FMAType::FMADD: return RIGHTLEFT * RIGHTRIGHT + LEFT;
2328         case FMAType::FMSUB: return RIGHTLEFT * RIGHTRIGHT - LEFT;
2329         case FMAType::FNMADD: return -(RIGHTLEFT * RIGHTRIGHT) + LEFT;
2330         case FMAType::FNMSUB: return -(RIGHTLEFT * RIGHTRIGHT) - LEFT;
2331         }
2332         return NAN;
2333     case ExprOpType::SQRT: return std::sqrt(LEFT);
2334     case ExprOpType::ABS: return std::fabs(LEFT);
2335     case ExprOpType::NEG: return -LEFT;
2336     case ExprOpType::MAX: return std::max(LEFT, RIGHT);
2337     case ExprOpType::MIN: return std::min(LEFT, RIGHT);
2338     case ExprOpType::CMP:
2339         switch (static_cast<ComparisonType>(node.op.imm.u)) {
2340         case ComparisonType::EQ: return bool2float(LEFT == RIGHT);
2341         case ComparisonType::LT: return bool2float(LEFT < RIGHT);
2342         case ComparisonType::LE: return bool2float(LEFT <= RIGHT);
2343         case ComparisonType::NEQ: return bool2float(LEFT != RIGHT);
2344         case ComparisonType::NLT: return bool2float(LEFT >= RIGHT);
2345         case ComparisonType::NLE: return bool2float(LEFT > RIGHT);
2346         }
2347         return NAN;
2348     case ExprOpType::AND: return bool2float(float2bool(LEFT) && float2bool(RIGHT));
2349     case ExprOpType::OR: return bool2float(float2bool(LEFT) || float2bool(RIGHT));
2350     case ExprOpType::XOR: return bool2float(float2bool(LEFT) != float2bool(RIGHT));
2351     case ExprOpType::NOT: return bool2float(!float2bool(LEFT));
2352     case ExprOpType::EXP: return std::exp(LEFT);
2353     case ExprOpType::LOG: return std::log(LEFT);
2354     case ExprOpType::POW: return std::pow(LEFT, RIGHT);
2355     case ExprOpType::SIN: return std::sin(LEFT);
2356     case ExprOpType::COS: return std::cos(LEFT);
2357     case ExprOpType::TERNARY: return float2bool(LEFT) ? RIGHTLEFT : RIGHTRIGHT;
2358     default: return NAN;
2359     }
2360 #undef RIGHTRIGHT
2361 #undef RIGHTLEFT
2362 #undef RIGHT
2363 #undef LEFT
2364 }
2365 
isOpCode(const ExpressionTreeNode & node,std::initializer_list<ExprOpType> types)2366 bool isOpCode(const ExpressionTreeNode &node, std::initializer_list<ExprOpType> types)
2367 {
2368     for (ExprOpType type : types) {
2369         if (node.op.type == type)
2370             return true;
2371     }
2372     return false;
2373 }
2374 
isInteger(float x)2375 bool isInteger(float x)
2376 {
2377     return std::floor(x) == x;
2378 }
2379 
replaceNode(ExpressionTreeNode & node,const ExpressionTreeNode & replacement)2380 void replaceNode(ExpressionTreeNode &node, const ExpressionTreeNode &replacement)
2381 {
2382     node.op = replacement.op;
2383     node.setLeft(replacement.left);
2384     node.setRight(replacement.right);
2385 }
2386 
swapNodeContents(ExpressionTreeNode & lhs,ExpressionTreeNode & rhs)2387 void swapNodeContents(ExpressionTreeNode &lhs, ExpressionTreeNode &rhs)
2388 {
2389     std::swap(lhs, rhs);
2390     std::swap(lhs.parent, rhs.parent);
2391 }
2392 
applyValueNumbering(ExpressionTree & tree)2393 void applyValueNumbering(ExpressionTree &tree)
2394 {
2395     std::vector<ExpressionTreeNode *> numbered;
2396     int valueNum = 0;
2397 
2398     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
2399     {
2400         node.valueNum = -1;
2401     });
2402 
2403     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
2404     {
2405         if (node.op.type == ExprOpType::MUX)
2406             return;
2407 
2408         for (ExpressionTreeNode *testnode : numbered) {
2409             if (equalSubTree(&node, testnode)) {
2410                 node.valueNum = testnode->valueNum;
2411                 return;
2412             }
2413         }
2414 
2415         node.valueNum = valueNum++;
2416         numbered.push_back(&node);
2417     });
2418 }
2419 
emitIntegerPow(ExpressionTree & tree,const ExpressionTreeNode & node,int exponent)2420 ExpressionTreeNode *emitIntegerPow(ExpressionTree &tree, const ExpressionTreeNode &node, int exponent)
2421 {
2422     if (exponent == 1)
2423         return tree.clone(&node);
2424 
2425     ExpressionTreeNode *mulNode = tree.makeNode({ ExprOpType::MUL });
2426     mulNode->setLeft(emitIntegerPow(tree, node, (exponent + 1) / 2));
2427     mulNode->setRight(emitIntegerPow(tree, node, exponent - (exponent + 1) / 2));
2428     return mulNode;
2429 }
2430 
2431 typedef std::unordered_map<int, const ExpressionTreeNode *> ValueIndex;
2432 
2433 class ExponentMap {
2434     struct CanonicalCompare {
2435         const ValueIndex &index;
2436 
operator ()__anon2a0524d50111::ExponentMap::CanonicalCompare2437         bool operator()(const std::pair<int, float> &lhs, const std::pair<int, float> &rhs) const
2438         {
2439             const std::initializer_list<ExprOpType> memOpCodes = { ExprOpType::MEM_LOAD_U8, ExprOpType::MEM_LOAD_U16, ExprOpType::MEM_LOAD_F16, ExprOpType::MEM_LOAD_F32 };
2440 
2441             // Order equivalent terms by exponent.
2442             if (lhs.first == rhs.first)
2443                 return lhs.second < rhs.second;
2444 
2445             const ExpressionTreeNode *lhsNode = index.at(lhs.first);
2446             const ExpressionTreeNode *rhsNode = index.at(rhs.first);
2447 
2448             // Ordering: complex values, memory, constants
2449             int lhsCategory = isConstant(*lhsNode) ? 2 : isOpCode(*lhsNode, memOpCodes) ? 1 : 0;
2450             int rhsCategory = isConstant(*rhsNode) ? 2 : isOpCode(*rhsNode, memOpCodes) ? 1 : 0;
2451 
2452             if (lhsCategory != rhsCategory)
2453                 return lhsCategory < rhsCategory;
2454 
2455             // Ordering criteria for each category:
2456             //
2457             // constants: order by value
2458             // memory: order by variable name
2459             // other: order by value number (unstable)
2460             if (lhsCategory == 2)
2461                 return lhsNode->op.imm.f < rhsNode->op.imm.f;
2462             else if (lhsCategory == 1)
2463                 return lhsNode->op.imm.u < rhsNode->op.imm.f;
2464             else
2465                 return lhs.first < rhs.first;
2466         };
2467     };
2468 
2469     // e.g. 3 * v0^2 * v1^3
2470     // map = { 0: 2, 1: 3 }, coeff = 3
2471     std::map<int, float> map; // key = valueNum, value = exponent
2472     std::vector<int> origSequence;
2473     float coeff;
2474 
expandOrigSequence(ValueIndex & index)2475     bool expandOrigSequence(ValueIndex &index)
2476     {
2477         bool changed = false;
2478 
2479         for (size_t i = 0; i < origSequence.size(); ++i) {
2480             const ExpressionTreeNode *value = index.at(origSequence[i]);
2481 
2482             if (value->op == ExprOpType::POW && isConstant(*value->right)) {
2483                 origSequence[i] = value->left->valueNum;
2484                 changed = true;
2485             } else if (value->op == ExprOpType::MUL || value->op == ExprOpType::DIV) {
2486                 origSequence[i] = value->left->valueNum;
2487                 origSequence.insert(origSequence.begin() + i + 1, value->right->valueNum);
2488                 changed = true;
2489             }
2490         }
2491 
2492         return changed;
2493     }
2494 
expandOnePass(ValueIndex & index)2495     bool expandOnePass(ValueIndex &index)
2496     {
2497         bool changed = false;
2498 
2499         for (auto it = map.begin(); it != map.end();) {
2500             const ExpressionTreeNode *value = index.at(it->first);
2501             bool erase = false;
2502 
2503             if (value->op == ExprOpType::POW && isConstant(*value->right)) {
2504                 index[value->left->valueNum] = value->left;
2505 
2506                 map[value->left->valueNum] += it->second * value->right->op.imm.f;
2507                 erase = true;
2508             } else if (value->op == ExprOpType::MUL) {
2509                 index[value->left->valueNum] = value->left;
2510                 index[value->right->valueNum] = value->right;
2511 
2512                 map[value->left->valueNum] += it->second;
2513                 map[value->right->valueNum] += it->second;
2514                 erase = true;
2515             } else if (value->op == ExprOpType::DIV) {
2516                 index[value->left->valueNum] = value->left;
2517                 index[value->right->valueNum] = value->right;
2518 
2519                 map[value->left->valueNum] += it->second;
2520                 map[value->right->valueNum] -= it->second;
2521                 erase = true;
2522             }
2523 
2524             if (erase) {
2525                 it = map.erase(it);
2526                 changed = true;
2527                 continue;
2528             }
2529 
2530             ++it;
2531         }
2532 
2533         return changed;
2534     }
2535 
combineConstants(const ValueIndex & index)2536     void combineConstants(const ValueIndex &index)
2537     {
2538         for (auto it = map.begin(); it != map.end();) {
2539             const ExpressionTreeNode *node = index.at(it->first);
2540             if (isConstant(*node)) {
2541                 coeff *= std::pow(node->op.imm.f, it->second);
2542                 it = map.erase(it);
2543                 continue;
2544             }
2545             ++it;
2546         }
2547     }
2548 public:
ExponentMap()2549     ExponentMap() : coeff(1.0f) {}
2550 
addTerm(int valueNum,float exp)2551     void addTerm(int valueNum, float exp)
2552     {
2553         map[valueNum] += exp;
2554         origSequence.push_back(valueNum);
2555     }
2556 
addCoeff(float val)2557     void addCoeff(float val) { coeff += val; }
2558 
mulCoeff(float val)2559     void mulCoeff(float val) { coeff *= val; }
2560 
getCoeff() const2561     float getCoeff() const { return coeff; }
2562 
isScalar() const2563     bool isScalar() const { return map.empty(); }
2564 
numTerms() const2565     size_t numTerms() const { return map.size() + 1; }
2566 
isSameTerm(const ExponentMap & other) const2567     bool isSameTerm(const ExponentMap &other) const
2568     {
2569         auto it1 = map.begin();
2570         auto it2 = other.map.begin();
2571 
2572         while (it1 != map.end() && it2 != other.map.end()) {
2573             if (it1->first != it2->first || it1->second != it2->second)
2574                 return false;
2575 
2576             ++it1;
2577             ++it2;
2578         }
2579 
2580         return it1 == map.end() && it2 == other.map.end();
2581     }
2582 
expand(ValueIndex & index)2583     void expand(ValueIndex &index)
2584     {
2585         while (expandOnePass(index)) {
2586             // ...
2587         }
2588         combineConstants(index);
2589 
2590         while (expandOrigSequence(index)) {
2591             // ...
2592         }
2593     }
2594 
isCanonical(const ValueIndex & index) const2595     bool isCanonical(const ValueIndex &index) const
2596     {
2597         std::vector<std::pair<int, float>> tmp;
2598         for (int x : origSequence) {
2599             tmp.push_back({ x, 1.0f });
2600         }
2601         return std::is_sorted(tmp.begin(), tmp.end(), CanonicalCompare{ index });
2602     }
2603 
emit(ExpressionTree & tree,const ValueIndex & index) const2604     ExpressionTreeNode *emit(ExpressionTree &tree, const ValueIndex &index) const
2605     {
2606         std::vector<std::pair<int, float>> flat(map.begin(), map.end());
2607         std::sort(flat.begin(), flat.end(), CanonicalCompare{ index });
2608 
2609         ExpressionTreeNode *node = nullptr;
2610 
2611         for (auto &term : flat) {
2612             ExpressionTreeNode *powNode = tree.makeNode(ExprOpType::POW);
2613             powNode->setLeft(tree.clone(index.at(term.first)));
2614             powNode->setRight(tree.makeNode({ ExprOpType::CONSTANT, term.second }));
2615 
2616             if (node) {
2617                 ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
2618                 mulNode->setLeft(node);
2619                 mulNode->setRight(powNode);
2620                 node = mulNode;
2621             } else {
2622                 node = powNode;
2623             }
2624         }
2625 
2626         if (node) {
2627             ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
2628             mulNode->setLeft(node);
2629             mulNode->setRight(tree.makeNode({ ExprOpType::CONSTANT, coeff }));
2630             node = mulNode;
2631         } else {
2632             node = tree.makeNode({ ExprOpType::CONSTANT, coeff });
2633         }
2634 
2635         return node;
2636     }
2637 
canonicalOrder(const ExponentMap & other,const ValueIndex & index) const2638     bool canonicalOrder(const ExponentMap &other, const ValueIndex &index) const
2639     {
2640         // Convert map to flat array, as canonical order is different from value numbering.
2641         std::vector<std::pair<int, float>> lhsFlat(map.begin(), map.end());
2642         std::vector<std::pair<int, float>> rhsFlat(other.map.begin(), other.map.end());
2643 
2644         CanonicalCompare pred{ index };
2645         std::sort(lhsFlat.begin(), lhsFlat.end(), pred);
2646         std::sort(rhsFlat.begin(), rhsFlat.end(), pred);
2647         return std::lexicographical_compare(lhsFlat.begin(), lhsFlat.end(), rhsFlat.begin(), rhsFlat.end(), pred);
2648     }
2649 };
2650 
2651 class AdditiveSequence {
2652     std::vector<ExponentMap> terms;
2653     float scalarTerm;
2654 public:
AdditiveSequence()2655     AdditiveSequence() : scalarTerm() {}
2656 
addTerm(int valueNum,int sign)2657     void addTerm(int valueNum, int sign)
2658     {
2659         ExponentMap map;
2660         map.addTerm(valueNum, 1.0f);
2661         map.mulCoeff(static_cast<float>(sign));
2662         terms.push_back(std::move(map));
2663     }
2664 
numTerms() const2665     size_t numTerms() const { return terms.size() + 1; }
2666 
expand(ValueIndex & index)2667     void expand(ValueIndex &index)
2668     {
2669         for (auto &term : terms) {
2670             term.expand(index);
2671         }
2672 
2673         for (auto it = terms.begin(); it != terms.end();) {
2674             if (it->isScalar()) {
2675                 scalarTerm += it->getCoeff();
2676                 it = terms.erase(it);
2677                 continue;
2678             }
2679 
2680             ++it;
2681         }
2682 
2683         for (auto it1 = terms.begin(); it1 != terms.end();) {
2684             for (auto it2 = it1 + 1; it2 != terms.end(); ++it2) {
2685                 if (it1->isSameTerm(*it2)) {
2686                     it1->addCoeff(it2->getCoeff());
2687                     it2->mulCoeff(0.0f);
2688                 }
2689             }
2690 
2691             if (it1->getCoeff() == 0.0f) {
2692                 it1 = terms.erase(it1);
2693                 continue;
2694             }
2695 
2696             ++it1;
2697         }
2698     }
2699 
canonicalize(const ValueIndex & index)2700     bool canonicalize(const ValueIndex &index)
2701     {
2702         auto pred = [&](const ExponentMap &lhs, const ExponentMap &rhs)
2703         {
2704             return lhs.canonicalOrder(rhs, index);
2705         };
2706 
2707         if (std::is_sorted(terms.begin(), terms.end(), pred))
2708             return true;
2709 
2710         std::sort(terms.begin(), terms.end(), pred);
2711         return false;
2712     }
2713 
emit(ExpressionTree & tree,const ValueIndex & index) const2714     ExpressionTreeNode *emit(ExpressionTree &tree, const ValueIndex &index) const
2715     {
2716         ExpressionTreeNode *head = nullptr;
2717 
2718         for (const auto &term : terms) {
2719             ExpressionTreeNode *node = term.emit(tree, index);
2720 
2721             if (head) {
2722                 ExpressionTreeNode *addNode = tree.makeNode(ExprOpType::ADD);
2723                 addNode->setLeft(head);
2724                 addNode->setRight(node);
2725                 head = addNode;
2726             } else {
2727                 head = node;
2728             }
2729         }
2730 
2731         if (head) {
2732             ExpressionTreeNode *addNode = tree.makeNode(scalarTerm < 0 ? ExprOpType::SUB : ExprOpType::ADD);
2733             addNode->setLeft(head);
2734             addNode->setRight(tree.makeNode({ ExprOpType::CONSTANT, std::fabs(scalarTerm) }));
2735             head = addNode;
2736         } else {
2737             head = tree.makeNode({ ExprOpType::CONSTANT, 0.0f });
2738         }
2739 
2740         return head;
2741     }
2742 };
2743 
analyzeAdditiveExpression(ExpressionTree & tree,ExpressionTreeNode & node)2744 bool analyzeAdditiveExpression(ExpressionTree &tree, ExpressionTreeNode &node)
2745 {
2746     size_t origNumTerms = 0;
2747     AdditiveSequence expr;
2748     ValueIndex index;
2749 
2750     node.preorder([&](ExpressionTreeNode &node)
2751     {
2752         if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }))
2753             return false;
2754 
2755         // Deduce net sign of term.
2756         const ExpressionTreeNode *parent = node.parent;
2757         const ExpressionTreeNode *cur = &node;
2758         int polarity = 1;
2759 
2760         while (parent && isOpCode(*parent, { ExprOpType::ADD, ExprOpType::SUB })) {
2761             if (parent->op == ExprOpType::SUB && cur == parent->right)
2762                 polarity = -polarity;
2763 
2764             cur = parent;
2765             parent = parent->parent;
2766         }
2767 
2768         ++origNumTerms;
2769         expr.addTerm(node.valueNum, polarity);
2770         index[node.valueNum] = &node;
2771         return true;
2772     });
2773 
2774     expr.expand(index);
2775     bool canonical = expr.canonicalize(index);
2776 
2777     if (expr.numTerms() < origNumTerms || !canonical) {
2778         ExpressionTreeNode *seq = expr.emit(tree, index);
2779         replaceNode(node, *seq);
2780         return true;
2781     }
2782 
2783     return false;
2784 }
2785 
analyzeMultiplicativeExpression(ExpressionTree & tree,ExpressionTreeNode & node)2786 bool analyzeMultiplicativeExpression(ExpressionTree &tree, ExpressionTreeNode &node)
2787 {
2788     std::unordered_map<int, const ExpressionTreeNode *> index;
2789 
2790     ExponentMap expr;
2791     size_t origNumTerms = 0;
2792     size_t numDivs = 0;
2793 
2794     node.preorder([&](ExpressionTreeNode &node)
2795     {
2796         if (node.op == ExprOpType::DIV)
2797             ++numDivs;
2798 
2799         if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }))
2800             return false;
2801 
2802         // Deduce net sign of term.
2803         const ExpressionTreeNode *parent = node.parent;
2804         const ExpressionTreeNode *cur = &node;
2805         int polarity = 1;
2806 
2807         while (parent && isOpCode(*parent, { ExprOpType::MUL, ExprOpType::DIV })) {
2808             if (parent->op == ExprOpType::DIV && cur == parent->right)
2809                 polarity = -polarity;
2810 
2811             cur = parent;
2812             parent = parent->parent;
2813         }
2814 
2815         expr.addTerm(node.valueNum, static_cast<float>(polarity));
2816         index[node.valueNum] = &node;
2817         ++origNumTerms;
2818         return true;
2819     });
2820 
2821     expr.expand(index);
2822 
2823     if (expr.numTerms() < origNumTerms || !expr.isCanonical(index) || numDivs) {
2824         ExpressionTreeNode *seq = expr.emit(tree, index);
2825         replaceNode(node, *seq);
2826         return true;
2827     }
2828 
2829     return false;
2830 }
2831 
applyAlgebraicOptimizations(ExpressionTree & tree)2832 bool applyAlgebraicOptimizations(ExpressionTree &tree)
2833 {
2834     bool changed = false;
2835 
2836     applyValueNumbering(tree);
2837 
2838     tree.getRoot()->preorder([&](ExpressionTreeNode &node)
2839     {
2840         if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::ADD, ExprOpType::SUB }))) {
2841             changed = changed || analyzeAdditiveExpression(tree, node);
2842             return changed;
2843         }
2844 
2845         if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::MUL, ExprOpType::DIV }))) {
2846             changed = changed || analyzeMultiplicativeExpression(tree, node);
2847             return changed;
2848         }
2849 
2850         return false;
2851     });
2852 
2853     return changed;
2854 }
2855 
applyComparisonOptimizations(ExpressionTree & tree)2856 bool applyComparisonOptimizations(ExpressionTree &tree)
2857 {
2858     bool changed = false;
2859 
2860     applyValueNumbering(tree);
2861 
2862     tree.getRoot()->preorder([&](ExpressionTreeNode &node)
2863     {
2864         // Eliminate constant conditions.
2865         if (node.op.type == ExprOpType::CMP && node.left->valueNum == node.right->valueNum) {
2866             ComparisonType type = static_cast<ComparisonType>(node.op.imm.u);
2867             if (type == ComparisonType::EQ || type == ComparisonType::LE || type == ComparisonType::NLT)
2868                 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
2869             else
2870                 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
2871 
2872             changed = true;
2873             return changed;
2874         }
2875 
2876         // Eliminate identical branches.
2877         if (node.op == ExprOpType::TERNARY && node.right->left->valueNum == node.right->right->valueNum) {
2878             replaceNode(node, *node.right->left);
2879             changed = true;
2880             return changed;
2881         }
2882 
2883         // MIN/MAX detection.
2884         if (node.op == ExprOpType::TERNARY && node.left->op.type == ExprOpType::CMP) {
2885             ComparisonType type = static_cast<ComparisonType>(node.left->op.imm.u);
2886             int cmpTerms[2] = { node.left->left->valueNum, node.left->right->valueNum };
2887             int muxTerms[2] = { node.right->left->valueNum, node.right->right->valueNum };
2888 
2889             bool isSameTerms = (cmpTerms[0] == muxTerms[0] && cmpTerms[1] == muxTerms[1]) || (cmpTerms[0] == muxTerms[1] && cmpTerms[1] == muxTerms[0]);
2890             bool isLessOrGreater = type == ComparisonType::LT || type == ComparisonType::LE || type == ComparisonType::NLE || type == ComparisonType::NLT;
2891 
2892             if (isSameTerms && isLessOrGreater) {
2893                 // a < b ? a : b --> min(a, b)     a > b ? b : a --> min(a, b)
2894                 // a > b ? a : b --> max(a, b)     a < b ? b : a --> max(a, b)
2895                 bool min = (type == ComparisonType::LT || type == ComparisonType::LE) ? cmpTerms[0] == muxTerms[0] : cmpTerms[0] != muxTerms[0];
2896                 ExpressionTreeNode *a = node.left->left;
2897                 ExpressionTreeNode *b = node.left->right;
2898 
2899                 replaceNode(node, ExpressionTreeNode{ min ? ExprOpType::MIN : ExprOpType::MAX });
2900                 node.setLeft(a);
2901                 node.setRight(b);
2902 
2903                 changed = true;
2904                 return changed;
2905             }
2906         }
2907 
2908         // CMP to SUB conversion. It has lower priority than other comparison transformations.
2909         if (node.op.type == ExprOpType::CMP && node.parent && isOpCode(*node.parent, { ExprOpType::AND, ExprOpType::OR, ExprOpType::XOR, ExprOpType::TERNARY })) {
2910             ComparisonType type = static_cast<ComparisonType>(node.op.imm.u);
2911 
2912             // a < b --> b - a    a > b --> a - b
2913             if (type == ComparisonType::LT || type == ComparisonType::NLE) {
2914                 if (type == ComparisonType::LT)
2915                     std::swap(node.left, node.right);
2916 
2917                 node.op = ExprOpType::SUB;
2918                 changed = true;
2919                 return changed;
2920             }
2921         }
2922 
2923         return false;
2924     });
2925 
2926     return changed;
2927 }
2928 
applyLocalOptimizations(ExpressionTree & tree)2929 bool applyLocalOptimizations(ExpressionTree &tree)
2930 {
2931     bool changed = false;
2932 
2933     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
2934     {
2935         if (node.op.type == ExprOpType::MUX)
2936             return;
2937 
2938         // Constant folding.
2939         if (node.op.type != ExprOpType::CONSTANT && isConstantExpr(node)) {
2940             float val = evalConstantExpr(node);
2941             replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, val } });
2942             changed = true;
2943         }
2944 
2945         // Move constants to right-hand side to simplify identities.
2946         if (isOpCode(node, { ExprOpType::ADD, ExprOpType::MUL }) && isConstant(*node.left) && !isConstant(*node.right)) {
2947             std::swap(node.left, node.right);
2948             changed = true;
2949         }
2950 
2951         // x * 0 = 0    0 / x = 0
2952         if ((node.op == ExprOpType::MUL && isConstant(*node.right, 0.0f)) || (node.op == ExprOpType::DIV && isConstant(*node.left, 0.0f))) {
2953             replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
2954             changed = true;
2955         }
2956 
2957         // sqrt(x) = x ** 0.5
2958         if (node.op == ExprOpType::SQRT) {
2959             node.op = ExprOpType::POW;
2960             node.setRight(tree.makeNode({ ExprOpType::CONSTANT, 0.5f }));
2961             changed = true;
2962         }
2963 
2964         // log(exp(x)) = x    exp(log(x)) = x
2965         if ((node.op == ExprOpType::LOG && node.left->op == ExprOpType::EXP) || (node.op == ExprOpType::EXP && node.left->op == ExprOpType::LOG)) {
2966             replaceNode(node, *node.left->left);
2967             changed = true;
2968         }
2969 
2970         // x ** 0 = 1
2971         if (node.op == ExprOpType::POW && isConstant(*node.right, 0.0f)) {
2972             replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
2973             changed = true;
2974         }
2975 
2976         // (a ** b) ** c = a ** (b * c)
2977         if (node.op == ExprOpType::POW && node.left->op == ExprOpType::POW) {
2978             ExpressionTreeNode *a = node.left->left;
2979             ExpressionTreeNode *b = node.left->right;
2980             ExpressionTreeNode *c = node.right;
2981             replaceNode(*node.left, *a);
2982             node.setRight(tree.makeNode(ExprOpType::MUL));
2983             node.right->setLeft(b);
2984             node.right->setRight(c);
2985             changed = true;
2986         }
2987 
2988         // 0 ? x : y = y    1 ? x : y = x
2989         if (node.op == ExprOpType::TERNARY && isConstant(*node.left)) {
2990             ExpressionTreeNode *replacement = node.left->op.imm.f > 0.0f ? node.right->left : node.right->right;
2991             replaceNode(node, *replacement);
2992             changed = true;
2993         }
2994 
2995         // a <= b ? x : y --> a > b ? y : x    a >= b ? x : y --> a < b ? y : x
2996         if (node.op == ExprOpType::TERNARY && node.left->op.type == ExprOpType::CMP) {
2997             ComparisonType type = static_cast<ComparisonType>(node.left->op.imm.u);
2998 
2999             if (type == ComparisonType::LE || type == ComparisonType::NLT) {
3000                 node.left->op.imm.u = static_cast<unsigned>(type == ComparisonType::LE ? ComparisonType::NLE : ComparisonType::LT);
3001                 std::swap(node.right->left, node.right->right);
3002                 changed = true;
3003             }
3004         }
3005 
3006         // !a ? b : c --> a ? c : b
3007         if (node.op == ExprOpType::TERNARY && node.left->op == ExprOpType::NOT) {
3008             replaceNode(*node.left, *node.left->left);
3009             std::swap(node.right->left, node.right->right);
3010             changed = true;
3011         }
3012 
3013         // !(a < b) --> a >= b
3014         if (node.op == ExprOpType::NOT && node.left->op.type == ExprOpType::CMP) {
3015             switch (static_cast<ComparisonType>(node.left->op.imm.u)) {
3016             case ComparisonType::EQ: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::NEQ); break;
3017             case ComparisonType::LT: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::NLT); break;
3018             case ComparisonType::LE: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::NLE); break;
3019             case ComparisonType::NEQ: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::EQ); break;
3020             case ComparisonType::NLT: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::LT); break;
3021             case ComparisonType::NLE: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::LE); break;
3022             }
3023             replaceNode(node, *node.left);
3024             changed = true;
3025         }
3026     });
3027 
3028     return changed;
3029 }
3030 
applyAlgebraicCleanup(ExpressionTree & tree)3031 bool applyAlgebraicCleanup(ExpressionTree &tree)
3032 {
3033     bool changed = false;
3034 
3035     // Prune extra terms introduced by the algebraic analysis. These need to run in a later pass to prevent cycles.
3036     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3037     {
3038         // x + 0 = x    x - 0 = x
3039         if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && isConstant(*node.right, 0.0f)) {
3040             replaceNode(node, *node.left);
3041             changed = true;
3042         }
3043 
3044         // x * 1 = x    x / 1 = x
3045         if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, 1.0f)) {
3046             replaceNode(node, *node.left);
3047             changed = true;
3048         }
3049 
3050         // x ** 1 = x
3051         if (node.op == ExprOpType::POW && isConstant(*node.right, 1.0f)) {
3052             replaceNode(node, *node.left);
3053             changed = true;
3054         }
3055     });
3056 
3057     return changed;
3058 }
3059 
3060 
applyStrengthReduction(ExpressionTree & tree)3061 bool applyStrengthReduction(ExpressionTree &tree)
3062 {
3063     bool changed = false;
3064 
3065     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3066     {
3067         if (node.op == ExprOpType::MUX)
3068             return;
3069 
3070         // 0 - x = -x
3071         if (node.op == ExprOpType::SUB && isConstant(*node.left, 0.0f)) {
3072             ExpressionTreeNode *tmp = node.right;
3073             replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
3074             node.setLeft(tmp);
3075             changed = true;
3076         }
3077 
3078         // x * -1 = -x    x / -1 = -x
3079         if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, -1.0f)) {
3080             ExpressionTreeNode *tmp = node.left;
3081             replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
3082             node.setLeft(tmp);
3083             changed = true;
3084         }
3085 
3086         // a + -b = a - b    a - -b = a + b
3087         if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && node.right->op.type == ExprOpType::NEG) {
3088             node.op = node.op == ExprOpType::ADD ? ExprOpType::SUB : ExprOpType::ADD;
3089             replaceNode(*node.right, *node.right->left);
3090             changed = true;
3091         }
3092 
3093         // -a + b = b - a
3094         if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::NEG) {
3095             node.op = ExprOpType::SUB;
3096             replaceNode(*node.left, *node.left->left);
3097             std::swap(node.left, node.right);
3098         }
3099 
3100         // -(a - b) = b - a
3101         if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::SUB) {
3102             replaceNode(node, *node.left);
3103             std::swap(node.left, node.right);
3104             changed = true;
3105         }
3106 
3107         // x * 2 = x + x
3108         if (node.op == ExprOpType::MUL && isConstant(*node.right, 2.0f) && (!node.parent || node.parent->op != ExprOpType::ADD)) {
3109             ExpressionTreeNode *replacement = tree.clone(node.left);
3110             node.op = ExprOpType::ADD;
3111             replaceNode(*node.right, *replacement);
3112             changed = true;
3113         }
3114 
3115         // x / y = x * (1 / y)
3116         if (node.op == ExprOpType::DIV && isConstant(*node.right)) {
3117             node.op = ExprOpType::MUL;
3118             node.right->op.imm.f = 1.0f / node.right->op.imm.f;
3119             changed = true;
3120         }
3121 
3122         // (1 / x) * y = y / x
3123         if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV && isConstant(*node.left->left, 1.0f)) {
3124             node.op = ExprOpType::DIV;
3125             replaceNode(*node.left, *node.left->right);
3126             std::swap(node.left, node.right);
3127             changed = true;
3128         }
3129 
3130         // x * (1 / y) = x / y
3131         if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV && isConstant(*node.right->left, 1.0f)) {
3132             node.op = ExprOpType::DIV;
3133             replaceNode(*node.right, *node.right->right);
3134             changed = true;
3135         }
3136 
3137         // (a / b) * c = (a * c) / b
3138         if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV) {
3139             node.op = ExprOpType::DIV;
3140             node.left->op = ExprOpType::MUL;
3141             swapNodeContents(*node.left->right, *node.right);
3142             changed = true;
3143         }
3144 
3145         // a * (b / c) = (a * b) / c
3146         if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV) {
3147             node.op = ExprOpType::DIV;
3148             node.right->op = ExprOpType::MUL;
3149             std::swap(node.left, node.right); // (b * c) / a
3150             swapNodeContents(*node.left->left, *node.left->right); // (c * b) / a
3151             swapNodeContents(*node.left->left, *node.right); // (a * b) / c
3152             changed = true;
3153         }
3154 
3155         // a / (b / c) = (a * c) / b
3156         if (node.op == ExprOpType::DIV && node.right->op == ExprOpType::DIV) {
3157             node.right->op = ExprOpType::MUL; // a / (b * c)
3158             std::swap(node.left, node.right); // (b * c) / a
3159             swapNodeContents(*node.left->left, *node.right); // (a * c) / b
3160             changed = true;
3161         }
3162 
3163         // (a / b) / c = a / (b * c)
3164         if (node.op == ExprOpType::DIV && node.left->op == ExprOpType::DIV) {
3165             node.left->op = ExprOpType::MUL; // (a * b) / c
3166             std::swap(node.left, node.right); // c / (a * b)
3167             swapNodeContents(*node.left, *node.right->left); // a / (c * b)
3168             swapNodeContents(*node.right->left, *node.right->right); // a / (b * c)
3169             changed = true;
3170         }
3171 
3172         // x ** (n / 2) = sqrt(x ** n)
3173         if (node.op == ExprOpType::POW && isConstant(*node.right) && !isInteger(node.right->op.imm.f) && isInteger(node.right->op.imm.f * 2.0f)) {
3174             ExpressionTreeNode *dup = tree.clone(&node);
3175             replaceNode(node, ExpressionTreeNode{ ExprOpType::SQRT });
3176             node.setLeft(dup);
3177             node.left->right->op.imm.f *= 2.0f;
3178             changed = true;
3179         }
3180 
3181         // x ** -N = 1 / (x ** N)
3182         if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f < 0) {
3183             ExpressionTreeNode *dup = tree.clone(&node);
3184             replaceNode(node, ExpressionTreeNode{ ExprOpType::DIV });
3185             node.setLeft(tree.makeNode({ ExprOpType::CONSTANT, 1.0f }));
3186             node.setRight(dup);
3187             node.right->right->op.imm.f = -node.right->right->op.imm.f;
3188             changed = true;
3189         }
3190 
3191         // x ** N = x * x * x * ...
3192         if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f > 0) {
3193             ExpressionTreeNode *replacement = emitIntegerPow(tree, *node.left, static_cast<int>(node.right->op.imm.f));
3194             replaceNode(node, *replacement);
3195             changed = true;
3196         }
3197     });
3198 
3199     return changed;
3200 }
3201 
applyOpFusion(ExpressionTree & tree)3202 bool applyOpFusion(ExpressionTree &tree)
3203 {
3204     std::unordered_map<int, size_t> refCount;
3205     bool changed = false;
3206 
3207     applyValueNumbering(tree);
3208 
3209     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3210     {
3211         if (node.op == ExprOpType::MUX)
3212             return;
3213 
3214         refCount[node.valueNum]++;
3215     });
3216 
3217     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3218     {
3219         if (node.op == ExprOpType::MUX)
3220             return;
3221 
3222         auto canElide = [&](ExpressionTreeNode &candidate)
3223         {
3224             return refCount[node.valueNum] > 1 || refCount[candidate.valueNum] <= 1;
3225         };
3226 
3227         // a + (b * c)    (b * c) + a    a - (b * c)    (b * c) - a
3228         if (node.op == ExprOpType::ADD && node.right->op == ExprOpType::MUL && canElide(*node.right)) {
3229             node.right->op = ExprOpType::MUX;
3230             node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
3231             changed = true;
3232         }
3233         if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::MUL && canElide(*node.left)) {
3234             std::swap(node.left, node.right);
3235             node.right->op = ExprOpType::MUX;
3236             node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
3237             changed = true;
3238         }
3239         if (node.op == ExprOpType::SUB && node.right->op == ExprOpType::MUL && canElide(*node.right)) {
3240             node.right->op = ExprOpType::MUX;
3241             node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FNMADD) };
3242             changed = true;
3243         }
3244         if (node.op == ExprOpType::SUB && node.left->op == ExprOpType::MUL && canElide(*node.left)) {
3245             std::swap(node.left, node.right);
3246             node.right->op = ExprOpType::MUX;
3247             node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMSUB) };
3248             changed = true;
3249         }
3250 
3251         // (a + b) * c = (a * c) + b * c
3252         if (node.op == ExprOpType::MUL && isOpCode(*node.left, { ExprOpType::ADD, ExprOpType::SUB }) &&
3253             isConstant(*node.right) && isConstant(*node.left->right) && canElide(*node.left))
3254         {
3255             std::swap(node.op, node.left->op);
3256             swapNodeContents(*node.right, *node.left->right);
3257             node.right->op.imm.f *= node.left->right->op.imm.f;
3258             changed = true;
3259         }
3260 
3261         // Negative FMA.
3262         if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::FMA && canElide(*node.left)) {
3263             replaceNode(node, *node.left);
3264 
3265             switch (static_cast<FMAType>(node.op.imm.u)) {
3266             case FMAType::FMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FNMSUB); break;
3267             case FMAType::FMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FNMADD); break;
3268             case FMAType::FNMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FMSUB); break;
3269             case FMAType::FNMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FMADD); break;
3270             }
3271 
3272             changed = true;
3273         }
3274     });
3275 
3276     return changed;
3277 }
3278 
renameRegisters(std::vector<ExprInstruction> & code)3279 void renameRegisters(std::vector<ExprInstruction> &code)
3280 {
3281     std::unordered_map<int, int> table;
3282     std::set<int> freeList;
3283 
3284     for (size_t i = 0; i < code.size(); ++i) {
3285         ExprInstruction &insn = code[i];
3286         int origRegs[4] = { insn.dst, insn.src1, insn.src2, insn.src3 };
3287         int renamed[4] = { insn.dst, insn.src1, insn.src2, insn.src3 };
3288 
3289         for (int n = 1; n < 4; ++n) {
3290             if (origRegs[n] < 0)
3291                 continue;
3292 
3293             auto it = table.find(origRegs[n]);
3294             if (it != table.end())
3295                 renamed[n] = it->second;
3296 
3297             bool dead = true;
3298 
3299             for (size_t j = i + 1; j < code.size(); ++j) {
3300                 const ExprInstruction &insn2 = code[j];
3301                 if (insn2.src1 == origRegs[n] || insn2.src2 == origRegs[n] || insn2.src3 == origRegs[n]) {
3302                     dead = false;
3303                     break;
3304                 }
3305             }
3306 
3307             if (dead)
3308                 freeList.insert(renamed[n]);
3309         }
3310 
3311         if (origRegs[0] >= 0 && !freeList.empty()) {
3312             renamed[0] = *freeList.begin();
3313             table[origRegs[0]] = renamed[0];
3314             freeList.erase(freeList.begin());
3315             freeList.insert(origRegs[0]);
3316         }
3317 
3318         insn.dst = renamed[0];
3319         insn.src1 = renamed[1];
3320         insn.src2 = renamed[2];
3321         insn.src3 = renamed[3];
3322     }
3323 }
3324 
compile(ExpressionTree & tree,const VSFormat * format)3325 std::vector<ExprInstruction> compile(ExpressionTree &tree, const VSFormat *format)
3326 {
3327     std::vector<ExprInstruction> code;
3328     std::unordered_set<int> found;
3329 
3330     if (!tree.getRoot())
3331         return code;
3332 
3333     while (applyLocalOptimizations(tree) || applyAlgebraicOptimizations(tree) || applyComparisonOptimizations(tree)) {
3334         // ...
3335     }
3336 
3337     while (applyAlgebraicCleanup(tree) || applyStrengthReduction(tree) || applyOpFusion(tree)) {
3338         // ...
3339     }
3340 
3341     applyValueNumbering(tree);
3342 
3343     tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3344     {
3345         if (node.op.type == ExprOpType::MUX)
3346             return;
3347         if (found.find(node.valueNum) != found.end())
3348             return;
3349 
3350         ExprInstruction opcode(node.op);
3351         opcode.dst = node.valueNum;
3352 
3353         if (node.left) {
3354             assert(node.left->valueNum >= 0);
3355             opcode.src1 = node.left->valueNum;
3356         }
3357         if (node.right) {
3358             if (node.right->op.type == ExprOpType::MUX) {
3359                 assert(node.right->left->valueNum >= 0);
3360                 assert(node.right->right->valueNum >= 0);
3361                 opcode.src2 = node.right->left->valueNum;
3362                 opcode.src3 = node.right->right->valueNum;
3363             } else {
3364                 assert(node.right->valueNum >= 0);
3365                 opcode.src2 = node.right->valueNum;
3366             }
3367         }
3368 
3369         code.push_back(opcode);
3370         found.insert(node.valueNum);
3371     });
3372 
3373     ExprInstruction store(ExprOpType::MEM_STORE_U8);
3374 
3375     if (format->sampleType == stInteger && format->bytesPerSample == 1)
3376         store.op.type = ExprOpType::MEM_STORE_U8;
3377     else if (format->sampleType == stInteger && format->bytesPerSample == 2)
3378         store.op.type = ExprOpType::MEM_STORE_U16;
3379     else if (format->sampleType == stFloat && format->bytesPerSample == 2)
3380         store.op.type = ExprOpType::MEM_STORE_F16;
3381     else if (format->sampleType == stFloat && format->bytesPerSample == 4)
3382         store.op.type = ExprOpType::MEM_STORE_F32;
3383 
3384     if (store.op.type == ExprOpType::MEM_STORE_U16)
3385         store.op.imm.u = format->bitsPerSample;
3386 
3387     store.src1 = code.back().dst;
3388     code.push_back(store);
3389 
3390     renameRegisters(code);
3391     return code;
3392 }
3393 
exprInit(VSMap * in,VSMap * out,void ** instanceData,VSNode * node,VSCore * core,const VSAPI * vsapi)3394 static void VS_CC exprInit(VSMap *in, VSMap *out, void **instanceData, VSNode *node, VSCore *core, const VSAPI *vsapi) {
3395     ExprData *d = static_cast<ExprData *>(*instanceData);
3396     vsapi->setVideoInfo(&d->vi, 1, node);
3397 }
3398 
exprGetFrame(int n,int activationReason,void ** instanceData,void ** frameData,VSFrameContext * frameCtx,VSCore * core,const VSAPI * vsapi)3399 static const VSFrameRef *VS_CC exprGetFrame(int n, int activationReason, void **instanceData, void **frameData, VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi) {
3400     ExprData *d = static_cast<ExprData *>(*instanceData);
3401     int numInputs = d->numInputs;
3402 
3403     if (activationReason == arInitial) {
3404         for (int i = 0; i < numInputs; i++)
3405             vsapi->requestFrameFilter(n, d->node[i], frameCtx);
3406     } else if (activationReason == arAllFramesReady) {
3407         const VSFrameRef *src[MAX_EXPR_INPUTS] = {};
3408         for (int i = 0; i < numInputs; i++)
3409             src[i] = vsapi->getFrameFilter(n, d->node[i], frameCtx);
3410 
3411         const VSFormat *fi = d->vi.format;
3412         int height = vsapi->getFrameHeight(src[0], 0);
3413         int width = vsapi->getFrameWidth(src[0], 0);
3414         int planes[3] = { 0, 1, 2 };
3415         const VSFrameRef *srcf[3] = { d->plane[0] != poCopy ? nullptr : src[0], d->plane[1] != poCopy ? nullptr : src[0], d->plane[2] != poCopy ? nullptr : src[0] };
3416         VSFrameRef *dst = vsapi->newVideoFrame2(fi, width, height, srcf, planes, src[0], core);
3417 
3418         const uint8_t *srcp[MAX_EXPR_INPUTS] = {};
3419         int src_stride[MAX_EXPR_INPUTS] = {};
3420         alignas(32) intptr_t ptroffsets[((MAX_EXPR_INPUTS + 1) + 7) & ~7] = { d->vi.format->bytesPerSample * 8 };
3421 
3422         for (int plane = 0; plane < d->vi.format->numPlanes; plane++) {
3423             if (d->plane[plane] != poProcess)
3424                 continue;
3425 
3426             for (int i = 0; i < numInputs; i++) {
3427                 if (d->node[i]) {
3428                     srcp[i] = vsapi->getReadPtr(src[i], plane);
3429                     src_stride[i] = vsapi->getStride(src[i], plane);
3430                     ptroffsets[i + 1] = vsapi->getFrameFormat(src[i])->bytesPerSample * 8;
3431                 }
3432             }
3433 
3434             uint8_t *dstp = vsapi->getWritePtr(dst, plane);
3435             int dst_stride = vsapi->getStride(dst, plane);
3436             int h = vsapi->getFrameHeight(dst, plane);
3437             int w = vsapi->getFrameWidth(dst, plane);
3438 
3439             if (d->proc[plane]) {
3440                 ExprData::ProcessLineProc proc = d->proc[plane];
3441                 int niterations = (w + 7) / 8;
3442 
3443                 for (int i = 0; i < numInputs; i++) {
3444                     if (d->node[i])
3445                         ptroffsets[i + 1] = vsapi->getFrameFormat(src[i])->bytesPerSample * 8;
3446                 }
3447 
3448                 for (int y = 0; y < h; y++) {
3449                     alignas(32) uint8_t *rwptrs[((MAX_EXPR_INPUTS + 1) + 7) & ~7] = { dstp + dst_stride * y };
3450                     for (int i = 0; i < numInputs; i++) {
3451                         rwptrs[i + 1] = const_cast<uint8_t *>(srcp[i] + src_stride[i] * y);
3452                     }
3453                     proc(rwptrs, ptroffsets, niterations);
3454                 }
3455             } else {
3456                 ExprInterpreter interpreter(d->bytecode[plane].data(), d->bytecode[plane].size());
3457 
3458                 for (int y = 0; y < h; y++) {
3459                     for (int x = 0; x < w; x++) {
3460                         interpreter.eval(srcp, dstp, x);
3461                     }
3462 
3463                     for (int i = 0; i < numInputs; i++) {
3464                         srcp[i] += src_stride[i];
3465                     }
3466                     dstp += dst_stride;
3467                 }
3468             }
3469         }
3470 
3471         for (int i = 0; i < MAX_EXPR_INPUTS; i++) {
3472             vsapi->freeFrame(src[i]);
3473         }
3474         return dst;
3475     }
3476 
3477     return nullptr;
3478 }
3479 
exprFree(void * instanceData,VSCore * core,const VSAPI * vsapi)3480 static void VS_CC exprFree(void *instanceData, VSCore *core, const VSAPI *vsapi) {
3481     ExprData *d = static_cast<ExprData *>(instanceData);
3482     for (int i = 0; i < MAX_EXPR_INPUTS; i++)
3483         vsapi->freeNode(d->node[i]);
3484     delete d;
3485 }
3486 
exprCreate(const VSMap * in,VSMap * out,void * userData,VSCore * core,const VSAPI * vsapi)3487 static void VS_CC exprCreate(const VSMap *in, VSMap *out, void *userData, VSCore *core, const VSAPI *vsapi) {
3488     std::unique_ptr<ExprData> d(new ExprData);
3489     int err;
3490 
3491 #ifdef VS_TARGET_CPU_X86
3492     const CPUFeatures &f = *getCPUFeatures();
3493 #   define EXPR_F16C_TEST (f.f16c)
3494 #else
3495 #   define EXPR_F16C_TEST (false)
3496 #endif
3497 
3498     try {
3499         d->numInputs = vsapi->propNumElements(in, "clips");
3500         if (d->numInputs > 26)
3501             throw std::runtime_error("More than 26 input clips provided");
3502 
3503         for (int i = 0; i < d->numInputs; i++) {
3504             d->node[i] = vsapi->propGetNode(in, "clips", i, &err);
3505         }
3506 
3507         const VSVideoInfo *vi[MAX_EXPR_INPUTS] = {};
3508         for (int i = 0; i < d->numInputs; i++) {
3509             if (d->node[i])
3510                 vi[i] = vsapi->getVideoInfo(d->node[i]);
3511         }
3512 
3513         for (int i = 0; i < d->numInputs; i++) {
3514             if (!isConstantFormat(vi[i]))
3515                 throw std::runtime_error("Only clips with constant format and dimensions allowed");
3516             if (vi[0]->format->numPlanes != vi[i]->format->numPlanes
3517                 || vi[0]->format->subSamplingW != vi[i]->format->subSamplingW
3518                 || vi[0]->format->subSamplingH != vi[i]->format->subSamplingH
3519                 || vi[0]->width != vi[i]->width
3520                 || vi[0]->height != vi[i]->height)
3521             {
3522                 throw std::runtime_error("All inputs must have the same number of planes and the same dimensions, subsampling included");
3523             }
3524 
3525             if (EXPR_F16C_TEST) {
3526                 if ((vi[i]->format->bitsPerSample > 16 && vi[i]->format->sampleType == stInteger)
3527                     || (vi[i]->format->bitsPerSample != 16 && vi[i]->format->bitsPerSample != 32 && vi[i]->format->sampleType == stFloat))
3528                     throw std::runtime_error("Input clips must be 8-16 bit integer or 16/32 bit float format");
3529             } else {
3530                 if ((vi[i]->format->bitsPerSample > 16 && vi[i]->format->sampleType == stInteger)
3531                     || (vi[i]->format->bitsPerSample != 32 && vi[i]->format->sampleType == stFloat))
3532                     throw std::runtime_error("Input clips must be 8-16 bit integer or 32 bit float format");
3533             }
3534         }
3535 
3536         d->vi = *vi[0];
3537         int format = int64ToIntS(vsapi->propGetInt(in, "format", 0, &err));
3538         if (!err) {
3539             const VSFormat *f = vsapi->getFormatPreset(format, core);
3540             if (f) {
3541                 if (d->vi.format->colorFamily == cmCompat)
3542                     throw std::runtime_error("No compat formats allowed");
3543                 if (d->vi.format->numPlanes != f->numPlanes)
3544                     throw std::runtime_error("The number of planes in the inputs and output must match");
3545                 d->vi.format = vsapi->registerFormat(d->vi.format->colorFamily, f->sampleType, f->bitsPerSample, d->vi.format->subSamplingW, d->vi.format->subSamplingH, core);
3546             }
3547         }
3548 
3549         int nexpr = vsapi->propNumElements(in, "expr");
3550         if (nexpr > d->vi.format->numPlanes)
3551             throw std::runtime_error("More expressions given than there are planes");
3552 
3553         std::string expr[3];
3554         for (int i = 0; i < nexpr; i++) {
3555             expr[i] = vsapi->propGetData(in, "expr", i, nullptr);
3556         }
3557         for (int i = nexpr; i < 3; ++i) {
3558             expr[i] = expr[nexpr - 1];
3559         }
3560 
3561         for (int i = 0; i < d->vi.format->numPlanes; i++) {
3562             if (!expr[i].empty()) {
3563                 d->plane[i] = poProcess;
3564             } else {
3565                 if (d->vi.format->bitsPerSample == vi[0]->format->bitsPerSample && d->vi.format->sampleType == vi[0]->format->sampleType)
3566                     d->plane[i] = poCopy;
3567                 else
3568                     d->plane[i] = poUndefined;
3569             }
3570 
3571             if (d->plane[i] != poProcess)
3572                 continue;
3573 
3574             auto tree = parseExpr(expr[i], vi, d->numInputs);
3575             d->bytecode[i] = compile(tree, d->vi.format);
3576 
3577             int cpulevel = vs_get_cpulevel(core);
3578             if (cpulevel > VS_CPU_LEVEL_NONE) {
3579 #ifdef VS_TARGET_CPU_X86
3580                 std::unique_ptr<ExprCompiler> compiler = make_compiler(d->numInputs, cpulevel);
3581                 for (auto op : d->bytecode[i]) {
3582                     compiler->addInstruction(op);
3583                 }
3584 
3585                 std::tie(d->proc[i], d->procSize[i]) = compiler->getCode();
3586 #endif
3587             }
3588         }
3589 #ifdef VS_TARGET_OS_WINDOWS
3590         FlushInstructionCache(GetCurrentProcess(), nullptr, 0);
3591 #endif
3592     } catch (std::runtime_error &e) {
3593         for (int i = 0; i < MAX_EXPR_INPUTS; i++) {
3594             vsapi->freeNode(d->node[i]);
3595         }
3596         vsapi->setError(out, (std::string{ "Expr: " } + e.what()).c_str());
3597         return;
3598     }
3599 
3600     vsapi->createFilter(in, out, "Expr", exprInit, exprGetFrame, exprFree, fmParallel, 0, d.release(), core);
3601 }
3602 
3603 } // namespace
3604 
3605 
3606 //////////////////////////////////////////
3607 // Init
3608 
exprInitialize(VSConfigPlugin configFunc,VSRegisterFunction registerFunc,VSPlugin * plugin)3609 void VS_CC exprInitialize(VSConfigPlugin configFunc, VSRegisterFunction registerFunc, VSPlugin *plugin) {
3610     //configFunc("com.vapoursynth.expr", "expr", "VapourSynth Expr Filter", VAPOURSYNTH_API_VERSION, 1, plugin);
3611     registerFunc("Expr", "clips:clip[];expr:data[];format:int:opt;", exprCreate, nullptr, plugin);
3612 }
3613