1 /*
2 * Copyright (c) 2012-2019 Fredrik Mellbin
3 *
4 * This file is part of VapourSynth.
5 *
6 * VapourSynth is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * VapourSynth is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with VapourSynth; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21 #include <algorithm>
22 #include <cmath>
23 #include <functional>
24 #include <iostream>
25 #include <limits>
26 #include <locale>
27 #include <map>
28 #include <memory>
29 #include <set>
30 #include <sstream>
31 #include <stdexcept>
32 #include <string>
33 #include <tuple>
34 #include <unordered_map>
35 #include <unordered_set>
36 #include <vector>
37 #include "VapourSynth.h"
38 #include "VSHelper.h"
39 #include "cpufeatures.h"
40 #include "internalfilters.h"
41 #include "vslog.h"
42 #include "kernel/cpulevel.h"
43
44 #ifdef VS_TARGET_CPU_X86
45 #include <immintrin.h>
46 #define NOMINMAX
47 #include "jitasm.h"
48 #ifndef VS_TARGET_OS_WINDOWS
49 #include <sys/mman.h>
50 #endif
51 #endif
52
53 namespace {
54
55 #define MAX_EXPR_INPUTS 26
56
57 enum class ExprOpType {
58 // Terminals.
59 MEM_LOAD_U8, MEM_LOAD_U16, MEM_LOAD_F16, MEM_LOAD_F32, CONSTANT,
60 MEM_STORE_U8, MEM_STORE_U16, MEM_STORE_F16, MEM_STORE_F32,
61
62 // Arithmetic primitives.
63 ADD, SUB, MUL, DIV, FMA, SQRT, ABS, NEG, MAX, MIN, CMP,
64
65 // Logical operators.
66 AND, OR, XOR, NOT,
67
68 // Transcendental functions.
69 EXP, LOG, POW, SIN, COS,
70
71 // Ternary operator
72 TERNARY,
73
74 // Meta-node holding true/false branches of ternary.
75 MUX,
76
77 // Stack helpers.
78 DUP, SWAP,
79 };
80
81 enum class FMAType {
82 FMADD = 0, // (b * c) + a
83 FMSUB = 1, // (b * c) - a
84 FNMADD = 2, // -(b * c) + a
85 FNMSUB = 3, // -(b * c) - a
86 };
87
88 enum class ComparisonType {
89 EQ = 0,
90 LT = 1,
91 LE = 2,
92 NEQ = 4,
93 NLT = 5,
94 NLE = 6,
95 };
96
97 #ifdef VS_TARGET_CPU_X86
98 static_assert(static_cast<int>(ComparisonType::EQ) == _CMP_EQ_OQ, "");
99 static_assert(static_cast<int>(ComparisonType::LT) == _CMP_LT_OS, "");
100 static_assert(static_cast<int>(ComparisonType::LE) == _CMP_LE_OS, "");
101 static_assert(static_cast<int>(ComparisonType::NEQ) == _CMP_NEQ_UQ, "");
102 static_assert(static_cast<int>(ComparisonType::NLT) == _CMP_NLT_US, "");
103 static_assert(static_cast<int>(ComparisonType::NLE) == _CMP_NLE_US, "");
104 #endif
105
106 union ExprUnion {
107 int32_t i;
108 uint32_t u;
109 float f;
110
ExprUnion()111 constexpr ExprUnion() : u{} {}
112
ExprUnion(int32_t i)113 constexpr ExprUnion(int32_t i) : i(i) {}
ExprUnion(uint32_t u)114 constexpr ExprUnion(uint32_t u) : u(u) {}
ExprUnion(float f)115 constexpr ExprUnion(float f) : f(f) {}
116 };
117
118 struct ExprOp {
119 ExprOpType type;
120 ExprUnion imm;
121
ExprOp__anon2a0524d50111::ExprOp122 ExprOp(ExprOpType type, ExprUnion param = {}) : type(type), imm(param) {}
123 };
124
operator ==(const ExprOp & lhs,const ExprOp & rhs)125 bool operator==(const ExprOp &lhs, const ExprOp &rhs) { return lhs.type == rhs.type && lhs.imm.u == rhs.imm.u; }
operator !=(const ExprOp & lhs,const ExprOp & rhs)126 bool operator!=(const ExprOp &lhs, const ExprOp &rhs) { return !(lhs == rhs); }
127
128 struct ExprInstruction {
129 ExprOp op;
130 int dst;
131 int src1;
132 int src2;
133 int src3;
134
ExprInstruction__anon2a0524d50111::ExprInstruction135 ExprInstruction(ExprOp op) : op(op), dst(-1), src1(-1), src2(-1), src3(-1) {}
136 };
137
138 enum PlaneOp {
139 poProcess, poCopy, poUndefined
140 };
141
142 struct ExprData {
143 VSNodeRef *node[MAX_EXPR_INPUTS];
144 VSVideoInfo vi;
145 std::vector<ExprInstruction> bytecode[3];
146 int plane[3];
147 int numInputs;
148 typedef void (*ProcessLineProc)(void *rwptrs, intptr_t ptroff[MAX_EXPR_INPUTS + 1], intptr_t niter);
149 ProcessLineProc proc[3];
150 size_t procSize[3];
151
ExprData__anon2a0524d50111::ExprData152 ExprData() : node(), vi(), plane(), numInputs(), proc() {}
153
~ExprData__anon2a0524d50111::ExprData154 ~ExprData() {
155 #ifdef VS_TARGET_CPU_X86
156 for (int i = 0; i < 3; i++) {
157 if (proc[i]) {
158 #ifdef VS_TARGET_OS_WINDOWS
159 VirtualFree((LPVOID)proc[i], 0, MEM_RELEASE);
160 #else
161 munmap((void *)proc[i], procSize[i]);
162 #endif
163 }
164 }
165 #endif
166 }
167 };
168
169 #ifdef VS_TARGET_CPU_X86
170 class ExprCompiler {
171 virtual void load8(const ExprInstruction &insn) = 0;
172 virtual void load16(const ExprInstruction &insn) = 0;
173 virtual void loadF16(const ExprInstruction &insn) = 0;
174 virtual void loadF32(const ExprInstruction &insn) = 0;
175 virtual void loadConst(const ExprInstruction &insn) = 0;
176 virtual void store8(const ExprInstruction &insn) = 0;
177 virtual void store16(const ExprInstruction &insn) = 0;
178 virtual void storeF16(const ExprInstruction &insn) = 0;
179 virtual void storeF32(const ExprInstruction &insn) = 0;
180 virtual void add(const ExprInstruction &insn) = 0;
181 virtual void sub(const ExprInstruction &insn) = 0;
182 virtual void mul(const ExprInstruction &insn) = 0;
183 virtual void div(const ExprInstruction &insn) = 0;
184 virtual void fma(const ExprInstruction &insn) = 0;
185 virtual void max(const ExprInstruction &insn) = 0;
186 virtual void min(const ExprInstruction &insn) = 0;
187 virtual void sqrt(const ExprInstruction &insn) = 0;
188 virtual void abs(const ExprInstruction &insn) = 0;
189 virtual void neg(const ExprInstruction &insn) = 0;
190 virtual void not_(const ExprInstruction &insn) = 0;
191 virtual void and_(const ExprInstruction &insn) = 0;
192 virtual void or_(const ExprInstruction &insn) = 0;
193 virtual void xor_(const ExprInstruction &insn) = 0;
194 virtual void cmp(const ExprInstruction &insn) = 0;
195 virtual void ternary(const ExprInstruction &insn) = 0;
196 virtual void exp(const ExprInstruction &insn) = 0;
197 virtual void log(const ExprInstruction &insn) = 0;
198 virtual void pow(const ExprInstruction &insn) = 0;
199 virtual void sin(const ExprInstruction &insn) = 0;
200 virtual void cos(const ExprInstruction &insn) = 0;
201 public:
addInstruction(const ExprInstruction & insn)202 void addInstruction(const ExprInstruction &insn)
203 {
204 switch (insn.op.type) {
205 case ExprOpType::MEM_LOAD_U8: load8(insn); break;
206 case ExprOpType::MEM_LOAD_U16: load16(insn); break;
207 case ExprOpType::MEM_LOAD_F16: loadF16(insn); break;
208 case ExprOpType::MEM_LOAD_F32: loadF32(insn); break;
209 case ExprOpType::CONSTANT: loadConst(insn); break;
210 case ExprOpType::MEM_STORE_U8: store8(insn); break;
211 case ExprOpType::MEM_STORE_U16: store16(insn); break;
212 case ExprOpType::MEM_STORE_F16: storeF16(insn); break;
213 case ExprOpType::MEM_STORE_F32: storeF32(insn); break;
214 case ExprOpType::ADD: add(insn); break;
215 case ExprOpType::SUB: sub(insn); break;
216 case ExprOpType::MUL: mul(insn); break;
217 case ExprOpType::DIV: div(insn); break;
218 case ExprOpType::FMA: fma(insn); break;
219 case ExprOpType::MAX: max(insn); break;
220 case ExprOpType::MIN: min(insn); break;
221 case ExprOpType::SQRT: sqrt(insn); break;
222 case ExprOpType::ABS: abs(insn); break;
223 case ExprOpType::NEG: neg(insn); break;
224 case ExprOpType::NOT: not_(insn); break;
225 case ExprOpType::AND: and_(insn); break;
226 case ExprOpType::OR: or_(insn); break;
227 case ExprOpType::XOR: xor_(insn); break;
228 case ExprOpType::CMP: cmp(insn); break;
229 case ExprOpType::TERNARY: ternary(insn); break;
230 case ExprOpType::EXP: exp(insn); break;
231 case ExprOpType::LOG: log(insn); break;
232 case ExprOpType::POW: pow(insn); break;
233 case ExprOpType::SIN: sin(insn); break;
234 case ExprOpType::COS: cos(insn); break;
235 default: vsFatal("illegal opcode"); break;
236 }
237 }
238
~ExprCompiler()239 virtual ~ExprCompiler() {}
240 virtual std::pair<ExprData::ProcessLineProc, size_t> getCode() = 0;
241 };
242
243 class ExprCompiler128 : public ExprCompiler, private jitasm::function<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t> {
244 typedef jitasm::function<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t> jit;
245 friend struct jitasm::function<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t>;
246 friend struct jitasm::function_cdecl<void, ExprCompiler128, uint8_t *, const intptr_t *, intptr_t>;
247
248 #define SPLAT(x) { (x), (x), (x), (x) }
249 static constexpr ExprUnion constData alignas(16)[53][4] = {
250 SPLAT(0x7FFFFFFF), // absmask
251 SPLAT(0x80000000), // negmask
252 SPLAT(0x7F), // x7F
253 SPLAT(0x00800000), // min_norm_pos
254 SPLAT(~0x7F800000), // inv_mant_mask
255 SPLAT(1.0f), // float_one
256 SPLAT(0.5f), // float_half
257 SPLAT(255.0f), // float_255
258 SPLAT(511.0f), // float_511
259 SPLAT(1023.0f), // float_1023
260 SPLAT(2047.0f), // float_2047
261 SPLAT(4095.0f), // float_4095
262 SPLAT(8191.0f), // float_8191
263 SPLAT(16383.0f), // float_16383
264 SPLAT(32767.0f), // float_32767
265 SPLAT(65535.0f), // float_65535
266 SPLAT(static_cast<int32_t>(0x80008000)), // i16min_epi16
267 SPLAT(static_cast<int32_t>(0xFFFF8000)), // i16min_epi32
268 SPLAT(88.3762626647949f), // exp_hi
269 SPLAT(-88.3762626647949f), // exp_lo
270 SPLAT(1.44269504088896341f), // log2e
271 SPLAT(0.693359375f), // exp_c1
272 SPLAT(-2.12194440e-4f), // exp_c2
273 SPLAT(1.9875691500E-4f), // exp_p0
274 SPLAT(1.3981999507E-3f), // exp_p1
275 SPLAT(8.3334519073E-3f), // exp_p2
276 SPLAT(4.1665795894E-2f), // exp_p3
277 SPLAT(1.6666665459E-1f), // exp_p4
278 SPLAT(5.0000001201E-1f), // exp_p5
279 SPLAT(0.707106781186547524f), // sqrt_1_2
280 SPLAT(7.0376836292E-2f), // log_p0
281 SPLAT(-1.1514610310E-1f), // log_p1
282 SPLAT(1.1676998740E-1f), // log_p2
283 SPLAT(-1.2420140846E-1f), // log_p3
284 SPLAT(+1.4249322787E-1f), // log_p4
285 SPLAT(-1.6668057665E-1f), // log_p5
286 SPLAT(+2.0000714765E-1f), // log_p6
287 SPLAT(-2.4999993993E-1f), // log_p7
288 SPLAT(+3.3333331174E-1f), // log_p8
289 SPLAT(0x3ea2f983), // float_invpi, 1/pi
290 SPLAT(0x4b400000), // float_rintf
291 SPLAT(0x40490000), // float_pi1
292 SPLAT(0x3a7da000), // float_pi2
293 SPLAT(0x34222000), // float_pi3
294 SPLAT(0x2cb4611a), // float_pi4
295 SPLAT(0xbe2aaaa6), // float_sinC3
296 SPLAT(0x3c08876a), // float_sinC5
297 SPLAT(0xb94fb7ff), // float_sinC7
298 SPLAT(0x362edef8), // float_sinC9
299 SPLAT(static_cast<int32_t>(0xBEFFFFE2)), // float_cosC2
300 SPLAT(0x3D2AA73C), // float_cosC4
301 SPLAT(static_cast<int32_t>(0XBAB58D50)), // float_cosC6
302 SPLAT(0x37C1AD76), // float_cosC8
303 };
304
305 struct ConstantIndex {
306 static constexpr int absmask = 0;
307 static constexpr int negmask = 1;
308 static constexpr int x7F = 2;
309 static constexpr int min_norm_pos = 3;
310 static constexpr int inv_mant_mask = 4;
311 static constexpr int float_one = 5;
312 static constexpr int float_half = 6;
313 static constexpr int float_255 = 7;
314 static constexpr int float_511 = 8;
315 static constexpr int float_1023 = 9;
316 static constexpr int float_2047 = 10;
317 static constexpr int float_4095 = 11;
318 static constexpr int float_8191 = 12;
319 static constexpr int float_16383 = 13;
320 static constexpr int float_32767 = 14;
321 static constexpr int float_65535 = 15;
322 static constexpr int i16min_epi16 = 16;
323 static constexpr int i16min_epi32 = 17;
324 static constexpr int exp_hi = 18;
325 static constexpr int exp_lo = 19;
326 static constexpr int log2e = 20;
327 static constexpr int exp_c1 = 21;
328 static constexpr int exp_c2 = 22;
329 static constexpr int exp_p0 = 23;
330 static constexpr int exp_p1 = 24;
331 static constexpr int exp_p2 = 25;
332 static constexpr int exp_p3 = 26;
333 static constexpr int exp_p4 = 27;
334 static constexpr int exp_p5 = 28;
335 static constexpr int sqrt_1_2 = 29;
336 static constexpr int log_p0 = 30;
337 static constexpr int log_p1 = 31;
338 static constexpr int log_p2 = 32;
339 static constexpr int log_p3 = 33;
340 static constexpr int log_p4 = 34;
341 static constexpr int log_p5 = 35;
342 static constexpr int log_p6 = 36;
343 static constexpr int log_p7 = 37;
344 static constexpr int log_p8 = 38;
345 static constexpr int log_q1 = exp_c2;
346 static constexpr int log_q2 = exp_c1;
347 static constexpr int float_invpi = 39;
348 static constexpr int float_rintf = 40;
349 static constexpr int float_pi1 = 41;
350 static constexpr int float_pi2 = float_pi1 + 1;
351 static constexpr int float_pi3 = float_pi1 + 2;
352 static constexpr int float_pi4 = float_pi1 + 3;
353 static constexpr int float_sinC3 = 45;
354 static constexpr int float_sinC5 = float_sinC3 + 1;
355 static constexpr int float_sinC7 = float_sinC3 + 2;
356 static constexpr int float_sinC9 = float_sinC3 + 3;
357 static constexpr int float_cosC2 = 49;
358 static constexpr int float_cosC4 = float_cosC2 + 1;
359 static constexpr int float_cosC6 = float_cosC2 + 2;
360 static constexpr int float_cosC8 = float_cosC2 + 3;
361 };
362 #undef SPLAT
363
364 // JitASM compiles everything from main(), so record the operations for later.
365 std::vector<std::function<void(Reg, XmmReg, Reg, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &)>> deferred;
366
367 CPUFeatures cpuFeatures;
368 int numInputs;
369 int curLabel;
370
371 #define EMIT() [this, insn](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
372 #define VEX1(op, arg1, arg2) \
373 do { \
374 if (cpuFeatures.avx) \
375 v##op(arg1, arg2); \
376 else \
377 op(arg1, arg2); \
378 } while (0)
379 #define VEX1IMM(op, arg1, arg2, imm) \
380 do { \
381 if (cpuFeatures.avx) { \
382 v##op(arg1, arg2, imm); \
383 } else if (arg1 == arg2) { \
384 op(arg2, imm); \
385 } else { \
386 movdqa(arg1, arg2); \
387 op(arg1, imm); \
388 } \
389 } while (0)
390 #define VEX2(op, arg1, arg2, arg3) \
391 do { \
392 if (cpuFeatures.avx) { \
393 v##op(arg1, arg2, arg3); \
394 } else if (arg1 == arg2) { \
395 op(arg2, arg3); \
396 } else if (arg1 != arg3) { \
397 movdqa(arg1, arg2); \
398 op(arg1, arg3); \
399 } else { \
400 XmmReg tmp; \
401 movdqa(tmp, arg2); \
402 op(tmp, arg3); \
403 movdqa(arg1, tmp); \
404 } \
405 } while (0)
406 #define VEX2IMM(op, arg1, arg2, arg3, imm) \
407 do { \
408 if (cpuFeatures.avx) { \
409 v##op(arg1, arg2, arg3, imm); \
410 } else if (arg1 == arg2) { \
411 op(arg2, arg3, imm); \
412 } else if (arg1 != arg3) { \
413 movdqa(arg1, arg2); \
414 op(arg1, arg3, imm); \
415 } else { \
416 XmmReg tmp; \
417 movdqa(tmp, arg2); \
418 op(tmp, arg3, imm); \
419 movdqa(arg1, tmp); \
420 } \
421 } while (0)
422
load8(const ExprInstruction & insn)423 void load8(const ExprInstruction &insn) override
424 {
425 deferred.push_back(EMIT()
426 {
427 auto t1 = bytecodeRegs[insn.dst];
428 Reg a;
429 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
430 VEX1(movq, t1.first, mmword_ptr[a]);
431 VEX2(punpcklbw, t1.first, t1.first, zero);
432 VEX2(punpckhwd, t1.second, t1.first, zero);
433 VEX2(punpcklwd, t1.first, t1.first, zero);
434 VEX1(cvtdq2ps, t1.first, t1.first);
435 VEX1(cvtdq2ps, t1.second, t1.second);
436 });
437 }
438
load16(const ExprInstruction & insn)439 void load16(const ExprInstruction &insn) override
440 {
441 deferred.push_back(EMIT()
442 {
443 auto t1 = bytecodeRegs[insn.dst];
444 Reg a;
445 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
446 VEX1(movdqa, t1.first, xmmword_ptr[a]);
447 VEX2(punpckhwd, t1.second, t1.first, zero);
448 VEX2(punpcklwd, t1.first, t1.first, zero);
449 VEX1(cvtdq2ps, t1.first, t1.first);
450 VEX1(cvtdq2ps, t1.second, t1.second);
451 });
452 }
453
loadF16(const ExprInstruction & insn)454 void loadF16(const ExprInstruction &insn) override
455 {
456 deferred.push_back(EMIT()
457 {
458 auto t1 = bytecodeRegs[insn.dst];
459 Reg a;
460 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
461 vcvtph2ps(t1.first, qword_ptr[a]);
462 vcvtph2ps(t1.second, qword_ptr[a + 8]);
463 });
464 }
465
loadF32(const ExprInstruction & insn)466 void loadF32(const ExprInstruction &insn) override
467 {
468 deferred.push_back(EMIT()
469 {
470 auto t1 = bytecodeRegs[insn.dst];
471 Reg a;
472 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
473 VEX1(movdqa, t1.first, xmmword_ptr[a]);
474 VEX1(movdqa, t1.second, xmmword_ptr[a + 16]);
475 });
476 }
477
loadConst(const ExprInstruction & insn)478 void loadConst(const ExprInstruction &insn) override
479 {
480 deferred.push_back(EMIT()
481 {
482 auto t1 = bytecodeRegs[insn.dst];
483
484 if (insn.op.imm.f == 0.0f) {
485 VEX1(movaps, t1.first, zero);
486 VEX1(movaps, t1.second, zero);
487 return;
488 }
489
490 Reg32 a;
491 mov(a, insn.op.imm.u);
492 VEX1(movd, t1.first, a);
493 VEX2IMM(shufps, t1.first, t1.first, t1.first, 0);
494 VEX1(movaps, t1.second, t1.first);
495 });
496 }
497
store8(const ExprInstruction & insn)498 void store8(const ExprInstruction &insn) override
499 {
500 deferred.push_back(EMIT()
501 {
502 auto t1 = bytecodeRegs[insn.src1];
503 XmmReg r1, r2, limit;
504 Reg a;
505 VEX1(movaps, limit, xmmword_ptr[constants + ConstantIndex::float_255 * 16]);
506 VEX2(minps, r1, t1.first, limit);
507 VEX2(minps, r2, t1.second, limit);
508 VEX1(cvtps2dq, r1, r1);
509 VEX1(cvtps2dq, r2, r2);
510 VEX2(packssdw, r1, r1, r2);
511 VEX2(packuswb, r1, r1, zero);
512 mov(a, ptr[regptrs]);
513 VEX1(movq, mmword_ptr[a], r1);
514 });
515 }
516
store16(const ExprInstruction & insn)517 void store16(const ExprInstruction &insn) override
518 {
519 deferred.push_back(EMIT()
520 {
521 int depth = insn.op.imm.u;
522 auto t1 = bytecodeRegs[insn.src1];
523 XmmReg r1, r2, limit;
524 Reg a;
525 VEX1(movaps, limit, xmmword_ptr[constants + (ConstantIndex::float_255 + depth - 8) * 16]);
526 VEX2IMM(shufps, limit, limit, limit, 0);
527 VEX2(minps, r1, t1.first, limit);
528 VEX2(minps, r2, t1.second, limit);
529 VEX1(cvtps2dq, r1, r1);
530 VEX1(cvtps2dq, r2, r2);
531
532 if (cpuFeatures.sse4_1) {
533 VEX2(packusdw, r1, r1, r2);
534 } else {
535 if (depth >= 16) {
536 VEX1(movaps, limit, xmmword_ptr[constants + ConstantIndex::i16min_epi32 * 16]);
537 VEX2(paddd, r1, r1, limit);
538 VEX2(paddd, r2, r2, limit);
539 }
540 VEX2(packssdw, r1, r1, r2);
541 if (depth >= 16)
542 VEX2(psubw, r1, r1, xmmword_ptr[constants + ConstantIndex::i16min_epi16 * 16]);
543 }
544 mov(a, ptr[regptrs]);
545 VEX1(movaps, xmmword_ptr[a], r1);
546 });
547 }
548
storeF16(const ExprInstruction & insn)549 void storeF16(const ExprInstruction &insn) override
550 {
551 deferred.push_back(EMIT()
552 {
553 auto t1 = bytecodeRegs[insn.src1];
554
555 Reg a;
556 mov(a, ptr[regptrs]);
557 vcvtps2ph(qword_ptr[a], t1.first, 0);
558 vcvtps2ph(qword_ptr[a + 8], t1.second, 0);
559 });
560 }
561
storeF32(const ExprInstruction & insn)562 void storeF32(const ExprInstruction &insn) override
563 {
564 deferred.push_back(EMIT()
565 {
566 auto t1 = bytecodeRegs[insn.src1];
567
568 Reg a;
569 mov(a, ptr[regptrs]);
570 VEX1(movaps, xmmword_ptr[a], t1.first);
571 VEX1(movaps, xmmword_ptr[a + 16], t1.second);
572 });
573 }
574
575 #define BINARYOP(op) \
576 do { \
577 auto t1 = bytecodeRegs[insn.src1]; \
578 auto t2 = bytecodeRegs[insn.src2]; \
579 auto t3 = bytecodeRegs[insn.dst]; \
580 VEX2(op, t3.first, t1.first, t2.first); \
581 VEX2(op, t3.second, t1.second, t2.second); \
582 } while (0)
add(const ExprInstruction & insn)583 void add(const ExprInstruction &insn) override
584 {
585 deferred.push_back(EMIT()
586 {
587 BINARYOP(addps);
588 });
589 }
590
sub(const ExprInstruction & insn)591 void sub(const ExprInstruction &insn) override
592 {
593 deferred.push_back(EMIT()
594 {
595 BINARYOP(subps);
596 });
597 }
598
mul(const ExprInstruction & insn)599 void mul(const ExprInstruction &insn) override
600 {
601 deferred.push_back(EMIT()
602 {
603 BINARYOP(mulps);
604 });
605 }
606
div(const ExprInstruction & insn)607 void div(const ExprInstruction &insn) override
608 {
609 deferred.push_back(EMIT()
610 {
611 BINARYOP(divps);
612 });
613 }
614
fma(const ExprInstruction & insn)615 void fma(const ExprInstruction &insn) override
616 {
617 deferred.push_back(EMIT()
618 {
619 FMAType type = static_cast<FMAType>(insn.op.imm.u);
620
621 // t1 + t2 * t3
622 auto t1 = bytecodeRegs[insn.src1];
623 auto t2 = bytecodeRegs[insn.src2];
624 auto t3 = bytecodeRegs[insn.src3];
625 auto t4 = bytecodeRegs[insn.dst];
626
627 if (cpuFeatures.fma3) {
628 #define FMA3(op) \
629 do { \
630 if (insn.dst == insn.src1) { \
631 v##op##231ps(t1.first, t2.first, t3.first); \
632 v##op##231ps(t1.second, t2.second, t3.second); \
633 } else if (insn.dst == insn.src2) { \
634 v##op##132ps(t2.first, t1.first, t3.first); \
635 v##op##132ps(t2.second, t1.second, t3.second); \
636 } else if (insn.dst == insn.src3) { \
637 v##op##132ps(t3.first, t1.first, t2.first); \
638 v##op##132ps(t3.second, t1.second, t2.second); \
639 } else { \
640 vmovaps(t4.first, t1.first); \
641 vmovaps(t4.second, t1.second); \
642 v##op##231ps(t4.first, t2.first, t3.first); \
643 v##op##231ps(t4.second, t2.second, t3.second); \
644 } \
645 } while (0)
646 switch (type) {
647 case FMAType::FMADD: FMA3(fmadd); break;
648 case FMAType::FMSUB: FMA3(fmsub); break;
649 case FMAType::FNMADD: FMA3(fnmadd); break;
650 case FMAType::FNMSUB: FMA3(fnmsub); break;
651 }
652 #undef FMA3
653 } else {
654 XmmReg r1, r2;
655 VEX2(mulps, r1, t2.first, t3.first);
656 VEX2(mulps, r2, t2.second, t3.second);
657
658 if (type == FMAType::FMADD || type == FMAType::FNMSUB) {
659 VEX2(addps, t4.first, r1, t1.first);
660 VEX2(addps, t4.second, r2, t1.second);
661 } else if (type == FMAType::FMSUB) {
662 VEX2(subps, t4.first, r1, t1.first);
663 VEX2(subps, t4.second, r2, t1.second);
664 } else if (type == FMAType::FNMADD) {
665 VEX2(subps, t4.first, t1.first, r1);
666 VEX2(subps, t4.second, t1.second, r2);
667 }
668
669 if (type == FMAType::FNMSUB) {
670 VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::negmask * 16]);
671 VEX2(xorps, t4.first, t4.first, r1);
672 VEX2(xorps, t4.second, t4.second, r2);
673 }
674 }
675 });
676 }
677
max(const ExprInstruction & insn)678 void max(const ExprInstruction &insn) override
679 {
680 deferred.push_back(EMIT()
681 {
682 BINARYOP(maxps);
683 });
684 }
685
min(const ExprInstruction & insn)686 void min(const ExprInstruction &insn) override
687 {
688 deferred.push_back(EMIT()
689 {
690 BINARYOP(minps);
691 });
692 }
693 #undef BINARYOP
694
sqrt(const ExprInstruction & insn)695 void sqrt(const ExprInstruction &insn) override
696 {
697 deferred.push_back(EMIT()
698 {
699 auto t1 = bytecodeRegs[insn.src1];
700 auto t2 = bytecodeRegs[insn.dst];
701 VEX2(maxps, t2.first, t1.first, zero);
702 VEX2(maxps, t2.second, t1.second, zero);
703 VEX1(sqrtps, t2.first, t2.first);
704 VEX1(sqrtps, t2.second, t2.second);
705 });
706 }
707
abs(const ExprInstruction & insn)708 void abs(const ExprInstruction &insn) override
709 {
710 deferred.push_back(EMIT()
711 {
712 auto t1 = bytecodeRegs[insn.src1];
713 auto t2 = bytecodeRegs[insn.dst];
714 XmmReg r1;
715 VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::absmask * 16]);
716 VEX2(andps, t2.first, t1.first, r1);
717 VEX2(andps, t2.second, t1.second, r1);
718 });
719 }
720
neg(const ExprInstruction & insn)721 void neg(const ExprInstruction &insn) override
722 {
723 deferred.push_back(EMIT()
724 {
725 auto t1 = bytecodeRegs[insn.src1];
726 auto t2 = bytecodeRegs[insn.dst];
727 XmmReg r1;
728 VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::negmask * 16]);
729 VEX2(xorps, t2.first, t1.first, r1);
730 VEX2(xorps, t2.second, t1.second, r1);
731 });
732 }
733
not_(const ExprInstruction & insn)734 void not_(const ExprInstruction &insn) override
735 {
736 deferred.push_back(EMIT()
737 {
738 auto t1 = bytecodeRegs[insn.src1];
739 auto t2 = bytecodeRegs[insn.dst];
740 XmmReg r1;
741 VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
742 VEX2IMM(cmpps, t2.first, t1.first, zero, _CMP_LE_OS);
743 VEX2IMM(cmpps, t2.second, t1.second, zero, _CMP_LE_OS);
744 VEX2(andps, t2.first, t2.first, r1);
745 VEX2(andps, t2.second, t2.second, r1);
746 });
747 }
748
749 #define LOGICOP(op) \
750 do { \
751 auto t1 = bytecodeRegs[insn.src1]; \
752 auto t2 = bytecodeRegs[insn.src2]; \
753 auto t3 = bytecodeRegs[insn.dst]; \
754 XmmReg r1, tmp1, tmp2; \
755 VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::float_one * 16]); \
756 VEX2IMM(cmpps, tmp1, t1.first, zero, _CMP_NLE_US); \
757 VEX2IMM(cmpps, tmp2, t1.second, zero, _CMP_NLE_US); \
758 VEX2IMM(cmpps, t3.first, t2.first, zero, _CMP_NLE_US); \
759 VEX2IMM(cmpps, t3.second, t2.second, zero, _CMP_NLE_US); \
760 VEX2(op, t3.first, t3.first, tmp1); \
761 VEX2(op, t3.second, t3.second, tmp2); \
762 VEX2(andps, t3.first, t3.first, r1); \
763 VEX2(andps, t3.second, t3.second, r1); \
764 } while (0)
765
and_(const ExprInstruction & insn)766 void and_(const ExprInstruction &insn) override
767 {
768 deferred.push_back(EMIT()
769 {
770 LOGICOP(andps);
771 });
772 }
773
or_(const ExprInstruction & insn)774 void or_(const ExprInstruction &insn) override
775 {
776 deferred.push_back(EMIT()
777 {
778 LOGICOP(orps);
779 });
780 }
781
xor_(const ExprInstruction & insn)782 void xor_(const ExprInstruction &insn) override
783 {
784 deferred.push_back(EMIT()
785 {
786 LOGICOP(xorps);
787 });
788 }
789 #undef LOGICOP
790
cmp(const ExprInstruction & insn)791 void cmp(const ExprInstruction &insn) override
792 {
793 deferred.push_back(EMIT()
794 {
795 auto t1 = bytecodeRegs[insn.src1];
796 auto t2 = bytecodeRegs[insn.src2];
797 auto t3 = bytecodeRegs[insn.dst];
798 XmmReg r1;
799 VEX1(movaps, r1, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
800 VEX2IMM(cmpps, t3.first, t1.first, t2.first, insn.op.imm.u);
801 VEX2IMM(cmpps, t3.second, t1.second, t2.second, insn.op.imm.u);
802 VEX2(andps, t3.first, t3.first, r1);
803 VEX2(andps, t3.second, t3.second, r1);
804 });
805 }
806
ternary(const ExprInstruction & insn)807 void ternary(const ExprInstruction &insn) override
808 {
809 deferred.push_back(EMIT()
810 {
811 auto t1 = bytecodeRegs[insn.src1];
812 auto t2 = bytecodeRegs[insn.src2];
813 auto t3 = bytecodeRegs[insn.src3];
814 auto t4 = bytecodeRegs[insn.dst];
815
816 XmmReg r1, r2;
817 VEX2IMM(cmpps, r1, t1.first, zero, _CMP_NLE_US);
818 VEX2IMM(cmpps, r2, t1.second, zero, _CMP_NLE_US);
819
820 if (cpuFeatures.sse4_1) {
821 VEX2IMM(blendvps, t4.first, t3.first, t2.first, r1);
822 VEX2IMM(blendvps, t4.second, t3.second, t2.second, r2);
823 } else {
824 VEX2(andps, t4.first, t3.first, r1);
825 VEX2(andps, t4.second, t3.second, r2);
826 VEX2(andnps, r1, r1, t2.first);
827 VEX2(andnps, r2, r2, t2.second);
828 VEX2(orps, t4.first, t4.first, r1);
829 VEX2(orps, t4.second, t4.second, r2);
830 }
831 });
832 }
833
exp_(XmmReg x,XmmReg one,Reg constants)834 void exp_(XmmReg x, XmmReg one, Reg constants)
835 {
836 XmmReg fx, emm0, etmp, y, mask, z;
837 VEX2(minps, x, x, xmmword_ptr[constants + ConstantIndex::exp_hi * 16]);
838 VEX2(maxps, x, x, xmmword_ptr[constants + ConstantIndex::exp_lo * 16]);
839 VEX2(mulps, fx, x, xmmword_ptr[constants + ConstantIndex::log2e * 16]);
840 VEX2(addps, fx, fx, xmmword_ptr[constants + ConstantIndex::float_half * 16]);
841 VEX1(cvttps2dq, emm0, fx);
842 VEX1(cvtdq2ps, etmp, emm0);
843 VEX2IMM(cmpps, mask, etmp, fx, _CMP_NLE_US);
844 VEX2(andps, mask, mask, one);
845 VEX2(subps, fx, etmp, mask);
846 VEX2(mulps, etmp, fx, xmmword_ptr[constants + ConstantIndex::exp_c1 * 16]);
847 VEX2(mulps, z, fx, xmmword_ptr[constants + ConstantIndex::exp_c2 * 16]);
848 VEX2(subps, x, x, etmp);
849 VEX2(subps, x, x, z);
850 VEX2(mulps, z, x, x);
851 VEX2(mulps, y, x, xmmword_ptr[constants + ConstantIndex::exp_p0 * 16]);
852 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p1 * 16]);
853 VEX2(mulps, y, y, x);
854 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p2 * 16]);
855 VEX2(mulps, y, y, x);
856 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p3 * 16]);
857 VEX2(mulps, y, y, x);
858 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p4 * 16]);
859 VEX2(mulps, y, y, x);
860 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::exp_p5 * 16]);
861 VEX2(mulps, y, y, z);
862 VEX2(addps, y, y, x);
863 VEX2(addps, y, y, one);
864 VEX1(cvttps2dq, emm0, fx);
865 VEX2(paddd, emm0, emm0, xmmword_ptr[constants + ConstantIndex::x7F * 16]);
866 VEX1IMM(pslld, emm0, emm0, 23);
867 VEX2(mulps, x, y, emm0);
868 }
869
log_(XmmReg x,XmmReg zero,XmmReg one,Reg constants)870 void log_(XmmReg x, XmmReg zero, XmmReg one, Reg constants)
871 {
872 XmmReg emm0, invalid_mask, mask, y, etmp, z;
873 VEX2IMM(cmpps, invalid_mask, zero, x, _CMP_NLT_US);
874 VEX2(maxps, x, x, xmmword_ptr[constants + ConstantIndex::min_norm_pos * 16]);
875 VEX1IMM(psrld, emm0, x, 23);
876 VEX2(andps, x, x, xmmword_ptr[constants + ConstantIndex::inv_mant_mask * 16]);
877 VEX2(orps, x, x, xmmword_ptr[constants + ConstantIndex::float_half * 16]);
878 VEX2(psubd, emm0, emm0, xmmword_ptr[constants + ConstantIndex::x7F * 16]);
879 VEX1(cvtdq2ps, emm0, emm0);
880 VEX2(addps, emm0, emm0, one);
881 VEX2IMM(cmpps, mask, x, xmmword_ptr[constants + ConstantIndex::sqrt_1_2 * 16], _CMP_LT_OS);
882 VEX2(andps, etmp, x, mask);
883 VEX2(subps, x, x, one);
884 VEX2(andps, mask, mask, one);
885 VEX2(subps, emm0, emm0, mask);
886 VEX2(addps, x, x, etmp);
887 VEX2(mulps, z, x, x);
888 VEX2(mulps, y, x, xmmword_ptr[constants + ConstantIndex::log_p0 * 16]);
889 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p1 * 16]);
890 VEX2(mulps, y, y, x);
891 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p2 * 16]);
892 VEX2(mulps, y, y, x);
893 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p3 * 16]);
894 VEX2(mulps, y, y, x);
895 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p4 * 16]);
896 VEX2(mulps, y, y, x);
897 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p5 * 16]);
898 VEX2(mulps, y, y, x);
899 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p6 * 16]);
900 VEX2(mulps, y, y, x);
901 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p7 * 16]);
902 VEX2(mulps, y, y, x);
903 VEX2(addps, y, y, xmmword_ptr[constants + ConstantIndex::log_p8 * 16]);
904 VEX2(mulps, y, y, x);
905 VEX2(mulps, y, y, z);
906 VEX2(mulps, etmp, emm0, xmmword_ptr[constants + ConstantIndex::log_q1 * 16]);
907 VEX2(addps, y, y, etmp);
908 VEX2(mulps, z, z, xmmword_ptr[constants + ConstantIndex::float_half * 16]);
909 VEX2(subps, y, y, z);
910 VEX2(mulps, emm0, emm0, xmmword_ptr[constants + ConstantIndex::log_q2 * 16]);
911 VEX2(addps, x, x, y);
912 VEX2(addps, x, x, emm0);
913 VEX2(orps, x, x, invalid_mask);
914 }
915
exp(const ExprInstruction & insn)916 void exp(const ExprInstruction &insn) override
917 {
918 int l = curLabel++;
919
920 deferred.push_back([this, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
921 {
922 char label[] = "label-0000";
923 sprintf(label, "label-%04d", l);
924
925 auto t1 = bytecodeRegs[insn.src1];
926 auto t2 = bytecodeRegs[insn.dst];
927 XmmReg r1, r2, one;
928 Reg a;
929 mov(a, 2);
930 VEX1(movaps, r1, t1.first);
931 VEX1(movaps, r2, t1.second);
932 VEX1(movaps, one, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
933
934 L(label);
935
936 exp_(r1, one, constants);
937 VEX1(movaps, t2.first, t2.second);
938 VEX1(movaps, t2.second, r1);
939 VEX1(movaps, r1, r2);
940
941 jit::sub(a, 1);
942 jnz(label);
943 });
944 }
945
log(const ExprInstruction & insn)946 void log(const ExprInstruction &insn) override
947 {
948 int l = curLabel++;
949
950 deferred.push_back([this, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
951 {
952 char label[] = "label-0000";
953 sprintf(label, "label-%04d", l);
954
955 auto t1 = bytecodeRegs[insn.src1];
956 auto t2 = bytecodeRegs[insn.dst];
957 XmmReg r1, r2, one;
958 Reg a;
959 mov(a, 2);
960 VEX1(movaps, r1, t1.first);
961 VEX1(movaps, r2, t1.second);
962 VEX1(movaps, one, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
963
964 L(label);
965
966 log_(r1, zero, one, constants);
967 VEX1(movaps, t2.first, t2.second);
968 VEX1(movaps, t2.second, r1);
969 VEX1(movaps, r1, r2);
970
971 jit::sub(a, 1);
972 jnz(label);
973 });
974 }
975
pow(const ExprInstruction & insn)976 void pow(const ExprInstruction &insn) override
977 {
978 int l = curLabel++;
979
980 deferred.push_back([this, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
981 {
982 char label[] = "label-0000";
983 sprintf(label, "label-%04d", l);
984
985 auto t1 = bytecodeRegs[insn.src1];
986 auto t2 = bytecodeRegs[insn.src2];
987 auto t3 = bytecodeRegs[insn.dst];
988
989 XmmReg r1, r2, r3, r4, one;
990 Reg a;
991 mov(a, 2);
992 VEX1(movaps, r1, t1.first);
993 VEX1(movaps, r2, t1.second);
994 VEX1(movaps, r3, t2.first);
995 VEX1(movaps, r4, t2.second);
996 VEX1(movaps, one, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
997
998 L(label);
999
1000 log_(r1, zero, one, constants);
1001 VEX2(mulps, r1, r1, r3);
1002 exp_(r1, one, constants);
1003
1004 VEX1(movaps, t3.first, t3.second);
1005 VEX1(movaps, t3.second, r1);
1006 VEX1(movaps, r1, r2);
1007 VEX1(movaps, r3, r4);
1008
1009 jit::sub(a, 1);
1010 jnz(label);
1011 });
1012 }
1013
sincos_(bool issin,XmmReg y,XmmReg x,Reg constants)1014 void sincos_(bool issin, XmmReg y, XmmReg x, Reg constants)
1015 {
1016 XmmReg t1, sign, t2, t3, t4;
1017 // Remove sign
1018 VEX1(movaps, t1, xmmword_ptr[constants + ConstantIndex::absmask * 16]);
1019 if (issin) {
1020 VEX1(movaps, sign, t1);
1021 VEX2(andnps, sign, sign, x);
1022 } else {
1023 VEX2(pxor, sign, sign, sign);
1024 }
1025 VEX2(andps, t1, t1, x);
1026 // Range reduction
1027 VEX1(movaps, t3, xmmword_ptr[constants + ConstantIndex::float_rintf * 16]);
1028 VEX2(mulps, t2, t1, xmmword_ptr[constants + ConstantIndex::float_invpi * 16]);
1029 VEX2(addps, t2, t2, t3);
1030 VEX1IMM(pslld, t4, t2, 31);
1031 VEX2(xorps, sign, sign, t4);
1032 VEX2(subps, t2, t2, t3);
1033 if (cpuFeatures.fma3) {
1034 vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi1 * 16]);
1035 vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi2 * 16]);
1036 vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi3 * 16]);
1037 vfnmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_pi4 * 16]);
1038 } else {
1039 VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi1 * 16]);
1040 VEX2(subps, t1, t1, t4);
1041 VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi2 * 16]);
1042 VEX2(subps, t1, t1, t4);
1043 VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi3 * 16]);
1044 VEX2(subps, t1, t1, t4);
1045 VEX2(mulps, t4, t2, xmmword_ptr[constants + ConstantIndex::float_pi4 * 16]);
1046 VEX2(subps, t1, t1, t4);
1047 }
1048 if (issin) {
1049 // Evaluate minimax polynomial for sin(x) in [-pi/2, pi/2] interval
1050 // Y <- X + X * X^2 * (C3 + X^2 * (C5 + X^2 * (C7 + X^2 * C9)))
1051 VEX2(mulps, t2, t1, t1);
1052 if (cpuFeatures.fma3) {
1053 vmovaps(t3, xmmword_ptr[constants + ConstantIndex::float_sinC7 * 16]);
1054 vfmadd231ps(t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC9 * 16]);
1055 vfmadd213ps(t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC5 * 16]);
1056 vfmadd213ps(t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC3 * 16]);
1057 VEX2(mulps, t3, t3, t2);
1058 vfmadd231ps(t1, t1, t3);
1059 } else {
1060 VEX2(mulps, t3, t2, xmmword_ptr[constants + ConstantIndex::float_sinC9 * 16]);
1061 VEX2(addps, t3, t3, xmmword_ptr[constants + ConstantIndex::float_sinC7 * 16]);
1062 VEX2(mulps, t3, t3, t2);
1063 VEX2(addps, t3, t3, xmmword_ptr[constants + ConstantIndex::float_sinC5 * 16]);
1064 VEX2(mulps, t3, t3, t2);
1065 VEX2(addps, t3, t3, xmmword_ptr[constants + ConstantIndex::float_sinC3 * 16]);
1066 VEX2(mulps, t3, t3, t2);
1067 VEX2(mulps, t3, t3, t1);
1068 VEX2(addps, t1, t1, t3);
1069 }
1070 } else {
1071 // Evaluate minimax polynomial for cos(x) in [-pi/2, pi/2] interval
1072 // Y <- 1 + X^2 * (C2 + X^2 * (C4 + X^2 * (C6 + X^2 * C8)))
1073 VEX2(mulps, t2, t1, t1);
1074 if (cpuFeatures.fma3) {
1075 vmovaps(t1, xmmword_ptr[constants + ConstantIndex::float_cosC6 * 16]);
1076 vfmadd231ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC8 * 16]);
1077 vfmadd213ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC4 * 16]);
1078 vfmadd213ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC2 * 16]);
1079 vfmadd213ps(t1, t2, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
1080 } else {
1081 VEX2(mulps, t1, t2, xmmword_ptr[constants + ConstantIndex::float_cosC8 * 16]);
1082 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_cosC6 * 16]);
1083 VEX2(mulps, t1, t1, t2);
1084 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_cosC4 * 16]);
1085 VEX2(mulps, t1, t1, t2);
1086 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_cosC2 * 16]);
1087 VEX2(mulps, t1, t1, t2);
1088 VEX2(addps, t1, t1, xmmword_ptr[constants + ConstantIndex::float_one * 16]);
1089 }
1090 }
1091 // Apply sign
1092 VEX2(xorps, y, t1, sign);
1093 }
1094
sincos(bool issin,const ExprInstruction & insn)1095 void sincos(bool issin, const ExprInstruction &insn)
1096 {
1097 int l = curLabel++;
1098
1099 deferred.push_back([this, issin, insn, l](Reg regptrs, XmmReg zero, Reg constants, std::unordered_map<int, std::pair<XmmReg, XmmReg>> &bytecodeRegs)
1100 {
1101 char label[] = "label-0000";
1102 sprintf(label, "label-%04d", l);
1103
1104 auto t1 = bytecodeRegs[insn.src1];
1105 auto t3 = bytecodeRegs[insn.dst];
1106
1107 XmmReg r1, r2;
1108 Reg a;
1109 mov(a, 2);
1110 VEX1(movaps, r1, t1.first);
1111 VEX1(movaps, r2, t1.second);
1112
1113 L(label);
1114
1115 sincos_(issin, r1, r1, constants);
1116 VEX1(movaps, t3.first, t3.second);
1117 VEX1(movaps, t3.second, r1);
1118 VEX1(movaps, r1, r2);
1119
1120 jit::sub(a, 1);
1121 jnz(label);
1122 });
1123 }
1124
sin(const ExprInstruction & insn)1125 void sin(const ExprInstruction &insn) override
1126 {
1127 sincos(true, insn);
1128 }
1129
cos(const ExprInstruction & insn)1130 void cos(const ExprInstruction &insn) override
1131 {
1132 sincos(false, insn);
1133 }
1134
main(Reg regptrs,Reg regoffs,Reg niter)1135 void main(Reg regptrs, Reg regoffs, Reg niter)
1136 {
1137 std::unordered_map<int, std::pair<XmmReg, XmmReg>> bytecodeRegs;
1138 XmmReg zero;
1139 VEX2(pxor, zero, zero, zero);
1140 Reg constants;
1141 mov(constants, (uintptr_t)constData);
1142
1143 L("wloop");
1144
1145 for (const auto &f : deferred) {
1146 f(regptrs, zero, constants, bytecodeRegs);
1147 }
1148
1149 #if UINTPTR_MAX > UINT32_MAX
1150 for (int i = 0; i < numInputs / 2 + 1; i++) {
1151 XmmReg r1, r2;
1152 VEX1(movdqu, r1, xmmword_ptr[regptrs + 16 * i]);
1153 VEX1(movdqu, r2, xmmword_ptr[regoffs + 16 * i]);
1154 VEX2(paddq, r1, r1, r2);
1155 VEX1(movdqu, xmmword_ptr[regptrs + 16 * i], r1);
1156 }
1157 #else
1158 for (int i = 0; i < numInputs / 4 + 1; i++) {
1159 XmmReg r1, r2;
1160 VEX1(movdqu, r1, xmmword_ptr[regptrs + 16 * i]);
1161 VEX1(movdqu, r2, xmmword_ptr[regoffs + 16 * i]);
1162 VEX2(paddd, r1, r1, r2);
1163 VEX1(movdqu, xmmword_ptr[regptrs + 16 * i], r1);
1164 }
1165 #endif
1166
1167 jit::sub(niter, 1);
1168 jnz("wloop");
1169 }
1170
1171 public:
ExprCompiler128(int numInputs)1172 explicit ExprCompiler128(int numInputs) : cpuFeatures(*getCPUFeatures()), numInputs(numInputs), curLabel() {}
1173
getCode()1174 std::pair<ExprData::ProcessLineProc, size_t> getCode() override
1175 {
1176 size_t size;
1177 if (jit::GetCode() && (size = GetCodeSize())) {
1178 #ifdef VS_TARGET_OS_WINDOWS
1179 void *ptr = VirtualAlloc(nullptr, size, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
1180 #else
1181 void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, 0, 0);
1182 #endif
1183 memcpy(ptr, jit::GetCode(), size);
1184 return {reinterpret_cast<ExprData::ProcessLineProc>(ptr), size};
1185 }
1186 return {nullptr, 0};
1187 }
1188 #undef VEX2IMM
1189 #undef VEX2
1190 #undef VEX1IMM
1191 #undef VEX1
1192 #undef EMIT
1193 };
1194
1195 constexpr ExprUnion ExprCompiler128::constData alignas(16)[53][4];
1196
1197 class ExprCompiler256 : public ExprCompiler, private jitasm::function<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t> {
1198 typedef jitasm::function<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t> jit;
1199 friend struct jitasm::function<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t>;
1200 friend struct jitasm::function_cdecl<void, ExprCompiler256, uint8_t *, const intptr_t *, intptr_t>;
1201
1202 #define SPLAT(x) { (x), (x), (x), (x), (x), (x), (x), (x) }
1203 static constexpr ExprUnion constData alignas(32)[53][8] = {
1204 SPLAT(0x7FFFFFFF), // absmask
1205 SPLAT(0x80000000), // negmask
1206 SPLAT(0x7F), // x7F
1207 SPLAT(0x00800000), // min_norm_pos
1208 SPLAT(~0x7F800000), // inv_mant_mask
1209 SPLAT(1.0f), // float_one
1210 SPLAT(0.5f), // float_half
1211 SPLAT(255.0f), // float_255
1212 SPLAT(511.0f), // float_511
1213 SPLAT(1023.0f), // float_1023
1214 SPLAT(2047.0f), // float_2047
1215 SPLAT(4095.0f), // float_4095
1216 SPLAT(8191.0f), // float_8191
1217 SPLAT(16383.0f), // float_16383
1218 SPLAT(32767.0f), // float_32767
1219 SPLAT(65535.0f), // float_65535
1220 SPLAT(static_cast<int32_t>(0x80008000)), // i16min_epi16
1221 SPLAT(static_cast<int32_t>(0xFFFF8000)), // i16min_epi32
1222 SPLAT(88.3762626647949f), // exp_hi
1223 SPLAT(-88.3762626647949f), // exp_lo
1224 SPLAT(1.44269504088896341f), // log2e
1225 SPLAT(0.693359375f), // exp_c1
1226 SPLAT(-2.12194440e-4f), // exp_c2
1227 SPLAT(1.9875691500E-4f), // exp_p0
1228 SPLAT(1.3981999507E-3f), // exp_p1
1229 SPLAT(8.3334519073E-3f), // exp_p2
1230 SPLAT(4.1665795894E-2f), // exp_p3
1231 SPLAT(1.6666665459E-1f), // exp_p4
1232 SPLAT(5.0000001201E-1f), // exp_p5
1233 SPLAT(0.707106781186547524f), // sqrt_1_2
1234 SPLAT(7.0376836292E-2f), // log_p0
1235 SPLAT(-1.1514610310E-1f), // log_p1
1236 SPLAT(1.1676998740E-1f), // log_p2
1237 SPLAT(-1.2420140846E-1f), // log_p3
1238 SPLAT(+1.4249322787E-1f), // log_p4
1239 SPLAT(-1.6668057665E-1f), // log_p5
1240 SPLAT(+2.0000714765E-1f), // log_p6
1241 SPLAT(-2.4999993993E-1f), // log_p7
1242 SPLAT(+3.3333331174E-1f), // log_p8
1243 SPLAT(0x3ea2f983), // float_invpi, 1/pi
1244 SPLAT(0x4b400000), // float_rintf
1245 SPLAT(0x40490000), // float_pi1
1246 SPLAT(0x3a7da000), // float_pi2
1247 SPLAT(0x34222000), // float_pi3
1248 SPLAT(0x2cb4611a), // float_pi4
1249 SPLAT(0xbe2aaaa6), // float_sinC3
1250 SPLAT(0x3c08876a), // float_sinC5
1251 SPLAT(0xb94fb7ff), // float_sinC7
1252 SPLAT(0x362edef8), // float_sinC9
1253 SPLAT(static_cast<int32_t>(0xBEFFFFE2)), // float_cosC2
1254 SPLAT(0x3D2AA73C), // float_cosC4
1255 SPLAT(static_cast<int32_t>(0XBAB58D50)), // float_cosC6
1256 SPLAT(0x37C1AD76), // float_cosC8
1257 };
1258
1259 struct ConstantIndex {
1260 static constexpr int absmask = 0;
1261 static constexpr int negmask = 1;
1262 static constexpr int x7F = 2;
1263 static constexpr int min_norm_pos = 3;
1264 static constexpr int inv_mant_mask = 4;
1265 static constexpr int float_one = 5;
1266 static constexpr int float_half = 6;
1267 static constexpr int float_255 = 7;
1268 static constexpr int float_511 = 8;
1269 static constexpr int float_1023 = 9;
1270 static constexpr int float_2047 = 10;
1271 static constexpr int float_4095 = 11;
1272 static constexpr int float_8191 = 12;
1273 static constexpr int float_16383 = 13;
1274 static constexpr int float_32767 = 14;
1275 static constexpr int float_65535 = 15;
1276 static constexpr int i16min_epi16 = 16;
1277 static constexpr int i16min_epi32 = 17;
1278 static constexpr int exp_hi = 18;
1279 static constexpr int exp_lo = 19;
1280 static constexpr int log2e = 20;
1281 static constexpr int exp_c1 = 21;
1282 static constexpr int exp_c2 = 22;
1283 static constexpr int exp_p0 = 23;
1284 static constexpr int exp_p1 = 24;
1285 static constexpr int exp_p2 = 25;
1286 static constexpr int exp_p3 = 26;
1287 static constexpr int exp_p4 = 27;
1288 static constexpr int exp_p5 = 28;
1289 static constexpr int sqrt_1_2 = 29;
1290 static constexpr int log_p0 = 30;
1291 static constexpr int log_p1 = 31;
1292 static constexpr int log_p2 = 32;
1293 static constexpr int log_p3 = 33;
1294 static constexpr int log_p4 = 34;
1295 static constexpr int log_p5 = 35;
1296 static constexpr int log_p6 = 36;
1297 static constexpr int log_p7 = 37;
1298 static constexpr int log_p8 = 38;
1299 static constexpr int log_q1 = exp_c2;
1300 static constexpr int log_q2 = exp_c1;
1301 static constexpr int float_invpi = 39;
1302 static constexpr int float_rintf = 40;
1303 static constexpr int float_pi1 = 41;
1304 static constexpr int float_pi2 = float_pi1 + 1;
1305 static constexpr int float_pi3 = float_pi1 + 2;
1306 static constexpr int float_pi4 = float_pi1 + 3;
1307 static constexpr int float_sinC3 = 45;
1308 static constexpr int float_sinC5 = float_sinC3 + 1;
1309 static constexpr int float_sinC7 = float_sinC3 + 2;
1310 static constexpr int float_sinC9 = float_sinC3 + 3;
1311 static constexpr int float_cosC2 = 49;
1312 static constexpr int float_cosC4 = float_cosC2 + 1;
1313 static constexpr int float_cosC6 = float_cosC2 + 2;
1314 static constexpr int float_cosC8 = float_cosC2 + 3;
1315 };
1316 #undef SPLAT
1317
1318 // JitASM compiles everything from main(), so record the operations for later.
1319 std::vector<std::function<void(Reg, YmmReg, Reg, std::unordered_map<int, YmmReg> &)>> deferred;
1320
1321 CPUFeatures cpuFeatures;
1322 int numInputs;
1323 int curLabel;
1324
1325 #define EMIT() [this, insn](Reg regptrs, YmmReg zero, Reg constants, std::unordered_map<int, YmmReg> &bytecodeRegs)
1326
load8(const ExprInstruction & insn)1327 void load8(const ExprInstruction &insn) override
1328 {
1329 deferred.push_back(EMIT()
1330 {
1331 auto t1 = bytecodeRegs[insn.dst];
1332 Reg a;
1333 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1334 vpmovzxbd(t1, mmword_ptr[a]);
1335 vcvtdq2ps(t1, t1);
1336 });
1337 }
1338
load16(const ExprInstruction & insn)1339 void load16(const ExprInstruction &insn) override
1340 {
1341 deferred.push_back(EMIT()
1342 {
1343 auto t1 = bytecodeRegs[insn.dst];
1344 Reg a;
1345 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1346 vpmovzxwd(t1, xmmword_ptr[a]);
1347 vcvtdq2ps(t1, t1);
1348 });
1349 }
1350
loadF16(const ExprInstruction & insn)1351 void loadF16(const ExprInstruction &insn) override
1352 {
1353 deferred.push_back(EMIT()
1354 {
1355 auto t1 = bytecodeRegs[insn.dst];
1356 Reg a;
1357 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1358 vcvtph2ps(t1, xmmword_ptr[a]);
1359 });
1360 }
1361
loadF32(const ExprInstruction & insn)1362 void loadF32(const ExprInstruction &insn) override
1363 {
1364 deferred.push_back(EMIT()
1365 {
1366 auto t1 = bytecodeRegs[insn.dst];
1367 Reg a;
1368 mov(a, ptr[regptrs + sizeof(void *) * (insn.op.imm.u + 1)]);
1369 vmovaps(t1, ymmword_ptr[a]);
1370 });
1371 }
1372
loadConst(const ExprInstruction & insn)1373 void loadConst(const ExprInstruction &insn) override
1374 {
1375 deferred.push_back(EMIT()
1376 {
1377 auto t1 = bytecodeRegs[insn.dst];
1378
1379 if (insn.op.imm.f == 0.0f) {
1380 vmovaps(t1, zero);
1381 return;
1382 }
1383
1384 XmmReg r1;
1385 Reg32 a;
1386 mov(a, insn.op.imm.u);
1387 vmovd(r1, a);
1388 vbroadcastss(t1, r1);
1389 });
1390 }
1391
store8(const ExprInstruction & insn)1392 void store8(const ExprInstruction &insn) override
1393 {
1394 deferred.push_back(EMIT()
1395 {
1396 auto t1 = bytecodeRegs[insn.src1];
1397 YmmReg r1;
1398 Reg a;
1399 vminps(r1, t1, ymmword_ptr[constants + ConstantIndex::float_255 * 32]);
1400 vcvtps2dq(r1, r1);
1401 vpackssdw(r1, r1, r1);
1402 vpermq(r1, r1, 0x08);
1403 vpackuswb(r1, r1, zero);
1404 mov(a, ptr[regptrs]);
1405 vmovq(qword_ptr[a], r1.as128());
1406 });
1407 }
1408
store16(const ExprInstruction & insn)1409 void store16(const ExprInstruction &insn) override
1410 {
1411 deferred.push_back(EMIT()
1412 {
1413 int depth = insn.op.imm.u;
1414 auto t1 = bytecodeRegs[insn.src1];
1415 YmmReg r1, limit;
1416 Reg a;
1417 vminps(r1, t1, ymmword_ptr[constants + (ConstantIndex::float_255 + depth - 8) * 32]);
1418 vcvtps2dq(r1, r1);
1419 vpackusdw(r1, r1, r1);
1420 vpermq(r1, r1, 0x08);
1421 mov(a, ptr[regptrs]);
1422 vmovaps(xmmword_ptr[a], r1.as128());
1423 });
1424 }
1425
storeF16(const ExprInstruction & insn)1426 void storeF16(const ExprInstruction &insn) override
1427 {
1428 deferred.push_back(EMIT()
1429 {
1430 auto t1 = bytecodeRegs[insn.src1];
1431 Reg a;
1432 mov(a, ptr[regptrs]);
1433 vcvtps2ph(xmmword_ptr[a], t1, 0);
1434 });
1435 }
1436
storeF32(const ExprInstruction & insn)1437 void storeF32(const ExprInstruction &insn) override
1438 {
1439 deferred.push_back(EMIT()
1440 {
1441 auto t1 = bytecodeRegs[insn.src1];
1442 Reg a;
1443 mov(a, ptr[regptrs]);
1444 vmovaps(ymmword_ptr[a], t1);
1445 });
1446 }
1447
1448 #define BINARYOP(op) \
1449 do { \
1450 auto t1 = bytecodeRegs[insn.src1]; \
1451 auto t2 = bytecodeRegs[insn.src2]; \
1452 auto t3 = bytecodeRegs[insn.dst]; \
1453 op(t3, t1, t2); \
1454 } while (0)
add(const ExprInstruction & insn)1455 void add(const ExprInstruction &insn) override
1456 {
1457 deferred.push_back(EMIT()
1458 {
1459 BINARYOP(vaddps);
1460 });
1461 }
1462
sub(const ExprInstruction & insn)1463 void sub(const ExprInstruction &insn) override
1464 {
1465 deferred.push_back(EMIT()
1466 {
1467 BINARYOP(vsubps);
1468 });
1469 }
1470
mul(const ExprInstruction & insn)1471 void mul(const ExprInstruction &insn) override
1472 {
1473 deferred.push_back(EMIT()
1474 {
1475 BINARYOP(vmulps);
1476 });
1477 }
1478
div(const ExprInstruction & insn)1479 void div(const ExprInstruction &insn) override
1480 {
1481 deferred.push_back(EMIT()
1482 {
1483 BINARYOP(vdivps);
1484 });
1485 }
1486
fma(const ExprInstruction & insn)1487 void fma(const ExprInstruction &insn) override
1488 {
1489 deferred.push_back(EMIT()
1490 {
1491 FMAType type = static_cast<FMAType>(insn.op.imm.u);
1492
1493 // t1 + t2 * t3
1494 auto t1 = bytecodeRegs[insn.src1];
1495 auto t2 = bytecodeRegs[insn.src2];
1496 auto t3 = bytecodeRegs[insn.src3];
1497 auto t4 = bytecodeRegs[insn.dst];
1498
1499 #define FMA3(op) \
1500 do { \
1501 if (insn.dst == insn.src1) { \
1502 op##231ps(t1, t2, t3); \
1503 } else if (insn.dst == insn.src2) { \
1504 op##132ps(t2, t1, t3); \
1505 } else if (insn.dst == insn.src3) { \
1506 op##132ps(t3, t1, t2); \
1507 } else { \
1508 vmovaps(t4, t1); \
1509 op##231ps(t4, t2, t3); \
1510 } \
1511 } while (0)
1512 switch (type) {
1513 case FMAType::FMADD: FMA3(vfmadd); break;
1514 case FMAType::FMSUB: FMA3(vfmsub); break;
1515 case FMAType::FNMADD: FMA3(vfnmadd); break;
1516 case FMAType::FNMSUB: FMA3(vfnmsub); break;
1517 }
1518 #undef FMA3
1519 });
1520 }
1521
max(const ExprInstruction & insn)1522 void max(const ExprInstruction &insn) override
1523 {
1524 deferred.push_back(EMIT()
1525 {
1526 BINARYOP(vmaxps);
1527 });
1528 }
1529
min(const ExprInstruction & insn)1530 void min(const ExprInstruction &insn) override
1531 {
1532 deferred.push_back(EMIT()
1533 {
1534 BINARYOP(vminps);
1535 });
1536 }
1537 #undef BINARYOP
1538
sqrt(const ExprInstruction & insn)1539 void sqrt(const ExprInstruction &insn) override
1540 {
1541 deferred.push_back(EMIT()
1542 {
1543 auto t1 = bytecodeRegs[insn.src1];
1544 auto t2 = bytecodeRegs[insn.dst];
1545 vmaxps(t2, t1, zero);
1546 vsqrtps(t2, t2);
1547 });
1548 }
1549
abs(const ExprInstruction & insn)1550 void abs(const ExprInstruction &insn) override
1551 {
1552 deferred.push_back(EMIT()
1553 {
1554 auto t1 = bytecodeRegs[insn.src1];
1555 auto t2 = bytecodeRegs[insn.dst];
1556 vandps(t2, t1, ymmword_ptr[constants + ConstantIndex::absmask * 32]);
1557 });
1558 }
1559
neg(const ExprInstruction & insn)1560 void neg(const ExprInstruction &insn) override
1561 {
1562 deferred.push_back(EMIT()
1563 {
1564 auto t1 = bytecodeRegs[insn.src1];
1565 auto t2 = bytecodeRegs[insn.dst];
1566 vxorps(t2, t1, ymmword_ptr[constants + ConstantIndex::negmask * 32]);
1567 });
1568 }
1569
not_(const ExprInstruction & insn)1570 void not_(const ExprInstruction &insn) override
1571 {
1572 deferred.push_back(EMIT()
1573 {
1574 auto t1 = bytecodeRegs[insn.src1];
1575 auto t2 = bytecodeRegs[insn.dst];
1576 YmmReg r1;
1577 vcmpps(t2, t1, zero, _CMP_LE_OS);
1578 vandps(t2, t2, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1579 });
1580 }
1581
1582 #define LOGICOP(op) \
1583 do { \
1584 auto t1 = bytecodeRegs[insn.src1]; \
1585 auto t2 = bytecodeRegs[insn.src2]; \
1586 auto t3 = bytecodeRegs[insn.dst]; \
1587 YmmReg tmp; \
1588 vcmpps(tmp, t1, zero, _CMP_NLE_US); \
1589 vcmpps(t3, t2, zero, _CMP_NLE_US); \
1590 op(t3, t3, tmp); \
1591 vandps(t3, t3, ymmword_ptr[constants + ConstantIndex::float_one * 32]); \
1592 } while (0)
1593
and_(const ExprInstruction & insn)1594 void and_(const ExprInstruction &insn) override
1595 {
1596 deferred.push_back(EMIT()
1597 {
1598 LOGICOP(vandps);
1599 });
1600 }
1601
or_(const ExprInstruction & insn)1602 void or_(const ExprInstruction &insn) override
1603 {
1604 deferred.push_back(EMIT()
1605 {
1606 LOGICOP(vorps);
1607 });
1608 }
1609
xor_(const ExprInstruction & insn)1610 void xor_(const ExprInstruction &insn) override
1611 {
1612 deferred.push_back(EMIT()
1613 {
1614 LOGICOP(vxorps);
1615 });
1616 }
1617 #undef LOGICOP
1618
cmp(const ExprInstruction & insn)1619 void cmp(const ExprInstruction &insn) override
1620 {
1621 deferred.push_back(EMIT()
1622 {
1623 auto t1 = bytecodeRegs[insn.src1];
1624 auto t2 = bytecodeRegs[insn.src2];
1625 auto t3 = bytecodeRegs[insn.dst];
1626 vcmpps(t3, t1, t2, insn.op.imm.u);
1627 vandps(t3, t3, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1628 });
1629 }
1630
ternary(const ExprInstruction & insn)1631 void ternary(const ExprInstruction &insn) override
1632 {
1633 deferred.push_back(EMIT()
1634 {
1635 auto t1 = bytecodeRegs[insn.src1];
1636 auto t2 = bytecodeRegs[insn.src2];
1637 auto t3 = bytecodeRegs[insn.src3];
1638 auto t4 = bytecodeRegs[insn.dst];
1639 YmmReg r1;
1640 vcmpps(r1, t1, zero, _CMP_NLE_US);
1641 vblendvps(t4, t3, t2, r1);
1642 });
1643 }
1644
exp_(YmmReg x,YmmReg one,Reg constants)1645 void exp_(YmmReg x, YmmReg one, Reg constants)
1646 {
1647 YmmReg fx, emm0, etmp, y, mask, z;
1648 vminps(x, x, ymmword_ptr[constants + ConstantIndex::exp_hi * 32]);
1649 vmaxps(x, x, ymmword_ptr[constants + ConstantIndex::exp_lo * 32]);
1650 vmovaps(fx, ymmword_ptr[constants + ConstantIndex::log2e * 32]);
1651 vfmadd213ps(fx, x, ymmword_ptr[constants + ConstantIndex::float_half * 32]);
1652 vcvttps2dq(emm0, fx);
1653 vcvtdq2ps(etmp, emm0);
1654 vcmpps(mask, etmp, fx, _CMP_NLE_US);
1655 vandps(mask, mask, one);
1656 vsubps(fx, etmp, mask);
1657 vfnmadd231ps(x, fx, ymmword_ptr[constants + ConstantIndex::exp_c1 * 32]);
1658 vfnmadd231ps(x, fx, ymmword_ptr[constants + ConstantIndex::exp_c2 * 32]);
1659 vmulps(z, x, x);
1660 vmovaps(y, ymmword_ptr[constants + ConstantIndex::exp_p0 * 32]);
1661 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p1 * 32]);
1662 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p2 * 32]);
1663 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p3 * 32]);
1664 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p4 * 32]);
1665 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::exp_p5 * 32]);
1666 vfmadd213ps(y, z, x);
1667 vaddps(y, y, one);
1668 vcvttps2dq(emm0, fx);
1669 vpaddd(emm0, emm0, ymmword_ptr[constants + ConstantIndex::x7F * 32]);
1670 vpslld(emm0, emm0, 23);
1671 vmulps(x, y, emm0);
1672 }
1673
log_(YmmReg x,YmmReg zero,YmmReg one,Reg constants)1674 void log_(YmmReg x, YmmReg zero, YmmReg one, Reg constants)
1675 {
1676 YmmReg emm0, invalid_mask, mask, y, etmp, z;
1677 vcmpps(invalid_mask, zero, x, _CMP_NLT_US);
1678 vmaxps(x, x, ymmword_ptr[constants + ConstantIndex::min_norm_pos * 32]);
1679 vpsrld(emm0, x, 23);
1680 vandps(x, x, ymmword_ptr[constants + ConstantIndex::inv_mant_mask * 32]);
1681 vorps(x, x, ymmword_ptr[constants + ConstantIndex::float_half * 32]);
1682 vpsubd(emm0, emm0, ymmword_ptr[constants + ConstantIndex::x7F * 32]);
1683 vcvtdq2ps(emm0, emm0);
1684 vaddps(emm0, emm0, one);
1685 vcmpps(mask, x, ymmword_ptr[constants + ConstantIndex::sqrt_1_2 * 32], _CMP_LT_OS);
1686 vandps(etmp, x, mask);
1687 vsubps(x, x, one);
1688 vandps(mask, mask, one);
1689 vsubps(emm0, emm0, mask);
1690 vaddps(x, x, etmp);
1691 vmulps(z, x, x);
1692 vmovaps(y, ymmword_ptr[constants + ConstantIndex::log_p0 * 32]);
1693 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p1 * 32]);
1694 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p2 * 32]);
1695 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p3 * 32]);
1696 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p4 * 32]);
1697 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p5 * 32]);
1698 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p6 * 32]);
1699 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p7 * 32]);
1700 vfmadd213ps(y, x, ymmword_ptr[constants + ConstantIndex::log_p8 * 32]);
1701 vmulps(y, y, x);
1702 vmulps(y, y, z);
1703 vfmadd231ps(y, emm0, ymmword_ptr[constants + ConstantIndex::log_q1 * 32]);
1704 vfnmadd231ps(y, z, ymmword_ptr[constants + ConstantIndex::float_half * 32]);
1705 vaddps(x, x, y);
1706 vfmadd231ps(x, emm0, ymmword_ptr[constants + ConstantIndex::log_q2 * 32]);
1707 vorps(x, x, invalid_mask);
1708 }
1709
exp(const ExprInstruction & insn)1710 void exp(const ExprInstruction &insn) override
1711 {
1712 deferred.push_back(EMIT()
1713 {
1714 auto t1 = bytecodeRegs[insn.src1];
1715 auto t2 = bytecodeRegs[insn.dst];
1716 YmmReg one;
1717 vmovaps(one, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1718 vmovaps(t1, t2);
1719 exp_(t1, one, constants);
1720 });
1721 }
1722
log(const ExprInstruction & insn)1723 void log(const ExprInstruction &insn) override
1724 {
1725 deferred.push_back(EMIT()
1726 {
1727 auto t1 = bytecodeRegs[insn.src1];
1728 auto t2 = bytecodeRegs[insn.dst];
1729 YmmReg one;
1730 vmovaps(one, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1731 vmovaps(t1, t2);
1732 log_(t1, zero, one, constants);
1733 });
1734 }
1735
pow(const ExprInstruction & insn)1736 void pow(const ExprInstruction &insn) override
1737 {
1738 deferred.push_back(EMIT()
1739 {
1740 auto t1 = bytecodeRegs[insn.src1];
1741 auto t2 = bytecodeRegs[insn.src2];
1742 auto t3 = bytecodeRegs[insn.dst];
1743
1744 YmmReg r1, one;
1745 vmovaps(one, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1746 vmovaps(r1, t1);
1747 log_(r1, zero, one, constants);
1748 vmulps(r1, r1, t2);
1749 exp_(r1, one, constants);
1750 vmovaps(t3, r1);
1751 });
1752 }
1753
sincos_(bool issin,const ExprInstruction & insn,Reg constants,std::unordered_map<int,YmmReg> & bytecodeRegs)1754 void sincos_(bool issin, const ExprInstruction &insn, Reg constants, std::unordered_map<int, YmmReg> &bytecodeRegs)
1755 {
1756 auto x = bytecodeRegs[insn.src1];
1757 auto y = bytecodeRegs[insn.dst];
1758 YmmReg t1, sign, t2, t3, t4;
1759 // Remove sign
1760 vmovaps(t1, ymmword_ptr[constants + ConstantIndex::absmask * 32]);
1761 if (issin) {
1762 vmovaps(sign, t1);
1763 vandnps(sign, sign, x);
1764 } else {
1765 vxorps(sign, sign, sign);
1766 }
1767 vandps(t1, t1, x);
1768 // Range reduction
1769 vmovaps(t3, ymmword_ptr[constants + ConstantIndex::float_rintf * 32]);
1770 vmulps(t2, t1, ymmword_ptr[constants + ConstantIndex::float_invpi * 32]);
1771 vaddps(t2, t2, t3);
1772 vpslld(t4, t2, 31);
1773 vxorps(sign, sign, t4);
1774 vsubps(t2, t2, t3);
1775 vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi1 * 32]);
1776 vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi2 * 32]);
1777 vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi3 * 32]);
1778 vfnmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_pi4 * 32]);
1779 if (issin) {
1780 // Evaluate minimax polynomial for sin(x) in [-pi/2, pi/2] interval
1781 // Y <- X + X * X^2 * (C3 + X^2 * (C5 + X^2 * (C7 + X^2 * C9)))
1782 vmulps(t2, t1, t1);
1783 vmovaps(t3, ymmword_ptr[constants + ConstantIndex::float_sinC7 * 32]);
1784 vfmadd231ps(t3, t2, ymmword_ptr[constants + ConstantIndex::float_sinC9 * 32]);
1785 vfmadd213ps(t3, t2, ymmword_ptr[constants + ConstantIndex::float_sinC5 * 32]);
1786 vfmadd213ps(t3, t2, ymmword_ptr[constants + ConstantIndex::float_sinC3 * 32]);
1787 vmulps(t3, t3, t2);
1788 vfmadd231ps(t1, t1, t3);
1789 } else {
1790 // Evaluate minimax polynomial for cos(x) in [-pi/2, pi/2] interval
1791 // Y <- 1 + X^2 * (C2 + X^2 * (C4 + X^2 * (C6 + X^2 * C8)))
1792 vmulps(t2, t1, t1);
1793 vmovaps(t1, ymmword_ptr[constants + ConstantIndex::float_cosC6 * 32]);
1794 vfmadd231ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_cosC8 * 32]);
1795 vfmadd213ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_cosC4 * 32]);
1796 vfmadd213ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_cosC2 * 32]);
1797 vfmadd213ps(t1, t2, ymmword_ptr[constants + ConstantIndex::float_one * 32]);
1798 }
1799 // Apply sign
1800 vxorps(y, t1, sign);
1801 }
1802
sin(const ExprInstruction & insn)1803 void sin(const ExprInstruction &insn) override
1804 {
1805 deferred.push_back(EMIT()
1806 {
1807 sincos_(true, insn, constants, bytecodeRegs);
1808 });
1809 }
1810
cos(const ExprInstruction & insn)1811 void cos(const ExprInstruction &insn) override
1812 {
1813 deferred.push_back(EMIT()
1814 {
1815 sincos_(false, insn, constants, bytecodeRegs);
1816 });
1817 }
1818
main(Reg regptrs,Reg regoffs,Reg niter)1819 void main(Reg regptrs, Reg regoffs, Reg niter)
1820 {
1821 std::unordered_map<int, YmmReg> bytecodeRegs;
1822 YmmReg zero;
1823 vpxor(zero, zero, zero);
1824 Reg constants;
1825 mov(constants, (uintptr_t)constData);
1826
1827 L("wloop");
1828
1829 for (const auto &f : deferred) {
1830 f(regptrs, zero, constants, bytecodeRegs);
1831 }
1832
1833 #if UINTPTR_MAX > UINT32_MAX
1834 for (int i = 0; i < numInputs / 4 + 1; i++) {
1835 YmmReg r1, r2;
1836 vmovdqu(r1, ymmword_ptr[regptrs + 32 * i]);
1837 vmovdqu(r2, ymmword_ptr[regoffs + 32 * i]);
1838 vpaddq(r1, r1, r2);
1839 vmovdqu(ymmword_ptr[regptrs + 32 * i], r1);
1840 }
1841 #else
1842 for (int i = 0; i < numInputs / 8 + 1; i++) {
1843 YmmReg r1, r2;
1844 vmovdqu(r1, ymmword_ptr[regptrs + 32 * i]);
1845 vmovdqu(r2, ymmword_ptr[regoffs + 32 * i]);
1846 vpaddd(r1, r1, r2);
1847 vmovdqu(ymmword_ptr[regptrs + 32 * i], r1);
1848 }
1849 #endif
1850
1851 jit::sub(niter, 1);
1852 jnz("wloop");
1853 }
1854
1855 public:
ExprCompiler256(int numInputs)1856 explicit ExprCompiler256(int numInputs) : cpuFeatures(*getCPUFeatures()), numInputs(numInputs) {}
1857
getCode()1858 std::pair<ExprData::ProcessLineProc, size_t> getCode() override
1859 {
1860 size_t size;
1861 if (jit::GetCode(true) && (size = GetCodeSize())) {
1862 #ifdef VS_TARGET_OS_WINDOWS
1863 void *ptr = VirtualAlloc(nullptr, size, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
1864 #else
1865 void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE | PROT_EXEC, MAP_ANON | MAP_PRIVATE, 0, 0);
1866 #endif
1867 memcpy(ptr, jit::GetCode(true), size);
1868 return {reinterpret_cast<ExprData::ProcessLineProc>(ptr), size};
1869 }
1870 return {nullptr, 0};
1871 }
1872 #undef EMIT
1873 };
1874
1875 constexpr ExprUnion ExprCompiler256::constData alignas(32)[53][8];
1876
make_compiler(int numInputs,int cpulevel)1877 std::unique_ptr<ExprCompiler> make_compiler(int numInputs, int cpulevel)
1878 {
1879 if (getCPUFeatures()->avx2 && cpulevel >= VS_CPU_LEVEL_AVX2)
1880 return std::unique_ptr<ExprCompiler>(new ExprCompiler256(numInputs));
1881 else
1882 return std::unique_ptr<ExprCompiler>(new ExprCompiler128(numInputs));
1883 }
1884 #endif
1885
1886 class ExprInterpreter {
1887 const ExprInstruction *bytecode;
1888 size_t numInsns;
1889 std::vector<float> registers;
1890
1891 template <class T>
clamp_int(float x,int depth=std::numeric_limits<T>::digits)1892 static T clamp_int(float x, int depth = std::numeric_limits<T>::digits)
1893 {
1894 float maxval = static_cast<float>((1U << depth) - 1);
1895 return static_cast<T>(std::lrint(std::min(std::max(x, static_cast<float>(std::numeric_limits<T>::min())), maxval)));
1896 }
1897
bool2float(bool x)1898 static float bool2float(bool x) { return x ? 1.0f : 0.0f; }
float2bool(float x)1899 static bool float2bool(float x) { return x > 0.0f; }
1900 public:
ExprInterpreter(const ExprInstruction * bytecode,size_t numInsns)1901 ExprInterpreter(const ExprInstruction *bytecode, size_t numInsns) : bytecode(bytecode), numInsns(numInsns)
1902 {
1903 int maxreg = 0;
1904 for (size_t i = 0; i < numInsns; ++i) {
1905 maxreg = std::max(maxreg, bytecode[i].dst);
1906 }
1907 registers.resize(maxreg + 1);
1908 }
1909
eval(const uint8_t * const * srcp,uint8_t * dstp,int x)1910 void eval(const uint8_t * const *srcp, uint8_t *dstp, int x)
1911 {
1912 for (size_t i = 0; i < numInsns; ++i) {
1913 const ExprInstruction &insn = bytecode[i];
1914
1915 #define SRC1 registers[insn.src1]
1916 #define SRC2 registers[insn.src2]
1917 #define SRC3 registers[insn.src3]
1918 #define DST registers[insn.dst]
1919 switch (insn.op.type) {
1920 case ExprOpType::MEM_LOAD_U8: DST = reinterpret_cast<const uint8_t *>(srcp[insn.op.imm.u])[x]; break;
1921 case ExprOpType::MEM_LOAD_U16: DST = reinterpret_cast<const uint16_t *>(srcp[insn.op.imm.u])[x]; break;
1922 case ExprOpType::MEM_LOAD_F16: DST = 0; break;
1923 case ExprOpType::MEM_LOAD_F32: DST = reinterpret_cast<const float *>(srcp[insn.op.imm.u])[x]; break;
1924 case ExprOpType::CONSTANT: DST = insn.op.imm.f; break;
1925 case ExprOpType::ADD: DST = SRC1 + SRC2; break;
1926 case ExprOpType::SUB: DST = SRC1 - SRC2; break;
1927 case ExprOpType::MUL: DST = SRC1 * SRC2; break;
1928 case ExprOpType::DIV: DST = SRC1 / SRC2; break;
1929 case ExprOpType::FMA:
1930 switch (static_cast<FMAType>(insn.op.imm.u)) {
1931 case FMAType::FMADD: DST = SRC2 * SRC3 + SRC1; break;
1932 case FMAType::FMSUB: DST = SRC2 * SRC3 - SRC1; break;
1933 case FMAType::FNMADD: DST = -(SRC2 * SRC3) + SRC1; break;
1934 case FMAType::FNMSUB: DST = -(SRC2 * SRC3) - SRC1; break;
1935 };
1936 break;
1937 case ExprOpType::MAX: DST = std::max(SRC1, SRC2); break;
1938 case ExprOpType::MIN: DST = std::min(SRC1, SRC2); break;
1939 case ExprOpType::EXP: DST = std::exp(SRC1); break;
1940 case ExprOpType::LOG: DST = std::log(SRC1); break;
1941 case ExprOpType::POW: DST = std::pow(SRC1, SRC2); break;
1942 case ExprOpType::SQRT: DST = std::sqrt(SRC1); break;
1943 case ExprOpType::SIN: DST = std::sin(SRC1); break;
1944 case ExprOpType::COS: DST = std::cos(SRC1); break;
1945 case ExprOpType::ABS: DST = std::fabs(SRC1); break;
1946 case ExprOpType::NEG: DST = -SRC1; break;
1947 case ExprOpType::CMP:
1948 switch (static_cast<ComparisonType>(insn.op.imm.u)) {
1949 case ComparisonType::EQ: DST = bool2float(SRC1 == SRC2); break;
1950 case ComparisonType::LT: DST = bool2float(SRC1 < SRC2); break;
1951 case ComparisonType::LE: DST = bool2float(SRC1 <= SRC2); break;
1952 case ComparisonType::NEQ: DST = bool2float(SRC1 != SRC2); break;
1953 case ComparisonType::NLT: DST = bool2float(SRC1 >= SRC2); break;
1954 case ComparisonType::NLE: DST = bool2float(SRC1 > SRC2); break;
1955 }
1956 break;
1957 case ExprOpType::TERNARY: DST = float2bool(SRC1) ? SRC2 : SRC3; break;
1958 case ExprOpType::AND: DST = bool2float((float2bool(SRC1) && float2bool(SRC2))); break;
1959 case ExprOpType::OR: DST = bool2float((float2bool(SRC1) || float2bool(SRC2))); break;
1960 case ExprOpType::XOR: DST = bool2float((float2bool(SRC1) != float2bool(SRC2))); break;
1961 case ExprOpType::NOT: DST = bool2float(!float2bool(SRC1)); break;
1962 case ExprOpType::MEM_STORE_U8: reinterpret_cast<uint8_t *>(dstp)[x] = clamp_int<uint8_t>(SRC1); return;
1963 case ExprOpType::MEM_STORE_U16: reinterpret_cast<uint16_t *>(dstp)[x] = clamp_int<uint16_t>(SRC1, insn.op.imm.u); return;
1964 case ExprOpType::MEM_STORE_F16: reinterpret_cast<uint16_t *>(dstp)[x] = 0; return;
1965 case ExprOpType::MEM_STORE_F32: reinterpret_cast<float *>(dstp)[x] = SRC1; return;
1966 default: vsFatal("illegal opcode"); return;
1967 }
1968 #undef DST
1969 #undef SRC3
1970 #undef SRC2
1971 #undef SRC1
1972 }
1973 }
1974 };
1975
1976 struct ExpressionTreeNode {
1977 ExpressionTreeNode *parent;
1978 ExpressionTreeNode *left;
1979 ExpressionTreeNode *right;
1980 ExprOp op;
1981 int valueNum;
1982
ExpressionTreeNode__anon2a0524d50111::ExpressionTreeNode1983 explicit ExpressionTreeNode(ExprOp op) : parent(), left(), right(), op(op), valueNum(-1) {}
1984
setLeft__anon2a0524d50111::ExpressionTreeNode1985 void setLeft(ExpressionTreeNode *node)
1986 {
1987 if (left)
1988 left->parent = nullptr;
1989
1990 left = node;
1991
1992 if (left)
1993 left->parent = this;
1994 }
1995
setRight__anon2a0524d50111::ExpressionTreeNode1996 void setRight(ExpressionTreeNode *node)
1997 {
1998 if (right)
1999 right->parent = nullptr;
2000
2001 right = node;
2002
2003 if (right)
2004 right->parent = this;
2005 }
2006
2007 template <class T>
preorder__anon2a0524d50111::ExpressionTreeNode2008 void preorder(T visitor)
2009 {
2010 if (visitor(*this))
2011 return;
2012
2013 if (left)
2014 left->preorder(visitor);
2015 if (right)
2016 right->preorder(visitor);
2017 }
2018
2019 template <class T>
postorder__anon2a0524d50111::ExpressionTreeNode2020 void postorder(T visitor)
2021 {
2022 if (left)
2023 left->postorder(visitor);
2024 if (right)
2025 right->postorder(visitor);
2026 visitor(*this);
2027 }
2028 };
2029
2030 class ExpressionTree {
2031 std::vector<std::unique_ptr<ExpressionTreeNode>> nodes;
2032 ExpressionTreeNode *root;
2033 public:
ExpressionTree()2034 ExpressionTree() : root() {}
2035
getRoot()2036 ExpressionTreeNode *getRoot() { return root; }
getRoot() const2037 const ExpressionTreeNode *getRoot() const { return root; }
2038
setRoot(ExpressionTreeNode * node)2039 void setRoot(ExpressionTreeNode *node) { root = node; }
2040
makeNode(ExprOp data)2041 ExpressionTreeNode *makeNode(ExprOp data)
2042 {
2043 nodes.push_back(std::unique_ptr<ExpressionTreeNode>(new ExpressionTreeNode(data)));
2044 return nodes.back().get();
2045 }
2046
clone(const ExpressionTreeNode * node)2047 ExpressionTreeNode *clone(const ExpressionTreeNode *node)
2048 {
2049 if (!node)
2050 return nullptr;
2051
2052 ExpressionTreeNode *newnode = makeNode(node->op);
2053 newnode->setLeft(clone(node->left));
2054 newnode->setRight(clone(node->right));
2055 return newnode;
2056 }
2057 };
2058
makeTreeNode(ExprOp data)2059 std::unique_ptr<ExpressionTreeNode> makeTreeNode(ExprOp data)
2060 {
2061 return std::unique_ptr<ExpressionTreeNode>(new ExpressionTreeNode(data));
2062 }
2063
equalSubTree(const ExpressionTreeNode * lhs,const ExpressionTreeNode * rhs)2064 bool equalSubTree(const ExpressionTreeNode *lhs, const ExpressionTreeNode *rhs)
2065 {
2066 if (lhs->valueNum >= 0 && rhs->valueNum >= 0)
2067 return lhs->valueNum == rhs->valueNum;
2068 if (lhs->op.type != rhs->op.type || lhs->op.imm.u != rhs->op.imm.u)
2069 return false;
2070 if (!!lhs->left != !!rhs->left || !!lhs->right != !!rhs->right)
2071 return false;
2072 if (lhs->left && !equalSubTree(lhs->left, rhs->left))
2073 return false;
2074 if (lhs->right && !equalSubTree(lhs->right, rhs->right))
2075 return false;
2076 return true;
2077 }
2078
tokenize(const std::string & expr)2079 std::vector<std::string> tokenize(const std::string &expr)
2080 {
2081 std::vector<std::string> tokens;
2082 auto it = expr.begin();
2083 auto prev = expr.begin();
2084
2085 while (it != expr.end()) {
2086 char c = *it;
2087
2088 if (std::isspace(c)) {
2089 if (it != prev)
2090 tokens.push_back(expr.substr(prev - expr.begin(), it - prev));
2091 prev = it + 1;
2092 }
2093 ++it;
2094 }
2095 if (prev != expr.end())
2096 tokens.push_back(expr.substr(prev - expr.begin(), expr.end() - prev));
2097
2098 return tokens;
2099 }
2100
decodeToken(const std::string & token)2101 ExprOp decodeToken(const std::string &token)
2102 {
2103 static const std::unordered_map<std::string, ExprOp> simple{
2104 { "+", { ExprOpType::ADD } },
2105 { "-", { ExprOpType::SUB } },
2106 { "*", { ExprOpType::MUL } },
2107 { "/", { ExprOpType::DIV } } ,
2108 { "sqrt", { ExprOpType::SQRT } },
2109 { "abs", { ExprOpType::ABS } },
2110 { "max", { ExprOpType::MAX } },
2111 { "min", { ExprOpType::MIN } },
2112 { "<", { ExprOpType::CMP, static_cast<int>(ComparisonType::LT) } },
2113 { ">", { ExprOpType::CMP, static_cast<int>(ComparisonType::NLE) } },
2114 { "=", { ExprOpType::CMP, static_cast<int>(ComparisonType::EQ) } },
2115 { ">=", { ExprOpType::CMP, static_cast<int>(ComparisonType::NLT) } },
2116 { "<=", { ExprOpType::CMP, static_cast<int>(ComparisonType::LE) } },
2117 { "and", { ExprOpType::AND } },
2118 { "or", { ExprOpType::OR } },
2119 { "xor", { ExprOpType::XOR } },
2120 { "not", { ExprOpType::NOT } },
2121 { "?", { ExprOpType::TERNARY } },
2122 { "exp", { ExprOpType::EXP } },
2123 { "log", { ExprOpType::LOG } },
2124 { "pow", { ExprOpType::POW } },
2125 { "sin", { ExprOpType::SIN } },
2126 { "cos", { ExprOpType::COS } },
2127 { "dup", { ExprOpType::DUP, 0 } },
2128 { "swap", { ExprOpType::SWAP, 1 } },
2129 };
2130
2131 auto it = simple.find(token);
2132 if (it != simple.end()) {
2133 return it->second;
2134 } else if (token.size() == 1 && token[0] >= 'a' && token[0] <= 'z') {
2135 return{ ExprOpType::MEM_LOAD_U8, token[0] >= 'x' ? token[0] - 'x' : token[0] - 'a' + 3 };
2136 } else if (token.substr(0, 3) == "dup" || token.substr(0, 4) == "swap") {
2137 size_t prefix = token[0] == 'd' ? 3 : 4;
2138 size_t count = 0;
2139 int idx = -1;
2140
2141 try {
2142 idx = std::stoi(token.substr(prefix), &count);
2143 } catch (...) {
2144 // ...
2145 }
2146
2147 if (idx < 0 || prefix + count != token.size())
2148 throw std::runtime_error("illegal token: " + token);
2149 return{ token[0] == 'd' ? ExprOpType::DUP : ExprOpType::SWAP, idx };
2150 } else {
2151 float f;
2152 std::string s;
2153 std::istringstream numStream(token);
2154 numStream.imbue(std::locale::classic());
2155 if (!(numStream >> f))
2156 throw std::runtime_error("failed to convert '" + token + "' to float");
2157 if (numStream >> s)
2158 throw std::runtime_error("failed to convert '" + token + "' to float, not the whole token could be converted");
2159 return{ ExprOpType::CONSTANT, f };
2160 }
2161 }
2162
parseExpr(const std::string & expr,const VSVideoInfo * const * vi,int numInputs)2163 ExpressionTree parseExpr(const std::string &expr, const VSVideoInfo * const *vi, int numInputs)
2164 {
2165 constexpr unsigned char numOperands[] = {
2166 0, // MEM_LOAD_U8
2167 0, // MEM_LOAD_U16
2168 0, // MEM_LOAD_F16
2169 0, // MEM_LOAD_F32
2170 0, // CONSTANT
2171 0, // MEM_STORE_U8
2172 0, // MEM_STORE_U16
2173 0, // MEM_STORE_F16
2174 0, // MEM_STORE_F32
2175 2, // ADD
2176 2, // SUB
2177 2, // MUL
2178 2, // DIV
2179 3, // FMA
2180 1, // SQRT
2181 1, // ABS
2182 1, // NEG
2183 2, // MAX
2184 2, // MIN
2185 2, // CMP
2186 2, // AND
2187 2, // OR
2188 2, // XOR
2189 1, // NOT
2190 1, // EXP
2191 1, // LOG
2192 2, // POW
2193 1, // SIN
2194 1, // COS
2195 3, // TERNARY
2196 0, // MUX
2197 0, // DUP
2198 0, // SWAP
2199 };
2200 static_assert(sizeof(numOperands) == static_cast<unsigned>(ExprOpType::SWAP) + 1, "invalid table");
2201
2202 auto tokens = tokenize(expr);
2203
2204 ExpressionTree tree;
2205 std::vector<ExpressionTreeNode *> stack;
2206
2207 for (const std::string &tok : tokens) {
2208 ExprOp op = decodeToken(tok);
2209
2210 // Check validity.
2211 if (op.type == ExprOpType::MEM_LOAD_U8 && op.imm.i >= numInputs)
2212 throw std::runtime_error("reference to undefined clip: " + tok);
2213 if ((op.type == ExprOpType::DUP || op.type == ExprOpType::SWAP) && op.imm.u >= stack.size())
2214 throw std::runtime_error("insufficient values on stack: " + tok);
2215 if (stack.size() < numOperands[static_cast<size_t>(op.type)])
2216 throw std::runtime_error("insufficient values on stack: " + tok);
2217
2218 // Rename load operations with the correct data type.
2219 if (op.type == ExprOpType::MEM_LOAD_U8) {
2220 const VSFormat *format = vi[op.imm.i]->format;
2221
2222 if (format->sampleType == stInteger && format->bytesPerSample == 1)
2223 op.type = ExprOpType::MEM_LOAD_U8;
2224 else if (format->sampleType == stInteger && format->bytesPerSample == 2)
2225 op.type = ExprOpType::MEM_LOAD_U16;
2226 else if (format->sampleType == stFloat && format->bytesPerSample == 2)
2227 op.type = ExprOpType::MEM_LOAD_F16;
2228 else if (format->sampleType == stFloat && format->bytesPerSample == 4)
2229 op.type = ExprOpType::MEM_LOAD_F32;
2230 }
2231
2232 // Apply DUP and SWAP in the frontend.
2233 if (op.type == ExprOpType::DUP) {
2234 stack.push_back(tree.clone(stack[stack.size() - 1 - op.imm.u]));
2235 } else if (op.type == ExprOpType::SWAP) {
2236 std::swap(stack.back(), stack[stack.size() - 1 - op.imm.u]);
2237 } else {
2238 size_t operands = numOperands[static_cast<size_t>(op.type)];
2239
2240 if (operands == 0) {
2241 stack.push_back(tree.makeNode(op));
2242 } else if (operands == 1) {
2243 ExpressionTreeNode *child = stack.back();
2244 stack.pop_back();
2245
2246 ExpressionTreeNode *node = tree.makeNode(op);
2247 node->setLeft(child);
2248 stack.push_back(node);
2249 } else if (operands == 2) {
2250 ExpressionTreeNode *left = stack[stack.size() - 2];
2251 ExpressionTreeNode *right = stack[stack.size() - 1];
2252 stack.resize(stack.size() - 2);
2253
2254 ExpressionTreeNode *node = tree.makeNode(op);
2255 node->setLeft(left);
2256 node->setRight(right);
2257 stack.push_back(node);
2258 } else if (operands == 3) {
2259 ExpressionTreeNode *arg1 = stack[stack.size() - 3];
2260 ExpressionTreeNode *arg2 = stack[stack.size() - 2];
2261 ExpressionTreeNode *arg3 = stack[stack.size() - 1];
2262 stack.resize(stack.size() - 3);
2263
2264 ExpressionTreeNode *mux = tree.makeNode(ExprOpType::MUX);
2265 mux->setLeft(arg2);
2266 mux->setRight(arg3);
2267
2268 ExpressionTreeNode *node = tree.makeNode(op);
2269 node->setLeft(arg1);
2270 node->setRight(mux);
2271 stack.push_back(node);
2272 }
2273 }
2274 }
2275
2276 if (stack.empty())
2277 throw std::runtime_error("empty expression: " + expr);
2278 if (stack.size() > 1)
2279 throw std::runtime_error("unconsumed values on stack: " + expr);
2280
2281 tree.setRoot(stack.back());
2282 return tree;
2283 }
2284
isConstantExpr(const ExpressionTreeNode & node)2285 bool isConstantExpr(const ExpressionTreeNode &node)
2286 {
2287 switch (node.op.type) {
2288 case ExprOpType::MEM_LOAD_U8:
2289 case ExprOpType::MEM_LOAD_U16:
2290 case ExprOpType::MEM_LOAD_F16:
2291 case ExprOpType::MEM_LOAD_F32:
2292 return false;
2293 case ExprOpType::CONSTANT:
2294 return true;
2295 default:
2296 return (!node.left || isConstantExpr(*node.left)) && (!node.right || isConstantExpr(*node.right));
2297 }
2298 }
2299
isConstant(const ExpressionTreeNode & node)2300 bool isConstant(const ExpressionTreeNode &node)
2301 {
2302 return node.op.type == ExprOpType::CONSTANT;
2303 }
2304
isConstant(const ExpressionTreeNode & node,float val)2305 bool isConstant(const ExpressionTreeNode &node, float val)
2306 {
2307 return node.op.type == ExprOpType::CONSTANT && node.op.imm.f == val;
2308 }
2309
evalConstantExpr(const ExpressionTreeNode & node)2310 float evalConstantExpr(const ExpressionTreeNode &node)
2311 {
2312 auto bool2float = [](bool x) { return x ? 1.0f : 0.0f; };
2313 auto float2bool = [](float x) { return x > 0.0f; };
2314
2315 #define LEFT evalConstantExpr(*node.left)
2316 #define RIGHT evalConstantExpr(*node.right)
2317 #define RIGHTLEFT evalConstantExpr(*node.right->left)
2318 #define RIGHTRIGHT evalConstantExpr(*node.right->right)
2319 switch (node.op.type) {
2320 case ExprOpType::CONSTANT: return node.op.imm.f;
2321 case ExprOpType::ADD: return LEFT + RIGHT;
2322 case ExprOpType::SUB: return LEFT - RIGHT;
2323 case ExprOpType::MUL: return LEFT * RIGHT;
2324 case ExprOpType::DIV: return LEFT / RIGHT;
2325 case ExprOpType::FMA:
2326 switch (static_cast<FMAType>(node.op.imm.u)) {
2327 case FMAType::FMADD: return RIGHTLEFT * RIGHTRIGHT + LEFT;
2328 case FMAType::FMSUB: return RIGHTLEFT * RIGHTRIGHT - LEFT;
2329 case FMAType::FNMADD: return -(RIGHTLEFT * RIGHTRIGHT) + LEFT;
2330 case FMAType::FNMSUB: return -(RIGHTLEFT * RIGHTRIGHT) - LEFT;
2331 }
2332 return NAN;
2333 case ExprOpType::SQRT: return std::sqrt(LEFT);
2334 case ExprOpType::ABS: return std::fabs(LEFT);
2335 case ExprOpType::NEG: return -LEFT;
2336 case ExprOpType::MAX: return std::max(LEFT, RIGHT);
2337 case ExprOpType::MIN: return std::min(LEFT, RIGHT);
2338 case ExprOpType::CMP:
2339 switch (static_cast<ComparisonType>(node.op.imm.u)) {
2340 case ComparisonType::EQ: return bool2float(LEFT == RIGHT);
2341 case ComparisonType::LT: return bool2float(LEFT < RIGHT);
2342 case ComparisonType::LE: return bool2float(LEFT <= RIGHT);
2343 case ComparisonType::NEQ: return bool2float(LEFT != RIGHT);
2344 case ComparisonType::NLT: return bool2float(LEFT >= RIGHT);
2345 case ComparisonType::NLE: return bool2float(LEFT > RIGHT);
2346 }
2347 return NAN;
2348 case ExprOpType::AND: return bool2float(float2bool(LEFT) && float2bool(RIGHT));
2349 case ExprOpType::OR: return bool2float(float2bool(LEFT) || float2bool(RIGHT));
2350 case ExprOpType::XOR: return bool2float(float2bool(LEFT) != float2bool(RIGHT));
2351 case ExprOpType::NOT: return bool2float(!float2bool(LEFT));
2352 case ExprOpType::EXP: return std::exp(LEFT);
2353 case ExprOpType::LOG: return std::log(LEFT);
2354 case ExprOpType::POW: return std::pow(LEFT, RIGHT);
2355 case ExprOpType::SIN: return std::sin(LEFT);
2356 case ExprOpType::COS: return std::cos(LEFT);
2357 case ExprOpType::TERNARY: return float2bool(LEFT) ? RIGHTLEFT : RIGHTRIGHT;
2358 default: return NAN;
2359 }
2360 #undef RIGHTRIGHT
2361 #undef RIGHTLEFT
2362 #undef RIGHT
2363 #undef LEFT
2364 }
2365
isOpCode(const ExpressionTreeNode & node,std::initializer_list<ExprOpType> types)2366 bool isOpCode(const ExpressionTreeNode &node, std::initializer_list<ExprOpType> types)
2367 {
2368 for (ExprOpType type : types) {
2369 if (node.op.type == type)
2370 return true;
2371 }
2372 return false;
2373 }
2374
isInteger(float x)2375 bool isInteger(float x)
2376 {
2377 return std::floor(x) == x;
2378 }
2379
replaceNode(ExpressionTreeNode & node,const ExpressionTreeNode & replacement)2380 void replaceNode(ExpressionTreeNode &node, const ExpressionTreeNode &replacement)
2381 {
2382 node.op = replacement.op;
2383 node.setLeft(replacement.left);
2384 node.setRight(replacement.right);
2385 }
2386
swapNodeContents(ExpressionTreeNode & lhs,ExpressionTreeNode & rhs)2387 void swapNodeContents(ExpressionTreeNode &lhs, ExpressionTreeNode &rhs)
2388 {
2389 std::swap(lhs, rhs);
2390 std::swap(lhs.parent, rhs.parent);
2391 }
2392
applyValueNumbering(ExpressionTree & tree)2393 void applyValueNumbering(ExpressionTree &tree)
2394 {
2395 std::vector<ExpressionTreeNode *> numbered;
2396 int valueNum = 0;
2397
2398 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
2399 {
2400 node.valueNum = -1;
2401 });
2402
2403 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
2404 {
2405 if (node.op.type == ExprOpType::MUX)
2406 return;
2407
2408 for (ExpressionTreeNode *testnode : numbered) {
2409 if (equalSubTree(&node, testnode)) {
2410 node.valueNum = testnode->valueNum;
2411 return;
2412 }
2413 }
2414
2415 node.valueNum = valueNum++;
2416 numbered.push_back(&node);
2417 });
2418 }
2419
emitIntegerPow(ExpressionTree & tree,const ExpressionTreeNode & node,int exponent)2420 ExpressionTreeNode *emitIntegerPow(ExpressionTree &tree, const ExpressionTreeNode &node, int exponent)
2421 {
2422 if (exponent == 1)
2423 return tree.clone(&node);
2424
2425 ExpressionTreeNode *mulNode = tree.makeNode({ ExprOpType::MUL });
2426 mulNode->setLeft(emitIntegerPow(tree, node, (exponent + 1) / 2));
2427 mulNode->setRight(emitIntegerPow(tree, node, exponent - (exponent + 1) / 2));
2428 return mulNode;
2429 }
2430
2431 typedef std::unordered_map<int, const ExpressionTreeNode *> ValueIndex;
2432
2433 class ExponentMap {
2434 struct CanonicalCompare {
2435 const ValueIndex &index;
2436
operator ()__anon2a0524d50111::ExponentMap::CanonicalCompare2437 bool operator()(const std::pair<int, float> &lhs, const std::pair<int, float> &rhs) const
2438 {
2439 const std::initializer_list<ExprOpType> memOpCodes = { ExprOpType::MEM_LOAD_U8, ExprOpType::MEM_LOAD_U16, ExprOpType::MEM_LOAD_F16, ExprOpType::MEM_LOAD_F32 };
2440
2441 // Order equivalent terms by exponent.
2442 if (lhs.first == rhs.first)
2443 return lhs.second < rhs.second;
2444
2445 const ExpressionTreeNode *lhsNode = index.at(lhs.first);
2446 const ExpressionTreeNode *rhsNode = index.at(rhs.first);
2447
2448 // Ordering: complex values, memory, constants
2449 int lhsCategory = isConstant(*lhsNode) ? 2 : isOpCode(*lhsNode, memOpCodes) ? 1 : 0;
2450 int rhsCategory = isConstant(*rhsNode) ? 2 : isOpCode(*rhsNode, memOpCodes) ? 1 : 0;
2451
2452 if (lhsCategory != rhsCategory)
2453 return lhsCategory < rhsCategory;
2454
2455 // Ordering criteria for each category:
2456 //
2457 // constants: order by value
2458 // memory: order by variable name
2459 // other: order by value number (unstable)
2460 if (lhsCategory == 2)
2461 return lhsNode->op.imm.f < rhsNode->op.imm.f;
2462 else if (lhsCategory == 1)
2463 return lhsNode->op.imm.u < rhsNode->op.imm.f;
2464 else
2465 return lhs.first < rhs.first;
2466 };
2467 };
2468
2469 // e.g. 3 * v0^2 * v1^3
2470 // map = { 0: 2, 1: 3 }, coeff = 3
2471 std::map<int, float> map; // key = valueNum, value = exponent
2472 std::vector<int> origSequence;
2473 float coeff;
2474
expandOrigSequence(ValueIndex & index)2475 bool expandOrigSequence(ValueIndex &index)
2476 {
2477 bool changed = false;
2478
2479 for (size_t i = 0; i < origSequence.size(); ++i) {
2480 const ExpressionTreeNode *value = index.at(origSequence[i]);
2481
2482 if (value->op == ExprOpType::POW && isConstant(*value->right)) {
2483 origSequence[i] = value->left->valueNum;
2484 changed = true;
2485 } else if (value->op == ExprOpType::MUL || value->op == ExprOpType::DIV) {
2486 origSequence[i] = value->left->valueNum;
2487 origSequence.insert(origSequence.begin() + i + 1, value->right->valueNum);
2488 changed = true;
2489 }
2490 }
2491
2492 return changed;
2493 }
2494
expandOnePass(ValueIndex & index)2495 bool expandOnePass(ValueIndex &index)
2496 {
2497 bool changed = false;
2498
2499 for (auto it = map.begin(); it != map.end();) {
2500 const ExpressionTreeNode *value = index.at(it->first);
2501 bool erase = false;
2502
2503 if (value->op == ExprOpType::POW && isConstant(*value->right)) {
2504 index[value->left->valueNum] = value->left;
2505
2506 map[value->left->valueNum] += it->second * value->right->op.imm.f;
2507 erase = true;
2508 } else if (value->op == ExprOpType::MUL) {
2509 index[value->left->valueNum] = value->left;
2510 index[value->right->valueNum] = value->right;
2511
2512 map[value->left->valueNum] += it->second;
2513 map[value->right->valueNum] += it->second;
2514 erase = true;
2515 } else if (value->op == ExprOpType::DIV) {
2516 index[value->left->valueNum] = value->left;
2517 index[value->right->valueNum] = value->right;
2518
2519 map[value->left->valueNum] += it->second;
2520 map[value->right->valueNum] -= it->second;
2521 erase = true;
2522 }
2523
2524 if (erase) {
2525 it = map.erase(it);
2526 changed = true;
2527 continue;
2528 }
2529
2530 ++it;
2531 }
2532
2533 return changed;
2534 }
2535
combineConstants(const ValueIndex & index)2536 void combineConstants(const ValueIndex &index)
2537 {
2538 for (auto it = map.begin(); it != map.end();) {
2539 const ExpressionTreeNode *node = index.at(it->first);
2540 if (isConstant(*node)) {
2541 coeff *= std::pow(node->op.imm.f, it->second);
2542 it = map.erase(it);
2543 continue;
2544 }
2545 ++it;
2546 }
2547 }
2548 public:
ExponentMap()2549 ExponentMap() : coeff(1.0f) {}
2550
addTerm(int valueNum,float exp)2551 void addTerm(int valueNum, float exp)
2552 {
2553 map[valueNum] += exp;
2554 origSequence.push_back(valueNum);
2555 }
2556
addCoeff(float val)2557 void addCoeff(float val) { coeff += val; }
2558
mulCoeff(float val)2559 void mulCoeff(float val) { coeff *= val; }
2560
getCoeff() const2561 float getCoeff() const { return coeff; }
2562
isScalar() const2563 bool isScalar() const { return map.empty(); }
2564
numTerms() const2565 size_t numTerms() const { return map.size() + 1; }
2566
isSameTerm(const ExponentMap & other) const2567 bool isSameTerm(const ExponentMap &other) const
2568 {
2569 auto it1 = map.begin();
2570 auto it2 = other.map.begin();
2571
2572 while (it1 != map.end() && it2 != other.map.end()) {
2573 if (it1->first != it2->first || it1->second != it2->second)
2574 return false;
2575
2576 ++it1;
2577 ++it2;
2578 }
2579
2580 return it1 == map.end() && it2 == other.map.end();
2581 }
2582
expand(ValueIndex & index)2583 void expand(ValueIndex &index)
2584 {
2585 while (expandOnePass(index)) {
2586 // ...
2587 }
2588 combineConstants(index);
2589
2590 while (expandOrigSequence(index)) {
2591 // ...
2592 }
2593 }
2594
isCanonical(const ValueIndex & index) const2595 bool isCanonical(const ValueIndex &index) const
2596 {
2597 std::vector<std::pair<int, float>> tmp;
2598 for (int x : origSequence) {
2599 tmp.push_back({ x, 1.0f });
2600 }
2601 return std::is_sorted(tmp.begin(), tmp.end(), CanonicalCompare{ index });
2602 }
2603
emit(ExpressionTree & tree,const ValueIndex & index) const2604 ExpressionTreeNode *emit(ExpressionTree &tree, const ValueIndex &index) const
2605 {
2606 std::vector<std::pair<int, float>> flat(map.begin(), map.end());
2607 std::sort(flat.begin(), flat.end(), CanonicalCompare{ index });
2608
2609 ExpressionTreeNode *node = nullptr;
2610
2611 for (auto &term : flat) {
2612 ExpressionTreeNode *powNode = tree.makeNode(ExprOpType::POW);
2613 powNode->setLeft(tree.clone(index.at(term.first)));
2614 powNode->setRight(tree.makeNode({ ExprOpType::CONSTANT, term.second }));
2615
2616 if (node) {
2617 ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
2618 mulNode->setLeft(node);
2619 mulNode->setRight(powNode);
2620 node = mulNode;
2621 } else {
2622 node = powNode;
2623 }
2624 }
2625
2626 if (node) {
2627 ExpressionTreeNode *mulNode = tree.makeNode(ExprOpType::MUL);
2628 mulNode->setLeft(node);
2629 mulNode->setRight(tree.makeNode({ ExprOpType::CONSTANT, coeff }));
2630 node = mulNode;
2631 } else {
2632 node = tree.makeNode({ ExprOpType::CONSTANT, coeff });
2633 }
2634
2635 return node;
2636 }
2637
canonicalOrder(const ExponentMap & other,const ValueIndex & index) const2638 bool canonicalOrder(const ExponentMap &other, const ValueIndex &index) const
2639 {
2640 // Convert map to flat array, as canonical order is different from value numbering.
2641 std::vector<std::pair<int, float>> lhsFlat(map.begin(), map.end());
2642 std::vector<std::pair<int, float>> rhsFlat(other.map.begin(), other.map.end());
2643
2644 CanonicalCompare pred{ index };
2645 std::sort(lhsFlat.begin(), lhsFlat.end(), pred);
2646 std::sort(rhsFlat.begin(), rhsFlat.end(), pred);
2647 return std::lexicographical_compare(lhsFlat.begin(), lhsFlat.end(), rhsFlat.begin(), rhsFlat.end(), pred);
2648 }
2649 };
2650
2651 class AdditiveSequence {
2652 std::vector<ExponentMap> terms;
2653 float scalarTerm;
2654 public:
AdditiveSequence()2655 AdditiveSequence() : scalarTerm() {}
2656
addTerm(int valueNum,int sign)2657 void addTerm(int valueNum, int sign)
2658 {
2659 ExponentMap map;
2660 map.addTerm(valueNum, 1.0f);
2661 map.mulCoeff(static_cast<float>(sign));
2662 terms.push_back(std::move(map));
2663 }
2664
numTerms() const2665 size_t numTerms() const { return terms.size() + 1; }
2666
expand(ValueIndex & index)2667 void expand(ValueIndex &index)
2668 {
2669 for (auto &term : terms) {
2670 term.expand(index);
2671 }
2672
2673 for (auto it = terms.begin(); it != terms.end();) {
2674 if (it->isScalar()) {
2675 scalarTerm += it->getCoeff();
2676 it = terms.erase(it);
2677 continue;
2678 }
2679
2680 ++it;
2681 }
2682
2683 for (auto it1 = terms.begin(); it1 != terms.end();) {
2684 for (auto it2 = it1 + 1; it2 != terms.end(); ++it2) {
2685 if (it1->isSameTerm(*it2)) {
2686 it1->addCoeff(it2->getCoeff());
2687 it2->mulCoeff(0.0f);
2688 }
2689 }
2690
2691 if (it1->getCoeff() == 0.0f) {
2692 it1 = terms.erase(it1);
2693 continue;
2694 }
2695
2696 ++it1;
2697 }
2698 }
2699
canonicalize(const ValueIndex & index)2700 bool canonicalize(const ValueIndex &index)
2701 {
2702 auto pred = [&](const ExponentMap &lhs, const ExponentMap &rhs)
2703 {
2704 return lhs.canonicalOrder(rhs, index);
2705 };
2706
2707 if (std::is_sorted(terms.begin(), terms.end(), pred))
2708 return true;
2709
2710 std::sort(terms.begin(), terms.end(), pred);
2711 return false;
2712 }
2713
emit(ExpressionTree & tree,const ValueIndex & index) const2714 ExpressionTreeNode *emit(ExpressionTree &tree, const ValueIndex &index) const
2715 {
2716 ExpressionTreeNode *head = nullptr;
2717
2718 for (const auto &term : terms) {
2719 ExpressionTreeNode *node = term.emit(tree, index);
2720
2721 if (head) {
2722 ExpressionTreeNode *addNode = tree.makeNode(ExprOpType::ADD);
2723 addNode->setLeft(head);
2724 addNode->setRight(node);
2725 head = addNode;
2726 } else {
2727 head = node;
2728 }
2729 }
2730
2731 if (head) {
2732 ExpressionTreeNode *addNode = tree.makeNode(scalarTerm < 0 ? ExprOpType::SUB : ExprOpType::ADD);
2733 addNode->setLeft(head);
2734 addNode->setRight(tree.makeNode({ ExprOpType::CONSTANT, std::fabs(scalarTerm) }));
2735 head = addNode;
2736 } else {
2737 head = tree.makeNode({ ExprOpType::CONSTANT, 0.0f });
2738 }
2739
2740 return head;
2741 }
2742 };
2743
analyzeAdditiveExpression(ExpressionTree & tree,ExpressionTreeNode & node)2744 bool analyzeAdditiveExpression(ExpressionTree &tree, ExpressionTreeNode &node)
2745 {
2746 size_t origNumTerms = 0;
2747 AdditiveSequence expr;
2748 ValueIndex index;
2749
2750 node.preorder([&](ExpressionTreeNode &node)
2751 {
2752 if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }))
2753 return false;
2754
2755 // Deduce net sign of term.
2756 const ExpressionTreeNode *parent = node.parent;
2757 const ExpressionTreeNode *cur = &node;
2758 int polarity = 1;
2759
2760 while (parent && isOpCode(*parent, { ExprOpType::ADD, ExprOpType::SUB })) {
2761 if (parent->op == ExprOpType::SUB && cur == parent->right)
2762 polarity = -polarity;
2763
2764 cur = parent;
2765 parent = parent->parent;
2766 }
2767
2768 ++origNumTerms;
2769 expr.addTerm(node.valueNum, polarity);
2770 index[node.valueNum] = &node;
2771 return true;
2772 });
2773
2774 expr.expand(index);
2775 bool canonical = expr.canonicalize(index);
2776
2777 if (expr.numTerms() < origNumTerms || !canonical) {
2778 ExpressionTreeNode *seq = expr.emit(tree, index);
2779 replaceNode(node, *seq);
2780 return true;
2781 }
2782
2783 return false;
2784 }
2785
analyzeMultiplicativeExpression(ExpressionTree & tree,ExpressionTreeNode & node)2786 bool analyzeMultiplicativeExpression(ExpressionTree &tree, ExpressionTreeNode &node)
2787 {
2788 std::unordered_map<int, const ExpressionTreeNode *> index;
2789
2790 ExponentMap expr;
2791 size_t origNumTerms = 0;
2792 size_t numDivs = 0;
2793
2794 node.preorder([&](ExpressionTreeNode &node)
2795 {
2796 if (node.op == ExprOpType::DIV)
2797 ++numDivs;
2798
2799 if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }))
2800 return false;
2801
2802 // Deduce net sign of term.
2803 const ExpressionTreeNode *parent = node.parent;
2804 const ExpressionTreeNode *cur = &node;
2805 int polarity = 1;
2806
2807 while (parent && isOpCode(*parent, { ExprOpType::MUL, ExprOpType::DIV })) {
2808 if (parent->op == ExprOpType::DIV && cur == parent->right)
2809 polarity = -polarity;
2810
2811 cur = parent;
2812 parent = parent->parent;
2813 }
2814
2815 expr.addTerm(node.valueNum, static_cast<float>(polarity));
2816 index[node.valueNum] = &node;
2817 ++origNumTerms;
2818 return true;
2819 });
2820
2821 expr.expand(index);
2822
2823 if (expr.numTerms() < origNumTerms || !expr.isCanonical(index) || numDivs) {
2824 ExpressionTreeNode *seq = expr.emit(tree, index);
2825 replaceNode(node, *seq);
2826 return true;
2827 }
2828
2829 return false;
2830 }
2831
applyAlgebraicOptimizations(ExpressionTree & tree)2832 bool applyAlgebraicOptimizations(ExpressionTree &tree)
2833 {
2834 bool changed = false;
2835
2836 applyValueNumbering(tree);
2837
2838 tree.getRoot()->preorder([&](ExpressionTreeNode &node)
2839 {
2840 if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::ADD, ExprOpType::SUB }))) {
2841 changed = changed || analyzeAdditiveExpression(tree, node);
2842 return changed;
2843 }
2844
2845 if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && (!node.parent || !isOpCode(*node.parent, { ExprOpType::MUL, ExprOpType::DIV }))) {
2846 changed = changed || analyzeMultiplicativeExpression(tree, node);
2847 return changed;
2848 }
2849
2850 return false;
2851 });
2852
2853 return changed;
2854 }
2855
applyComparisonOptimizations(ExpressionTree & tree)2856 bool applyComparisonOptimizations(ExpressionTree &tree)
2857 {
2858 bool changed = false;
2859
2860 applyValueNumbering(tree);
2861
2862 tree.getRoot()->preorder([&](ExpressionTreeNode &node)
2863 {
2864 // Eliminate constant conditions.
2865 if (node.op.type == ExprOpType::CMP && node.left->valueNum == node.right->valueNum) {
2866 ComparisonType type = static_cast<ComparisonType>(node.op.imm.u);
2867 if (type == ComparisonType::EQ || type == ComparisonType::LE || type == ComparisonType::NLT)
2868 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
2869 else
2870 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
2871
2872 changed = true;
2873 return changed;
2874 }
2875
2876 // Eliminate identical branches.
2877 if (node.op == ExprOpType::TERNARY && node.right->left->valueNum == node.right->right->valueNum) {
2878 replaceNode(node, *node.right->left);
2879 changed = true;
2880 return changed;
2881 }
2882
2883 // MIN/MAX detection.
2884 if (node.op == ExprOpType::TERNARY && node.left->op.type == ExprOpType::CMP) {
2885 ComparisonType type = static_cast<ComparisonType>(node.left->op.imm.u);
2886 int cmpTerms[2] = { node.left->left->valueNum, node.left->right->valueNum };
2887 int muxTerms[2] = { node.right->left->valueNum, node.right->right->valueNum };
2888
2889 bool isSameTerms = (cmpTerms[0] == muxTerms[0] && cmpTerms[1] == muxTerms[1]) || (cmpTerms[0] == muxTerms[1] && cmpTerms[1] == muxTerms[0]);
2890 bool isLessOrGreater = type == ComparisonType::LT || type == ComparisonType::LE || type == ComparisonType::NLE || type == ComparisonType::NLT;
2891
2892 if (isSameTerms && isLessOrGreater) {
2893 // a < b ? a : b --> min(a, b) a > b ? b : a --> min(a, b)
2894 // a > b ? a : b --> max(a, b) a < b ? b : a --> max(a, b)
2895 bool min = (type == ComparisonType::LT || type == ComparisonType::LE) ? cmpTerms[0] == muxTerms[0] : cmpTerms[0] != muxTerms[0];
2896 ExpressionTreeNode *a = node.left->left;
2897 ExpressionTreeNode *b = node.left->right;
2898
2899 replaceNode(node, ExpressionTreeNode{ min ? ExprOpType::MIN : ExprOpType::MAX });
2900 node.setLeft(a);
2901 node.setRight(b);
2902
2903 changed = true;
2904 return changed;
2905 }
2906 }
2907
2908 // CMP to SUB conversion. It has lower priority than other comparison transformations.
2909 if (node.op.type == ExprOpType::CMP && node.parent && isOpCode(*node.parent, { ExprOpType::AND, ExprOpType::OR, ExprOpType::XOR, ExprOpType::TERNARY })) {
2910 ComparisonType type = static_cast<ComparisonType>(node.op.imm.u);
2911
2912 // a < b --> b - a a > b --> a - b
2913 if (type == ComparisonType::LT || type == ComparisonType::NLE) {
2914 if (type == ComparisonType::LT)
2915 std::swap(node.left, node.right);
2916
2917 node.op = ExprOpType::SUB;
2918 changed = true;
2919 return changed;
2920 }
2921 }
2922
2923 return false;
2924 });
2925
2926 return changed;
2927 }
2928
applyLocalOptimizations(ExpressionTree & tree)2929 bool applyLocalOptimizations(ExpressionTree &tree)
2930 {
2931 bool changed = false;
2932
2933 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
2934 {
2935 if (node.op.type == ExprOpType::MUX)
2936 return;
2937
2938 // Constant folding.
2939 if (node.op.type != ExprOpType::CONSTANT && isConstantExpr(node)) {
2940 float val = evalConstantExpr(node);
2941 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, val } });
2942 changed = true;
2943 }
2944
2945 // Move constants to right-hand side to simplify identities.
2946 if (isOpCode(node, { ExprOpType::ADD, ExprOpType::MUL }) && isConstant(*node.left) && !isConstant(*node.right)) {
2947 std::swap(node.left, node.right);
2948 changed = true;
2949 }
2950
2951 // x * 0 = 0 0 / x = 0
2952 if ((node.op == ExprOpType::MUL && isConstant(*node.right, 0.0f)) || (node.op == ExprOpType::DIV && isConstant(*node.left, 0.0f))) {
2953 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 0.0f } });
2954 changed = true;
2955 }
2956
2957 // sqrt(x) = x ** 0.5
2958 if (node.op == ExprOpType::SQRT) {
2959 node.op = ExprOpType::POW;
2960 node.setRight(tree.makeNode({ ExprOpType::CONSTANT, 0.5f }));
2961 changed = true;
2962 }
2963
2964 // log(exp(x)) = x exp(log(x)) = x
2965 if ((node.op == ExprOpType::LOG && node.left->op == ExprOpType::EXP) || (node.op == ExprOpType::EXP && node.left->op == ExprOpType::LOG)) {
2966 replaceNode(node, *node.left->left);
2967 changed = true;
2968 }
2969
2970 // x ** 0 = 1
2971 if (node.op == ExprOpType::POW && isConstant(*node.right, 0.0f)) {
2972 replaceNode(node, ExpressionTreeNode{ { ExprOpType::CONSTANT, 1.0f } });
2973 changed = true;
2974 }
2975
2976 // (a ** b) ** c = a ** (b * c)
2977 if (node.op == ExprOpType::POW && node.left->op == ExprOpType::POW) {
2978 ExpressionTreeNode *a = node.left->left;
2979 ExpressionTreeNode *b = node.left->right;
2980 ExpressionTreeNode *c = node.right;
2981 replaceNode(*node.left, *a);
2982 node.setRight(tree.makeNode(ExprOpType::MUL));
2983 node.right->setLeft(b);
2984 node.right->setRight(c);
2985 changed = true;
2986 }
2987
2988 // 0 ? x : y = y 1 ? x : y = x
2989 if (node.op == ExprOpType::TERNARY && isConstant(*node.left)) {
2990 ExpressionTreeNode *replacement = node.left->op.imm.f > 0.0f ? node.right->left : node.right->right;
2991 replaceNode(node, *replacement);
2992 changed = true;
2993 }
2994
2995 // a <= b ? x : y --> a > b ? y : x a >= b ? x : y --> a < b ? y : x
2996 if (node.op == ExprOpType::TERNARY && node.left->op.type == ExprOpType::CMP) {
2997 ComparisonType type = static_cast<ComparisonType>(node.left->op.imm.u);
2998
2999 if (type == ComparisonType::LE || type == ComparisonType::NLT) {
3000 node.left->op.imm.u = static_cast<unsigned>(type == ComparisonType::LE ? ComparisonType::NLE : ComparisonType::LT);
3001 std::swap(node.right->left, node.right->right);
3002 changed = true;
3003 }
3004 }
3005
3006 // !a ? b : c --> a ? c : b
3007 if (node.op == ExprOpType::TERNARY && node.left->op == ExprOpType::NOT) {
3008 replaceNode(*node.left, *node.left->left);
3009 std::swap(node.right->left, node.right->right);
3010 changed = true;
3011 }
3012
3013 // !(a < b) --> a >= b
3014 if (node.op == ExprOpType::NOT && node.left->op.type == ExprOpType::CMP) {
3015 switch (static_cast<ComparisonType>(node.left->op.imm.u)) {
3016 case ComparisonType::EQ: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::NEQ); break;
3017 case ComparisonType::LT: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::NLT); break;
3018 case ComparisonType::LE: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::NLE); break;
3019 case ComparisonType::NEQ: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::EQ); break;
3020 case ComparisonType::NLT: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::LT); break;
3021 case ComparisonType::NLE: node.left->op.imm.u = static_cast<unsigned>(ComparisonType::LE); break;
3022 }
3023 replaceNode(node, *node.left);
3024 changed = true;
3025 }
3026 });
3027
3028 return changed;
3029 }
3030
applyAlgebraicCleanup(ExpressionTree & tree)3031 bool applyAlgebraicCleanup(ExpressionTree &tree)
3032 {
3033 bool changed = false;
3034
3035 // Prune extra terms introduced by the algebraic analysis. These need to run in a later pass to prevent cycles.
3036 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3037 {
3038 // x + 0 = x x - 0 = x
3039 if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && isConstant(*node.right, 0.0f)) {
3040 replaceNode(node, *node.left);
3041 changed = true;
3042 }
3043
3044 // x * 1 = x x / 1 = x
3045 if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, 1.0f)) {
3046 replaceNode(node, *node.left);
3047 changed = true;
3048 }
3049
3050 // x ** 1 = x
3051 if (node.op == ExprOpType::POW && isConstant(*node.right, 1.0f)) {
3052 replaceNode(node, *node.left);
3053 changed = true;
3054 }
3055 });
3056
3057 return changed;
3058 }
3059
3060
applyStrengthReduction(ExpressionTree & tree)3061 bool applyStrengthReduction(ExpressionTree &tree)
3062 {
3063 bool changed = false;
3064
3065 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3066 {
3067 if (node.op == ExprOpType::MUX)
3068 return;
3069
3070 // 0 - x = -x
3071 if (node.op == ExprOpType::SUB && isConstant(*node.left, 0.0f)) {
3072 ExpressionTreeNode *tmp = node.right;
3073 replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
3074 node.setLeft(tmp);
3075 changed = true;
3076 }
3077
3078 // x * -1 = -x x / -1 = -x
3079 if (isOpCode(node, { ExprOpType::MUL, ExprOpType::DIV }) && isConstant(*node.right, -1.0f)) {
3080 ExpressionTreeNode *tmp = node.left;
3081 replaceNode(node, ExpressionTreeNode{ { ExprOpType::NEG } });
3082 node.setLeft(tmp);
3083 changed = true;
3084 }
3085
3086 // a + -b = a - b a - -b = a + b
3087 if (isOpCode(node, { ExprOpType::ADD, ExprOpType::SUB }) && node.right->op.type == ExprOpType::NEG) {
3088 node.op = node.op == ExprOpType::ADD ? ExprOpType::SUB : ExprOpType::ADD;
3089 replaceNode(*node.right, *node.right->left);
3090 changed = true;
3091 }
3092
3093 // -a + b = b - a
3094 if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::NEG) {
3095 node.op = ExprOpType::SUB;
3096 replaceNode(*node.left, *node.left->left);
3097 std::swap(node.left, node.right);
3098 }
3099
3100 // -(a - b) = b - a
3101 if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::SUB) {
3102 replaceNode(node, *node.left);
3103 std::swap(node.left, node.right);
3104 changed = true;
3105 }
3106
3107 // x * 2 = x + x
3108 if (node.op == ExprOpType::MUL && isConstant(*node.right, 2.0f) && (!node.parent || node.parent->op != ExprOpType::ADD)) {
3109 ExpressionTreeNode *replacement = tree.clone(node.left);
3110 node.op = ExprOpType::ADD;
3111 replaceNode(*node.right, *replacement);
3112 changed = true;
3113 }
3114
3115 // x / y = x * (1 / y)
3116 if (node.op == ExprOpType::DIV && isConstant(*node.right)) {
3117 node.op = ExprOpType::MUL;
3118 node.right->op.imm.f = 1.0f / node.right->op.imm.f;
3119 changed = true;
3120 }
3121
3122 // (1 / x) * y = y / x
3123 if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV && isConstant(*node.left->left, 1.0f)) {
3124 node.op = ExprOpType::DIV;
3125 replaceNode(*node.left, *node.left->right);
3126 std::swap(node.left, node.right);
3127 changed = true;
3128 }
3129
3130 // x * (1 / y) = x / y
3131 if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV && isConstant(*node.right->left, 1.0f)) {
3132 node.op = ExprOpType::DIV;
3133 replaceNode(*node.right, *node.right->right);
3134 changed = true;
3135 }
3136
3137 // (a / b) * c = (a * c) / b
3138 if (node.op == ExprOpType::MUL && node.left->op == ExprOpType::DIV) {
3139 node.op = ExprOpType::DIV;
3140 node.left->op = ExprOpType::MUL;
3141 swapNodeContents(*node.left->right, *node.right);
3142 changed = true;
3143 }
3144
3145 // a * (b / c) = (a * b) / c
3146 if (node.op == ExprOpType::MUL && node.right->op == ExprOpType::DIV) {
3147 node.op = ExprOpType::DIV;
3148 node.right->op = ExprOpType::MUL;
3149 std::swap(node.left, node.right); // (b * c) / a
3150 swapNodeContents(*node.left->left, *node.left->right); // (c * b) / a
3151 swapNodeContents(*node.left->left, *node.right); // (a * b) / c
3152 changed = true;
3153 }
3154
3155 // a / (b / c) = (a * c) / b
3156 if (node.op == ExprOpType::DIV && node.right->op == ExprOpType::DIV) {
3157 node.right->op = ExprOpType::MUL; // a / (b * c)
3158 std::swap(node.left, node.right); // (b * c) / a
3159 swapNodeContents(*node.left->left, *node.right); // (a * c) / b
3160 changed = true;
3161 }
3162
3163 // (a / b) / c = a / (b * c)
3164 if (node.op == ExprOpType::DIV && node.left->op == ExprOpType::DIV) {
3165 node.left->op = ExprOpType::MUL; // (a * b) / c
3166 std::swap(node.left, node.right); // c / (a * b)
3167 swapNodeContents(*node.left, *node.right->left); // a / (c * b)
3168 swapNodeContents(*node.right->left, *node.right->right); // a / (b * c)
3169 changed = true;
3170 }
3171
3172 // x ** (n / 2) = sqrt(x ** n)
3173 if (node.op == ExprOpType::POW && isConstant(*node.right) && !isInteger(node.right->op.imm.f) && isInteger(node.right->op.imm.f * 2.0f)) {
3174 ExpressionTreeNode *dup = tree.clone(&node);
3175 replaceNode(node, ExpressionTreeNode{ ExprOpType::SQRT });
3176 node.setLeft(dup);
3177 node.left->right->op.imm.f *= 2.0f;
3178 changed = true;
3179 }
3180
3181 // x ** -N = 1 / (x ** N)
3182 if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f < 0) {
3183 ExpressionTreeNode *dup = tree.clone(&node);
3184 replaceNode(node, ExpressionTreeNode{ ExprOpType::DIV });
3185 node.setLeft(tree.makeNode({ ExprOpType::CONSTANT, 1.0f }));
3186 node.setRight(dup);
3187 node.right->right->op.imm.f = -node.right->right->op.imm.f;
3188 changed = true;
3189 }
3190
3191 // x ** N = x * x * x * ...
3192 if (node.op == ExprOpType::POW && isConstant(*node.right) && isInteger(node.right->op.imm.f) && node.right->op.imm.f > 0) {
3193 ExpressionTreeNode *replacement = emitIntegerPow(tree, *node.left, static_cast<int>(node.right->op.imm.f));
3194 replaceNode(node, *replacement);
3195 changed = true;
3196 }
3197 });
3198
3199 return changed;
3200 }
3201
applyOpFusion(ExpressionTree & tree)3202 bool applyOpFusion(ExpressionTree &tree)
3203 {
3204 std::unordered_map<int, size_t> refCount;
3205 bool changed = false;
3206
3207 applyValueNumbering(tree);
3208
3209 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3210 {
3211 if (node.op == ExprOpType::MUX)
3212 return;
3213
3214 refCount[node.valueNum]++;
3215 });
3216
3217 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3218 {
3219 if (node.op == ExprOpType::MUX)
3220 return;
3221
3222 auto canElide = [&](ExpressionTreeNode &candidate)
3223 {
3224 return refCount[node.valueNum] > 1 || refCount[candidate.valueNum] <= 1;
3225 };
3226
3227 // a + (b * c) (b * c) + a a - (b * c) (b * c) - a
3228 if (node.op == ExprOpType::ADD && node.right->op == ExprOpType::MUL && canElide(*node.right)) {
3229 node.right->op = ExprOpType::MUX;
3230 node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
3231 changed = true;
3232 }
3233 if (node.op == ExprOpType::ADD && node.left->op == ExprOpType::MUL && canElide(*node.left)) {
3234 std::swap(node.left, node.right);
3235 node.right->op = ExprOpType::MUX;
3236 node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMADD) };
3237 changed = true;
3238 }
3239 if (node.op == ExprOpType::SUB && node.right->op == ExprOpType::MUL && canElide(*node.right)) {
3240 node.right->op = ExprOpType::MUX;
3241 node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FNMADD) };
3242 changed = true;
3243 }
3244 if (node.op == ExprOpType::SUB && node.left->op == ExprOpType::MUL && canElide(*node.left)) {
3245 std::swap(node.left, node.right);
3246 node.right->op = ExprOpType::MUX;
3247 node.op = { ExprOpType::FMA, static_cast<unsigned>(FMAType::FMSUB) };
3248 changed = true;
3249 }
3250
3251 // (a + b) * c = (a * c) + b * c
3252 if (node.op == ExprOpType::MUL && isOpCode(*node.left, { ExprOpType::ADD, ExprOpType::SUB }) &&
3253 isConstant(*node.right) && isConstant(*node.left->right) && canElide(*node.left))
3254 {
3255 std::swap(node.op, node.left->op);
3256 swapNodeContents(*node.right, *node.left->right);
3257 node.right->op.imm.f *= node.left->right->op.imm.f;
3258 changed = true;
3259 }
3260
3261 // Negative FMA.
3262 if (node.op == ExprOpType::NEG && node.left->op == ExprOpType::FMA && canElide(*node.left)) {
3263 replaceNode(node, *node.left);
3264
3265 switch (static_cast<FMAType>(node.op.imm.u)) {
3266 case FMAType::FMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FNMSUB); break;
3267 case FMAType::FMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FNMADD); break;
3268 case FMAType::FNMADD: node.op.imm.u = static_cast<unsigned>(FMAType::FMSUB); break;
3269 case FMAType::FNMSUB: node.op.imm.u = static_cast<unsigned>(FMAType::FMADD); break;
3270 }
3271
3272 changed = true;
3273 }
3274 });
3275
3276 return changed;
3277 }
3278
renameRegisters(std::vector<ExprInstruction> & code)3279 void renameRegisters(std::vector<ExprInstruction> &code)
3280 {
3281 std::unordered_map<int, int> table;
3282 std::set<int> freeList;
3283
3284 for (size_t i = 0; i < code.size(); ++i) {
3285 ExprInstruction &insn = code[i];
3286 int origRegs[4] = { insn.dst, insn.src1, insn.src2, insn.src3 };
3287 int renamed[4] = { insn.dst, insn.src1, insn.src2, insn.src3 };
3288
3289 for (int n = 1; n < 4; ++n) {
3290 if (origRegs[n] < 0)
3291 continue;
3292
3293 auto it = table.find(origRegs[n]);
3294 if (it != table.end())
3295 renamed[n] = it->second;
3296
3297 bool dead = true;
3298
3299 for (size_t j = i + 1; j < code.size(); ++j) {
3300 const ExprInstruction &insn2 = code[j];
3301 if (insn2.src1 == origRegs[n] || insn2.src2 == origRegs[n] || insn2.src3 == origRegs[n]) {
3302 dead = false;
3303 break;
3304 }
3305 }
3306
3307 if (dead)
3308 freeList.insert(renamed[n]);
3309 }
3310
3311 if (origRegs[0] >= 0 && !freeList.empty()) {
3312 renamed[0] = *freeList.begin();
3313 table[origRegs[0]] = renamed[0];
3314 freeList.erase(freeList.begin());
3315 freeList.insert(origRegs[0]);
3316 }
3317
3318 insn.dst = renamed[0];
3319 insn.src1 = renamed[1];
3320 insn.src2 = renamed[2];
3321 insn.src3 = renamed[3];
3322 }
3323 }
3324
compile(ExpressionTree & tree,const VSFormat * format)3325 std::vector<ExprInstruction> compile(ExpressionTree &tree, const VSFormat *format)
3326 {
3327 std::vector<ExprInstruction> code;
3328 std::unordered_set<int> found;
3329
3330 if (!tree.getRoot())
3331 return code;
3332
3333 while (applyLocalOptimizations(tree) || applyAlgebraicOptimizations(tree) || applyComparisonOptimizations(tree)) {
3334 // ...
3335 }
3336
3337 while (applyAlgebraicCleanup(tree) || applyStrengthReduction(tree) || applyOpFusion(tree)) {
3338 // ...
3339 }
3340
3341 applyValueNumbering(tree);
3342
3343 tree.getRoot()->postorder([&](ExpressionTreeNode &node)
3344 {
3345 if (node.op.type == ExprOpType::MUX)
3346 return;
3347 if (found.find(node.valueNum) != found.end())
3348 return;
3349
3350 ExprInstruction opcode(node.op);
3351 opcode.dst = node.valueNum;
3352
3353 if (node.left) {
3354 assert(node.left->valueNum >= 0);
3355 opcode.src1 = node.left->valueNum;
3356 }
3357 if (node.right) {
3358 if (node.right->op.type == ExprOpType::MUX) {
3359 assert(node.right->left->valueNum >= 0);
3360 assert(node.right->right->valueNum >= 0);
3361 opcode.src2 = node.right->left->valueNum;
3362 opcode.src3 = node.right->right->valueNum;
3363 } else {
3364 assert(node.right->valueNum >= 0);
3365 opcode.src2 = node.right->valueNum;
3366 }
3367 }
3368
3369 code.push_back(opcode);
3370 found.insert(node.valueNum);
3371 });
3372
3373 ExprInstruction store(ExprOpType::MEM_STORE_U8);
3374
3375 if (format->sampleType == stInteger && format->bytesPerSample == 1)
3376 store.op.type = ExprOpType::MEM_STORE_U8;
3377 else if (format->sampleType == stInteger && format->bytesPerSample == 2)
3378 store.op.type = ExprOpType::MEM_STORE_U16;
3379 else if (format->sampleType == stFloat && format->bytesPerSample == 2)
3380 store.op.type = ExprOpType::MEM_STORE_F16;
3381 else if (format->sampleType == stFloat && format->bytesPerSample == 4)
3382 store.op.type = ExprOpType::MEM_STORE_F32;
3383
3384 if (store.op.type == ExprOpType::MEM_STORE_U16)
3385 store.op.imm.u = format->bitsPerSample;
3386
3387 store.src1 = code.back().dst;
3388 code.push_back(store);
3389
3390 renameRegisters(code);
3391 return code;
3392 }
3393
exprInit(VSMap * in,VSMap * out,void ** instanceData,VSNode * node,VSCore * core,const VSAPI * vsapi)3394 static void VS_CC exprInit(VSMap *in, VSMap *out, void **instanceData, VSNode *node, VSCore *core, const VSAPI *vsapi) {
3395 ExprData *d = static_cast<ExprData *>(*instanceData);
3396 vsapi->setVideoInfo(&d->vi, 1, node);
3397 }
3398
exprGetFrame(int n,int activationReason,void ** instanceData,void ** frameData,VSFrameContext * frameCtx,VSCore * core,const VSAPI * vsapi)3399 static const VSFrameRef *VS_CC exprGetFrame(int n, int activationReason, void **instanceData, void **frameData, VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi) {
3400 ExprData *d = static_cast<ExprData *>(*instanceData);
3401 int numInputs = d->numInputs;
3402
3403 if (activationReason == arInitial) {
3404 for (int i = 0; i < numInputs; i++)
3405 vsapi->requestFrameFilter(n, d->node[i], frameCtx);
3406 } else if (activationReason == arAllFramesReady) {
3407 const VSFrameRef *src[MAX_EXPR_INPUTS] = {};
3408 for (int i = 0; i < numInputs; i++)
3409 src[i] = vsapi->getFrameFilter(n, d->node[i], frameCtx);
3410
3411 const VSFormat *fi = d->vi.format;
3412 int height = vsapi->getFrameHeight(src[0], 0);
3413 int width = vsapi->getFrameWidth(src[0], 0);
3414 int planes[3] = { 0, 1, 2 };
3415 const VSFrameRef *srcf[3] = { d->plane[0] != poCopy ? nullptr : src[0], d->plane[1] != poCopy ? nullptr : src[0], d->plane[2] != poCopy ? nullptr : src[0] };
3416 VSFrameRef *dst = vsapi->newVideoFrame2(fi, width, height, srcf, planes, src[0], core);
3417
3418 const uint8_t *srcp[MAX_EXPR_INPUTS] = {};
3419 int src_stride[MAX_EXPR_INPUTS] = {};
3420 alignas(32) intptr_t ptroffsets[((MAX_EXPR_INPUTS + 1) + 7) & ~7] = { d->vi.format->bytesPerSample * 8 };
3421
3422 for (int plane = 0; plane < d->vi.format->numPlanes; plane++) {
3423 if (d->plane[plane] != poProcess)
3424 continue;
3425
3426 for (int i = 0; i < numInputs; i++) {
3427 if (d->node[i]) {
3428 srcp[i] = vsapi->getReadPtr(src[i], plane);
3429 src_stride[i] = vsapi->getStride(src[i], plane);
3430 ptroffsets[i + 1] = vsapi->getFrameFormat(src[i])->bytesPerSample * 8;
3431 }
3432 }
3433
3434 uint8_t *dstp = vsapi->getWritePtr(dst, plane);
3435 int dst_stride = vsapi->getStride(dst, plane);
3436 int h = vsapi->getFrameHeight(dst, plane);
3437 int w = vsapi->getFrameWidth(dst, plane);
3438
3439 if (d->proc[plane]) {
3440 ExprData::ProcessLineProc proc = d->proc[plane];
3441 int niterations = (w + 7) / 8;
3442
3443 for (int i = 0; i < numInputs; i++) {
3444 if (d->node[i])
3445 ptroffsets[i + 1] = vsapi->getFrameFormat(src[i])->bytesPerSample * 8;
3446 }
3447
3448 for (int y = 0; y < h; y++) {
3449 alignas(32) uint8_t *rwptrs[((MAX_EXPR_INPUTS + 1) + 7) & ~7] = { dstp + dst_stride * y };
3450 for (int i = 0; i < numInputs; i++) {
3451 rwptrs[i + 1] = const_cast<uint8_t *>(srcp[i] + src_stride[i] * y);
3452 }
3453 proc(rwptrs, ptroffsets, niterations);
3454 }
3455 } else {
3456 ExprInterpreter interpreter(d->bytecode[plane].data(), d->bytecode[plane].size());
3457
3458 for (int y = 0; y < h; y++) {
3459 for (int x = 0; x < w; x++) {
3460 interpreter.eval(srcp, dstp, x);
3461 }
3462
3463 for (int i = 0; i < numInputs; i++) {
3464 srcp[i] += src_stride[i];
3465 }
3466 dstp += dst_stride;
3467 }
3468 }
3469 }
3470
3471 for (int i = 0; i < MAX_EXPR_INPUTS; i++) {
3472 vsapi->freeFrame(src[i]);
3473 }
3474 return dst;
3475 }
3476
3477 return nullptr;
3478 }
3479
exprFree(void * instanceData,VSCore * core,const VSAPI * vsapi)3480 static void VS_CC exprFree(void *instanceData, VSCore *core, const VSAPI *vsapi) {
3481 ExprData *d = static_cast<ExprData *>(instanceData);
3482 for (int i = 0; i < MAX_EXPR_INPUTS; i++)
3483 vsapi->freeNode(d->node[i]);
3484 delete d;
3485 }
3486
exprCreate(const VSMap * in,VSMap * out,void * userData,VSCore * core,const VSAPI * vsapi)3487 static void VS_CC exprCreate(const VSMap *in, VSMap *out, void *userData, VSCore *core, const VSAPI *vsapi) {
3488 std::unique_ptr<ExprData> d(new ExprData);
3489 int err;
3490
3491 #ifdef VS_TARGET_CPU_X86
3492 const CPUFeatures &f = *getCPUFeatures();
3493 # define EXPR_F16C_TEST (f.f16c)
3494 #else
3495 # define EXPR_F16C_TEST (false)
3496 #endif
3497
3498 try {
3499 d->numInputs = vsapi->propNumElements(in, "clips");
3500 if (d->numInputs > 26)
3501 throw std::runtime_error("More than 26 input clips provided");
3502
3503 for (int i = 0; i < d->numInputs; i++) {
3504 d->node[i] = vsapi->propGetNode(in, "clips", i, &err);
3505 }
3506
3507 const VSVideoInfo *vi[MAX_EXPR_INPUTS] = {};
3508 for (int i = 0; i < d->numInputs; i++) {
3509 if (d->node[i])
3510 vi[i] = vsapi->getVideoInfo(d->node[i]);
3511 }
3512
3513 for (int i = 0; i < d->numInputs; i++) {
3514 if (!isConstantFormat(vi[i]))
3515 throw std::runtime_error("Only clips with constant format and dimensions allowed");
3516 if (vi[0]->format->numPlanes != vi[i]->format->numPlanes
3517 || vi[0]->format->subSamplingW != vi[i]->format->subSamplingW
3518 || vi[0]->format->subSamplingH != vi[i]->format->subSamplingH
3519 || vi[0]->width != vi[i]->width
3520 || vi[0]->height != vi[i]->height)
3521 {
3522 throw std::runtime_error("All inputs must have the same number of planes and the same dimensions, subsampling included");
3523 }
3524
3525 if (EXPR_F16C_TEST) {
3526 if ((vi[i]->format->bitsPerSample > 16 && vi[i]->format->sampleType == stInteger)
3527 || (vi[i]->format->bitsPerSample != 16 && vi[i]->format->bitsPerSample != 32 && vi[i]->format->sampleType == stFloat))
3528 throw std::runtime_error("Input clips must be 8-16 bit integer or 16/32 bit float format");
3529 } else {
3530 if ((vi[i]->format->bitsPerSample > 16 && vi[i]->format->sampleType == stInteger)
3531 || (vi[i]->format->bitsPerSample != 32 && vi[i]->format->sampleType == stFloat))
3532 throw std::runtime_error("Input clips must be 8-16 bit integer or 32 bit float format");
3533 }
3534 }
3535
3536 d->vi = *vi[0];
3537 int format = int64ToIntS(vsapi->propGetInt(in, "format", 0, &err));
3538 if (!err) {
3539 const VSFormat *f = vsapi->getFormatPreset(format, core);
3540 if (f) {
3541 if (d->vi.format->colorFamily == cmCompat)
3542 throw std::runtime_error("No compat formats allowed");
3543 if (d->vi.format->numPlanes != f->numPlanes)
3544 throw std::runtime_error("The number of planes in the inputs and output must match");
3545 d->vi.format = vsapi->registerFormat(d->vi.format->colorFamily, f->sampleType, f->bitsPerSample, d->vi.format->subSamplingW, d->vi.format->subSamplingH, core);
3546 }
3547 }
3548
3549 int nexpr = vsapi->propNumElements(in, "expr");
3550 if (nexpr > d->vi.format->numPlanes)
3551 throw std::runtime_error("More expressions given than there are planes");
3552
3553 std::string expr[3];
3554 for (int i = 0; i < nexpr; i++) {
3555 expr[i] = vsapi->propGetData(in, "expr", i, nullptr);
3556 }
3557 for (int i = nexpr; i < 3; ++i) {
3558 expr[i] = expr[nexpr - 1];
3559 }
3560
3561 for (int i = 0; i < d->vi.format->numPlanes; i++) {
3562 if (!expr[i].empty()) {
3563 d->plane[i] = poProcess;
3564 } else {
3565 if (d->vi.format->bitsPerSample == vi[0]->format->bitsPerSample && d->vi.format->sampleType == vi[0]->format->sampleType)
3566 d->plane[i] = poCopy;
3567 else
3568 d->plane[i] = poUndefined;
3569 }
3570
3571 if (d->plane[i] != poProcess)
3572 continue;
3573
3574 auto tree = parseExpr(expr[i], vi, d->numInputs);
3575 d->bytecode[i] = compile(tree, d->vi.format);
3576
3577 int cpulevel = vs_get_cpulevel(core);
3578 if (cpulevel > VS_CPU_LEVEL_NONE) {
3579 #ifdef VS_TARGET_CPU_X86
3580 std::unique_ptr<ExprCompiler> compiler = make_compiler(d->numInputs, cpulevel);
3581 for (auto op : d->bytecode[i]) {
3582 compiler->addInstruction(op);
3583 }
3584
3585 std::tie(d->proc[i], d->procSize[i]) = compiler->getCode();
3586 #endif
3587 }
3588 }
3589 #ifdef VS_TARGET_OS_WINDOWS
3590 FlushInstructionCache(GetCurrentProcess(), nullptr, 0);
3591 #endif
3592 } catch (std::runtime_error &e) {
3593 for (int i = 0; i < MAX_EXPR_INPUTS; i++) {
3594 vsapi->freeNode(d->node[i]);
3595 }
3596 vsapi->setError(out, (std::string{ "Expr: " } + e.what()).c_str());
3597 return;
3598 }
3599
3600 vsapi->createFilter(in, out, "Expr", exprInit, exprGetFrame, exprFree, fmParallel, 0, d.release(), core);
3601 }
3602
3603 } // namespace
3604
3605
3606 //////////////////////////////////////////
3607 // Init
3608
exprInitialize(VSConfigPlugin configFunc,VSRegisterFunction registerFunc,VSPlugin * plugin)3609 void VS_CC exprInitialize(VSConfigPlugin configFunc, VSRegisterFunction registerFunc, VSPlugin *plugin) {
3610 //configFunc("com.vapoursynth.expr", "expr", "VapourSynth Expr Filter", VAPOURSYNTH_API_VERSION, 1, plugin);
3611 registerFunc("Expr", "clips:clip[];expr:data[];format:int:opt;", exprCreate, nullptr, plugin);
3612 }
3613