1 // Copyright 2017 Citra Emulator Project
2 // Licensed under GPLv2 or any later version
3 // Refer to the license.txt file included.
4 
5 #include <exception>
6 #include <map>
7 #include <set>
8 #include <string>
9 #include <tuple>
10 #include <utility>
11 #include <fmt/format.h>
12 #include <nihstro/shader_bytecode.h>
13 #include "common/assert.h"
14 #include "common/common_types.h"
15 #include "video_core/renderer_opengl/gl_shader_decompiler.h"
16 
17 namespace OpenGL::ShaderDecompiler {
18 
19 using nihstro::Instruction;
20 using nihstro::OpCode;
21 using nihstro::RegisterType;
22 using nihstro::SourceRegister;
23 using nihstro::SwizzlePattern;
24 
25 constexpr u32 PROGRAM_END = Pica::Shader::MAX_PROGRAM_CODE_LENGTH;
26 
27 class DecompileFail : public std::runtime_error {
28 public:
29     using std::runtime_error::runtime_error;
30 };
31 
32 /// Describes the behaviour of code path of a given entry point and a return point.
33 enum class ExitMethod {
34     Undetermined, ///< Internal value. Only occur when analyzing JMP loop.
35     AlwaysReturn, ///< All code paths reach the return point.
36     Conditional,  ///< Code path reaches the return point or an END instruction conditionally.
37     AlwaysEnd,    ///< All code paths reach a END instruction.
38 };
39 
40 /// A subroutine is a range of code refereced by a CALL, IF or LOOP instruction.
41 struct Subroutine {
42     /// Generates a name suitable for GLSL source code.
GetNameOpenGL::ShaderDecompiler::Subroutine43     std::string GetName() const {
44         return "sub_" + std::to_string(begin) + "_" + std::to_string(end);
45     }
46 
47     u32 begin;              ///< Entry point of the subroutine.
48     u32 end;                ///< Return point of the subroutine.
49     ExitMethod exit_method; ///< Exit method of the subroutine.
50     std::set<u32> labels;   ///< Addresses refereced by JMP instructions.
51 
operator <OpenGL::ShaderDecompiler::Subroutine52     bool operator<(const Subroutine& rhs) const {
53         return std::tie(begin, end) < std::tie(rhs.begin, rhs.end);
54     }
55 };
56 
57 /// Analyzes shader code and produces a set of subroutines.
58 class ControlFlowAnalyzer {
59 public:
ControlFlowAnalyzer(const Pica::Shader::ProgramCode & program_code,u32 main_offset)60     ControlFlowAnalyzer(const Pica::Shader::ProgramCode& program_code, u32 main_offset)
61         : program_code(program_code) {
62 
63         // Recursively finds all subroutines.
64         const Subroutine& program_main = AddSubroutine(main_offset, PROGRAM_END);
65         if (program_main.exit_method != ExitMethod::AlwaysEnd)
66             throw DecompileFail("Program does not always end");
67     }
68 
MoveSubroutines()69     std::set<Subroutine> MoveSubroutines() {
70         return std::move(subroutines);
71     }
72 
73 private:
74     const Pica::Shader::ProgramCode& program_code;
75     std::set<Subroutine> subroutines;
76     std::map<std::pair<u32, u32>, ExitMethod> exit_method_map;
77 
78     /// Adds and analyzes a new subroutine if it is not added yet.
AddSubroutine(u32 begin,u32 end)79     const Subroutine& AddSubroutine(u32 begin, u32 end) {
80         auto iter = subroutines.find(Subroutine{begin, end});
81         if (iter != subroutines.end())
82             return *iter;
83 
84         Subroutine subroutine{begin, end};
85         subroutine.exit_method = Scan(begin, end, subroutine.labels);
86         if (subroutine.exit_method == ExitMethod::Undetermined)
87             throw DecompileFail("Recursive function detected");
88         return *subroutines.insert(std::move(subroutine)).first;
89     }
90 
91     /// Merges exit method of two parallel branches.
ParallelExit(ExitMethod a,ExitMethod b)92     static ExitMethod ParallelExit(ExitMethod a, ExitMethod b) {
93         if (a == ExitMethod::Undetermined) {
94             return b;
95         }
96         if (b == ExitMethod::Undetermined) {
97             return a;
98         }
99         if (a == b) {
100             return a;
101         }
102         return ExitMethod::Conditional;
103     }
104 
105     /// Cascades exit method of two blocks of code.
SeriesExit(ExitMethod a,ExitMethod b)106     static ExitMethod SeriesExit(ExitMethod a, ExitMethod b) {
107         // This should be handled before evaluating b.
108         DEBUG_ASSERT(a != ExitMethod::AlwaysEnd);
109 
110         if (a == ExitMethod::Undetermined) {
111             return ExitMethod::Undetermined;
112         }
113 
114         if (a == ExitMethod::AlwaysReturn) {
115             return b;
116         }
117 
118         if (b == ExitMethod::Undetermined || b == ExitMethod::AlwaysEnd) {
119             return ExitMethod::AlwaysEnd;
120         }
121 
122         return ExitMethod::Conditional;
123     }
124 
125     /// Scans a range of code for labels and determines the exit method.
Scan(u32 begin,u32 end,std::set<u32> & labels)126     ExitMethod Scan(u32 begin, u32 end, std::set<u32>& labels) {
127         auto [iter, inserted] =
128             exit_method_map.emplace(std::make_pair(begin, end), ExitMethod::Undetermined);
129         ExitMethod& exit_method = iter->second;
130         if (!inserted)
131             return exit_method;
132 
133         for (u32 offset = begin; offset != end && offset != PROGRAM_END; ++offset) {
134             const Instruction instr = {program_code[offset]};
135             switch (instr.opcode.Value()) {
136             case OpCode::Id::END: {
137                 return exit_method = ExitMethod::AlwaysEnd;
138             }
139             case OpCode::Id::JMPC:
140             case OpCode::Id::JMPU: {
141                 labels.insert(instr.flow_control.dest_offset);
142                 ExitMethod no_jmp = Scan(offset + 1, end, labels);
143                 ExitMethod jmp = Scan(instr.flow_control.dest_offset, end, labels);
144                 return exit_method = ParallelExit(no_jmp, jmp);
145             }
146             case OpCode::Id::CALL: {
147                 auto& call = AddSubroutine(instr.flow_control.dest_offset,
148                                            instr.flow_control.dest_offset +
149                                                instr.flow_control.num_instructions);
150                 if (call.exit_method == ExitMethod::AlwaysEnd)
151                     return exit_method = ExitMethod::AlwaysEnd;
152                 ExitMethod after_call = Scan(offset + 1, end, labels);
153                 return exit_method = SeriesExit(call.exit_method, after_call);
154             }
155             case OpCode::Id::LOOP: {
156                 auto& loop = AddSubroutine(offset + 1, instr.flow_control.dest_offset + 1);
157                 if (loop.exit_method == ExitMethod::AlwaysEnd)
158                     return exit_method = ExitMethod::AlwaysEnd;
159                 ExitMethod after_loop = Scan(instr.flow_control.dest_offset + 1, end, labels);
160                 return exit_method = SeriesExit(loop.exit_method, after_loop);
161             }
162             case OpCode::Id::CALLC:
163             case OpCode::Id::CALLU: {
164                 auto& call = AddSubroutine(instr.flow_control.dest_offset,
165                                            instr.flow_control.dest_offset +
166                                                instr.flow_control.num_instructions);
167                 ExitMethod after_call = Scan(offset + 1, end, labels);
168                 return exit_method = SeriesExit(
169                            ParallelExit(call.exit_method, ExitMethod::AlwaysReturn), after_call);
170             }
171             case OpCode::Id::IFU:
172             case OpCode::Id::IFC: {
173                 auto& if_sub = AddSubroutine(offset + 1, instr.flow_control.dest_offset);
174                 ExitMethod else_method;
175                 if (instr.flow_control.num_instructions != 0) {
176                     auto& else_sub = AddSubroutine(instr.flow_control.dest_offset,
177                                                    instr.flow_control.dest_offset +
178                                                        instr.flow_control.num_instructions);
179                     else_method = else_sub.exit_method;
180                 } else {
181                     else_method = ExitMethod::AlwaysReturn;
182                 }
183 
184                 ExitMethod both = ParallelExit(if_sub.exit_method, else_method);
185                 if (both == ExitMethod::AlwaysEnd)
186                     return exit_method = ExitMethod::AlwaysEnd;
187                 ExitMethod after_call =
188                     Scan(instr.flow_control.dest_offset + instr.flow_control.num_instructions, end,
189                          labels);
190                 return exit_method = SeriesExit(both, after_call);
191             }
192             default:
193                 break;
194             }
195         }
196         return exit_method = ExitMethod::AlwaysReturn;
197     }
198 };
199 
200 class ShaderWriter {
201 public:
202     // Forwards all arguments directly to libfmt.
203     // Note that all formatting requirements for fmt must be
204     // obeyed when using this function. (e.g. {{ must be used
205     // printing the character '{' is desirable. Ditto for }} and '}',
206     // etc).
207     template <typename... Args>
AddLine(std::string_view text,Args &&...args)208     void AddLine(std::string_view text, Args&&... args) {
209         AddExpression(fmt::format(text, std::forward<Args>(args)...));
210         AddNewLine();
211     }
212 
AddNewLine()213     void AddNewLine() {
214         DEBUG_ASSERT(scope >= 0);
215         shader_source += '\n';
216     }
217 
MoveResult()218     std::string MoveResult() {
219         return std::move(shader_source);
220     }
221 
222     int scope = 0;
223 
224 private:
AddExpression(std::string_view text)225     void AddExpression(std::string_view text) {
226         if (!text.empty()) {
227             shader_source.append(static_cast<std::size_t>(scope) * 4, ' ');
228         }
229         shader_source += text;
230     }
231 
232     std::string shader_source;
233 };
234 
235 /// An adaptor for getting swizzle pattern string from nihstro interfaces.
236 template <SwizzlePattern::Selector (SwizzlePattern::*getter)(int) const>
GetSelectorSrc(const SwizzlePattern & pattern)237 std::string GetSelectorSrc(const SwizzlePattern& pattern) {
238     std::string out;
239     for (int i = 0; i < 4; ++i) {
240         switch ((pattern.*getter)(i)) {
241         case SwizzlePattern::Selector::x:
242             out += 'x';
243             break;
244         case SwizzlePattern::Selector::y:
245             out += 'y';
246             break;
247         case SwizzlePattern::Selector::z:
248             out += 'z';
249             break;
250         case SwizzlePattern::Selector::w:
251             out += 'w';
252             break;
253         default:
254             UNREACHABLE();
255             return "";
256         }
257     }
258     return out;
259 }
260 
261 constexpr auto GetSelectorSrc1 = GetSelectorSrc<&SwizzlePattern::GetSelectorSrc1>;
262 constexpr auto GetSelectorSrc2 = GetSelectorSrc<&SwizzlePattern::GetSelectorSrc2>;
263 constexpr auto GetSelectorSrc3 = GetSelectorSrc<&SwizzlePattern::GetSelectorSrc3>;
264 
265 class GLSLGenerator {
266 public:
GLSLGenerator(const std::set<Subroutine> & subroutines,const Pica::Shader::ProgramCode & program_code,const Pica::Shader::SwizzleData & swizzle_data,u32 main_offset,const RegGetter & inputreg_getter,const RegGetter & outputreg_getter,bool sanitize_mul)267     GLSLGenerator(const std::set<Subroutine>& subroutines,
268                   const Pica::Shader::ProgramCode& program_code,
269                   const Pica::Shader::SwizzleData& swizzle_data, u32 main_offset,
270                   const RegGetter& inputreg_getter, const RegGetter& outputreg_getter,
271                   bool sanitize_mul)
272         : subroutines(subroutines), program_code(program_code), swizzle_data(swizzle_data),
273           main_offset(main_offset), inputreg_getter(inputreg_getter),
274           outputreg_getter(outputreg_getter), sanitize_mul(sanitize_mul) {
275 
276         Generate();
277     }
278 
MoveShaderCode()279     std::string MoveShaderCode() {
280         return shader.MoveResult();
281     }
282 
283 private:
284     /// Gets the Subroutine object corresponding to the specified address.
GetSubroutine(u32 begin,u32 end) const285     const Subroutine& GetSubroutine(u32 begin, u32 end) const {
286         auto iter = subroutines.find(Subroutine{begin, end});
287         ASSERT(iter != subroutines.end());
288         return *iter;
289     }
290 
291     /// Generates condition evaluation code for the flow control instruction.
EvaluateCondition(Instruction::FlowControlType flow_control)292     static std::string EvaluateCondition(Instruction::FlowControlType flow_control) {
293         using Op = Instruction::FlowControlType::Op;
294 
295         const std::string_view result_x =
296             flow_control.refx.Value() ? "conditional_code.x" : "!conditional_code.x";
297         const std::string_view result_y =
298             flow_control.refy.Value() ? "conditional_code.y" : "!conditional_code.y";
299 
300         switch (flow_control.op) {
301         case Op::JustX:
302             return std::string(result_x);
303         case Op::JustY:
304             return std::string(result_y);
305         case Op::Or:
306         case Op::And: {
307             const std::string_view and_or = flow_control.op == Op::Or ? "any" : "all";
308             std::string bvec;
309             if (flow_control.refx.Value() && flow_control.refy.Value()) {
310                 bvec = "conditional_code";
311             } else if (!flow_control.refx.Value() && !flow_control.refy.Value()) {
312                 bvec = "not(conditional_code)";
313             } else {
314                 bvec = fmt::format("bvec2({}, {})", result_x, result_y);
315             }
316             return fmt::format("{}({})", and_or, bvec);
317         }
318         default:
319             UNREACHABLE();
320             return "";
321         }
322     }
323 
324     /// Generates code representing a source register.
GetSourceRegister(const SourceRegister & source_reg,u32 address_register_index) const325     std::string GetSourceRegister(const SourceRegister& source_reg,
326                                   u32 address_register_index) const {
327         const u32 index = static_cast<u32>(source_reg.GetIndex());
328 
329         switch (source_reg.GetRegisterType()) {
330         case RegisterType::Input:
331             return inputreg_getter(index);
332         case RegisterType::Temporary:
333             return fmt::format("reg_tmp{}", index);
334         case RegisterType::FloatUniform:
335             if (address_register_index != 0) {
336                 return fmt::format("uniforms.f[{} + address_registers.{}]", index,
337                                    "xyz"[address_register_index - 1]);
338             }
339             return fmt::format("uniforms.f[{}]", index);
340         default:
341             UNREACHABLE();
342             return "";
343         }
344     }
345 
346     /// Generates code representing a destination register.
GetDestRegister(const DestRegister & dest_reg) const347     std::string GetDestRegister(const DestRegister& dest_reg) const {
348         const u32 index = static_cast<u32>(dest_reg.GetIndex());
349 
350         switch (dest_reg.GetRegisterType()) {
351         case RegisterType::Output:
352             return outputreg_getter(index);
353         case RegisterType::Temporary:
354             return fmt::format("reg_tmp{}", index);
355         default:
356             UNREACHABLE();
357             return "";
358         }
359     }
360 
361     /// Generates code representing a bool uniform
GetUniformBool(u32 index) const362     std::string GetUniformBool(u32 index) const {
363         return fmt::format("uniforms.b[{}]", index);
364     }
365 
366     /**
367      * Adds code that calls a subroutine.
368      * @param subroutine the subroutine to call.
369      */
CallSubroutine(const Subroutine & subroutine)370     void CallSubroutine(const Subroutine& subroutine) {
371         if (subroutine.exit_method == ExitMethod::AlwaysEnd) {
372             shader.AddLine("{}();", subroutine.GetName());
373             shader.AddLine("return true;");
374         } else if (subroutine.exit_method == ExitMethod::Conditional) {
375             shader.AddLine("if ({}()) {{ return true; }}", subroutine.GetName());
376         } else {
377             shader.AddLine("{}();", subroutine.GetName());
378         }
379     }
380 
381     /**
382      * Writes code that does an assignment operation.
383      * @param swizzle the swizzle data of the current instruction.
384      * @param reg the destination register code.
385      * @param value the code representing the value to assign.
386      * @param dest_num_components number of components of the destination register.
387      * @param value_num_components number of components of the value to assign.
388      */
SetDest(const SwizzlePattern & swizzle,std::string_view reg,std::string_view value,u32 dest_num_components,u32 value_num_components)389     void SetDest(const SwizzlePattern& swizzle, std::string_view reg, std::string_view value,
390                  u32 dest_num_components, u32 value_num_components) {
391         u32 dest_mask_num_components = 0;
392         std::string dest_mask_swizzle = ".";
393 
394         for (u32 i = 0; i < dest_num_components; ++i) {
395             if (swizzle.DestComponentEnabled(static_cast<int>(i))) {
396                 dest_mask_swizzle += "xyzw"[i];
397                 ++dest_mask_num_components;
398             }
399         }
400 
401         if (reg.empty() || dest_mask_num_components == 0) {
402             return;
403         }
404         DEBUG_ASSERT(value_num_components >= dest_num_components || value_num_components == 1);
405 
406         const std::string dest =
407             fmt::format("{}{}", reg, dest_num_components != 1 ? dest_mask_swizzle : "");
408 
409         std::string src{value};
410         if (value_num_components == 1) {
411             if (dest_mask_num_components != 1) {
412                 src = fmt::format("vec{}({})", dest_mask_num_components, value);
413             }
414         } else if (value_num_components != dest_mask_num_components) {
415             src = fmt::format("({}){}", value, dest_mask_swizzle);
416         }
417 
418         shader.AddLine("{} = {};", dest, src);
419     }
420 
421     /**
422      * Compiles a single instruction from PICA to GLSL.
423      * @param offset the offset of the PICA shader instruction.
424      * @return the offset of the next instruction to execute. Usually it is the current offset + 1.
425      * If the current instruction is IF or LOOP, the next instruction is after the IF or LOOP block.
426      * If the current instruction always terminates the program, returns PROGRAM_END.
427      */
CompileInstr(u32 offset)428     u32 CompileInstr(u32 offset) {
429         const Instruction instr = {program_code[offset]};
430 
431         std::size_t swizzle_offset =
432             instr.opcode.Value().GetInfo().type == OpCode::Type::MultiplyAdd
433                 ? instr.mad.operand_desc_id
434                 : instr.common.operand_desc_id;
435         const SwizzlePattern swizzle = {swizzle_data[swizzle_offset]};
436 
437         shader.AddLine("// {}: {}", offset, instr.opcode.Value().GetInfo().name);
438 
439         switch (instr.opcode.Value().GetInfo().type) {
440         case OpCode::Type::Arithmetic: {
441             const bool is_inverted =
442                 (0 != (instr.opcode.Value().GetInfo().subtype & OpCode::Info::SrcInversed));
443 
444             std::string src1 = swizzle.negate_src1 ? "-" : "";
445             src1 += GetSourceRegister(instr.common.GetSrc1(is_inverted),
446                                       !is_inverted * instr.common.address_register_index);
447             src1 += "." + GetSelectorSrc1(swizzle);
448 
449             std::string src2 = swizzle.negate_src2 ? "-" : "";
450             src2 += GetSourceRegister(instr.common.GetSrc2(is_inverted),
451                                       is_inverted * instr.common.address_register_index);
452             src2 += "." + GetSelectorSrc2(swizzle);
453 
454             std::string dest_reg = GetDestRegister(instr.common.dest.Value());
455 
456             switch (instr.opcode.Value().EffectiveOpCode()) {
457             case OpCode::Id::ADD: {
458                 SetDest(swizzle, dest_reg, fmt::format("{} + {}", src1, src2), 4, 4);
459                 break;
460             }
461 
462             case OpCode::Id::MUL: {
463                 if (sanitize_mul) {
464                     SetDest(swizzle, dest_reg, fmt::format("sanitize_mul({}, {})", src1, src2), 4,
465                             4);
466                 } else {
467                     SetDest(swizzle, dest_reg, fmt::format("{} * {}", src1, src2), 4, 4);
468                 }
469                 break;
470             }
471 
472             case OpCode::Id::FLR: {
473                 SetDest(swizzle, dest_reg, fmt::format("floor({})", src1), 4, 4);
474                 break;
475             }
476 
477             case OpCode::Id::MAX: {
478                 SetDest(swizzle, dest_reg, fmt::format("max({}, {})", src1, src2), 4, 4);
479                 break;
480             }
481 
482             case OpCode::Id::MIN: {
483                 SetDest(swizzle, dest_reg, fmt::format("min({}, {})", src1, src2), 4, 4);
484                 break;
485             }
486 
487             case OpCode::Id::DP3:
488             case OpCode::Id::DP4:
489             case OpCode::Id::DPH:
490             case OpCode::Id::DPHI: {
491                 OpCode::Id opcode = instr.opcode.Value().EffectiveOpCode();
492                 std::string dot;
493                 if (opcode == OpCode::Id::DP3) {
494                     if (sanitize_mul) {
495                         dot = fmt::format("dot(vec3(sanitize_mul({}, {})), vec3(1.0))", src1, src2);
496                     } else {
497                         dot = fmt::format("dot(vec3({}), vec3({}))", src1, src2);
498                     }
499                 } else {
500                     if (sanitize_mul) {
501                         const std::string src1_ =
502                             (opcode == OpCode::Id::DPH || opcode == OpCode::Id::DPHI)
503                                 ? fmt::format("vec4({}.xyz, 1.0)", src1)
504                                 : std::move(src1);
505 
506                         dot = fmt::format("dot(sanitize_mul({}, {}), vec4(1.0))", src1_, src2);
507                     } else {
508                         dot = fmt::format("dot({}, {})", src1, src2);
509                     }
510                 }
511 
512                 SetDest(swizzle, dest_reg, dot, 4, 1);
513                 break;
514             }
515 
516             case OpCode::Id::RCP: {
517                 if (!sanitize_mul) {
518                     // When accurate multiplication is OFF, NaN are not really handled. This is a
519                     // workaround to cheaply avoid NaN. Fixes graphical issues in Ocarina of Time.
520                     shader.AddLine("if ({}.x != 0.0)", src1);
521                 }
522                 SetDest(swizzle, dest_reg, fmt::format("(1.0 / {}.x)", src1), 4, 1);
523                 break;
524             }
525 
526             case OpCode::Id::RSQ: {
527                 if (!sanitize_mul) {
528                     // When accurate multiplication is OFF, NaN are not really handled. This is a
529                     // workaround to cheaply avoid NaN. Fixes graphical issues in Ocarina of Time.
530                     shader.AddLine("if ({}.x > 0.0)", src1);
531                 }
532                 SetDest(swizzle, dest_reg, fmt::format("inversesqrt({}.x)", src1), 4, 1);
533                 break;
534             }
535 
536             case OpCode::Id::MOVA: {
537                 SetDest(swizzle, "address_registers", fmt::format("ivec2({})", src1), 2, 2);
538                 break;
539             }
540 
541             case OpCode::Id::MOV: {
542                 SetDest(swizzle, dest_reg, src1, 4, 4);
543                 break;
544             }
545 
546             case OpCode::Id::SGE:
547             case OpCode::Id::SGEI: {
548                 SetDest(swizzle, dest_reg,
549                         fmt::format("vec4(greaterThanEqual({}, {}))", src1, src2), 4, 4);
550                 break;
551             }
552 
553             case OpCode::Id::SLT:
554             case OpCode::Id::SLTI: {
555                 SetDest(swizzle, dest_reg, fmt::format("vec4(lessThan({}, {}))", src1, src2), 4, 4);
556                 break;
557             }
558 
559             case OpCode::Id::CMP: {
560                 using CompareOp = Instruction::Common::CompareOpType::Op;
561                 const std::map<CompareOp, std::pair<std::string_view, std::string_view>> cmp_ops{
562                     {CompareOp::Equal, {"==", "equal"}},
563                     {CompareOp::NotEqual, {"!=", "notEqual"}},
564                     {CompareOp::LessThan, {"<", "lessThan"}},
565                     {CompareOp::LessEqual, {"<=", "lessThanEqual"}},
566                     {CompareOp::GreaterThan, {">", "greaterThan"}},
567                     {CompareOp::GreaterEqual, {">=", "greaterThanEqual"}},
568                 };
569 
570                 const CompareOp op_x = instr.common.compare_op.x.Value();
571                 const CompareOp op_y = instr.common.compare_op.y.Value();
572 
573                 if (cmp_ops.find(op_x) == cmp_ops.end()) {
574                     LOG_ERROR(HW_GPU, "Unknown compare mode {:x}", op_x);
575                 } else if (cmp_ops.find(op_y) == cmp_ops.end()) {
576                     LOG_ERROR(HW_GPU, "Unknown compare mode {:x}", op_y);
577                 } else if (op_x != op_y) {
578                     shader.AddLine("conditional_code.x = {}.x {} {}.x;", src1,
579                                    cmp_ops.find(op_x)->second.first, src2);
580                     shader.AddLine("conditional_code.y = {}.y {} {}.y;", src1,
581                                    cmp_ops.find(op_y)->second.first, src2);
582                 } else {
583                     shader.AddLine("conditional_code = {}(vec2({}), vec2({}));",
584                                    cmp_ops.find(op_x)->second.second, src1, src2);
585                 }
586                 break;
587             }
588 
589             case OpCode::Id::EX2: {
590                 SetDest(swizzle, dest_reg, fmt::format("exp2({}.x)", src1), 4, 1);
591                 break;
592             }
593 
594             case OpCode::Id::LG2: {
595                 SetDest(swizzle, dest_reg, fmt::format("log2({}.x)", src1), 4, 1);
596                 break;
597             }
598 
599             default: {
600                 LOG_ERROR(HW_GPU, "Unhandled arithmetic instruction: 0x{:02x} ({}): 0x{:08x}",
601                           (int)instr.opcode.Value().EffectiveOpCode(),
602                           instr.opcode.Value().GetInfo().name, instr.hex);
603                 throw DecompileFail("Unhandled instruction");
604                 break;
605             }
606             }
607 
608             break;
609         }
610 
611         case OpCode::Type::MultiplyAdd: {
612             if ((instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MAD) ||
613                 (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MADI)) {
614                 bool is_inverted = (instr.opcode.Value().EffectiveOpCode() == OpCode::Id::MADI);
615 
616                 std::string src1 = swizzle.negate_src1 ? "-" : "";
617                 src1 += GetSourceRegister(instr.mad.GetSrc1(is_inverted), 0);
618                 src1 += "." + GetSelectorSrc1(swizzle);
619 
620                 std::string src2 = swizzle.negate_src2 ? "-" : "";
621                 src2 += GetSourceRegister(instr.mad.GetSrc2(is_inverted),
622                                           !is_inverted * instr.mad.address_register_index);
623                 src2 += "." + GetSelectorSrc2(swizzle);
624 
625                 std::string src3 = swizzle.negate_src3 ? "-" : "";
626                 src3 += GetSourceRegister(instr.mad.GetSrc3(is_inverted),
627                                           is_inverted * instr.mad.address_register_index);
628                 src3 += "." + GetSelectorSrc3(swizzle);
629 
630                 std::string dest_reg =
631                     (instr.mad.dest.Value() < 0x10)
632                         ? outputreg_getter(static_cast<u32>(instr.mad.dest.Value().GetIndex()))
633                         : (instr.mad.dest.Value() < 0x20)
634                               ? "reg_tmp" + std::to_string(instr.mad.dest.Value().GetIndex())
635                               : "";
636 
637                 if (sanitize_mul) {
638                     SetDest(swizzle, dest_reg,
639                             fmt::format("sanitize_mul({}, {}) + {}", src1, src2, src3), 4, 4);
640                 } else {
641                     SetDest(swizzle, dest_reg, fmt::format("{} * {} + {}", src1, src2, src3), 4, 4);
642                 }
643             } else {
644                 LOG_ERROR(HW_GPU, "Unhandled multiply-add instruction: 0x{:02x} ({}): 0x{:08x}",
645                           (int)instr.opcode.Value().EffectiveOpCode(),
646                           instr.opcode.Value().GetInfo().name, instr.hex);
647                 throw DecompileFail("Unhandled instruction");
648             }
649             break;
650         }
651 
652         default: {
653             switch (instr.opcode.Value()) {
654             case OpCode::Id::END: {
655                 shader.AddLine("return true;");
656                 offset = PROGRAM_END - 1;
657                 break;
658             }
659 
660             case OpCode::Id::JMPC:
661             case OpCode::Id::JMPU: {
662                 std::string condition;
663                 if (instr.opcode.Value() == OpCode::Id::JMPC) {
664                     condition = EvaluateCondition(instr.flow_control);
665                 } else {
666                     bool invert_test = instr.flow_control.num_instructions & 1;
667                     condition = (invert_test ? "!" : "") +
668                                 GetUniformBool(instr.flow_control.bool_uniform_id);
669                 }
670 
671                 shader.AddLine("if ({}) {{", condition);
672                 ++shader.scope;
673                 shader.AddLine("{{ jmp_to = {}u; break; }}",
674                                instr.flow_control.dest_offset.Value());
675 
676                 --shader.scope;
677                 shader.AddLine("}}");
678                 break;
679             }
680 
681             case OpCode::Id::CALL:
682             case OpCode::Id::CALLC:
683             case OpCode::Id::CALLU: {
684                 std::string condition;
685                 if (instr.opcode.Value() == OpCode::Id::CALLC) {
686                     condition = EvaluateCondition(instr.flow_control);
687                 } else if (instr.opcode.Value() == OpCode::Id::CALLU) {
688                     condition = GetUniformBool(instr.flow_control.bool_uniform_id);
689                 }
690 
691                 if (condition.empty()) {
692                     shader.AddLine("{{");
693                 } else {
694                     shader.AddLine("if ({}) {{", condition);
695                 }
696                 ++shader.scope;
697 
698                 auto& call_sub = GetSubroutine(instr.flow_control.dest_offset,
699                                                instr.flow_control.dest_offset +
700                                                    instr.flow_control.num_instructions);
701 
702                 CallSubroutine(call_sub);
703                 if (instr.opcode.Value() == OpCode::Id::CALL &&
704                     call_sub.exit_method == ExitMethod::AlwaysEnd) {
705                     offset = PROGRAM_END - 1;
706                 }
707 
708                 --shader.scope;
709                 shader.AddLine("}}");
710                 break;
711             }
712 
713             case OpCode::Id::NOP: {
714                 break;
715             }
716 
717             case OpCode::Id::IFC:
718             case OpCode::Id::IFU: {
719                 std::string condition;
720                 if (instr.opcode.Value() == OpCode::Id::IFC) {
721                     condition = EvaluateCondition(instr.flow_control);
722                 } else {
723                     condition = GetUniformBool(instr.flow_control.bool_uniform_id);
724                 }
725 
726                 const u32 if_offset = offset + 1;
727                 const u32 else_offset = instr.flow_control.dest_offset;
728                 const u32 endif_offset =
729                     instr.flow_control.dest_offset + instr.flow_control.num_instructions;
730 
731                 shader.AddLine("if ({}) {{", condition);
732                 ++shader.scope;
733 
734                 auto& if_sub = GetSubroutine(if_offset, else_offset);
735                 CallSubroutine(if_sub);
736                 offset = else_offset - 1;
737 
738                 if (instr.flow_control.num_instructions != 0) {
739                     --shader.scope;
740                     shader.AddLine("}} else {{");
741                     ++shader.scope;
742 
743                     auto& else_sub = GetSubroutine(else_offset, endif_offset);
744                     CallSubroutine(else_sub);
745                     offset = endif_offset - 1;
746 
747                     if (if_sub.exit_method == ExitMethod::AlwaysEnd &&
748                         else_sub.exit_method == ExitMethod::AlwaysEnd) {
749                         offset = PROGRAM_END - 1;
750                     }
751                 }
752 
753                 --shader.scope;
754                 shader.AddLine("}}");
755                 break;
756             }
757 
758             case OpCode::Id::LOOP: {
759                 const std::string int_uniform =
760                     fmt::format("uniforms.i[{}]", instr.flow_control.int_uniform_id.Value());
761 
762                 shader.AddLine("address_registers.z = int({}.y);", int_uniform);
763 
764                 const std::string loop_var = fmt::format("loop{}", offset);
765                 shader.AddLine(
766                     "for (uint {} = 0u; {} <= {}.x; address_registers.z += int({}.z), ++{}) {{",
767                     loop_var, loop_var, int_uniform, int_uniform, loop_var);
768                 ++shader.scope;
769 
770                 auto& loop_sub = GetSubroutine(offset + 1, instr.flow_control.dest_offset + 1);
771                 CallSubroutine(loop_sub);
772                 offset = instr.flow_control.dest_offset;
773 
774                 --shader.scope;
775                 shader.AddLine("}}");
776 
777                 if (loop_sub.exit_method == ExitMethod::AlwaysEnd) {
778                     offset = PROGRAM_END - 1;
779                 }
780 
781                 break;
782             }
783 
784             case OpCode::Id::EMIT:
785             case OpCode::Id::SETEMIT:
786                 LOG_ERROR(HW_GPU, "Geometry shader operation detected in vertex shader");
787                 break;
788 
789             default: {
790                 LOG_ERROR(HW_GPU, "Unhandled instruction: 0x{:02x} ({}): 0x{:08x}",
791                           (int)instr.opcode.Value().EffectiveOpCode(),
792                           instr.opcode.Value().GetInfo().name, instr.hex);
793                 throw DecompileFail("Unhandled instruction");
794                 break;
795             }
796             }
797 
798             break;
799         }
800         }
801         return offset + 1;
802     }
803 
804     /**
805      * Compiles a range of instructions from PICA to GLSL.
806      * @param begin the offset of the starting instruction.
807      * @param end the offset where the compilation should stop (exclusive).
808      * @return the offset of the next instruction to compile. PROGRAM_END if the program terminates.
809      */
CompileRange(u32 begin,u32 end)810     u32 CompileRange(u32 begin, u32 end) {
811         u32 program_counter;
812         for (program_counter = begin; program_counter < (begin > end ? PROGRAM_END : end);) {
813             program_counter = CompileInstr(program_counter);
814         }
815         return program_counter;
816     }
817 
Generate()818     void Generate() {
819         if (sanitize_mul) {
820 #ifdef ANDROID
821             // Use a cheaper sanitize_mul on Android, as mobile GPUs struggle here
822             // This seems to be sufficient at least for Ocarina of Time and Attack on Titan accurate
823             // multiplication bugs
824             shader.AddLine(
825                 "#define sanitize_mul(lhs, rhs) mix(lhs * rhs, vec4(0.0), isnan(lhs * rhs))");
826 #else
827             shader.AddLine("vec4 sanitize_mul(vec4 lhs, vec4 rhs) {{");
828             ++shader.scope;
829             shader.AddLine("vec4 product = lhs * rhs;");
830             shader.AddLine("return mix(product, mix(mix(vec4(0.0), product, isnan(rhs)), product, "
831                            "isnan(lhs)), isnan(product));");
832             --shader.scope;
833             shader.AddLine("}}\n");
834 #endif
835         }
836 
837         // Add declarations for registers
838         shader.AddLine("bvec2 conditional_code = bvec2(false);");
839         shader.AddLine("ivec3 address_registers = ivec3(0);");
840         for (int i = 0; i < 16; ++i) {
841             shader.AddLine("vec4 reg_tmp{} = vec4(0.0, 0.0, 0.0, 1.0);", i);
842         }
843         shader.AddNewLine();
844 
845         // Add declarations for all subroutines
846         for (const auto& subroutine : subroutines) {
847             shader.AddLine("bool {}();", subroutine.GetName());
848         }
849         shader.AddNewLine();
850 
851         // Add the main entry point
852         shader.AddLine("bool exec_shader() {{");
853         ++shader.scope;
854         CallSubroutine(GetSubroutine(main_offset, PROGRAM_END));
855         --shader.scope;
856         shader.AddLine("}}\n");
857 
858         // Add definitions for all subroutines
859         for (const auto& subroutine : subroutines) {
860             std::set<u32> labels = subroutine.labels;
861 
862             shader.AddLine("bool {}() {{", subroutine.GetName());
863             ++shader.scope;
864 
865             if (labels.empty()) {
866                 if (CompileRange(subroutine.begin, subroutine.end) != PROGRAM_END) {
867                     shader.AddLine("return false;");
868                 }
869             } else {
870                 labels.insert(subroutine.begin);
871                 shader.AddLine("uint jmp_to = {}u;", subroutine.begin);
872                 shader.AddLine("while (true) {{");
873                 ++shader.scope;
874 
875                 shader.AddLine("switch (jmp_to) {{");
876 
877                 for (auto label : labels) {
878                     shader.AddLine("case {}u: {{", label);
879                     ++shader.scope;
880 
881                     auto next_it = labels.lower_bound(label + 1);
882                     u32 next_label = next_it == labels.end() ? subroutine.end : *next_it;
883 
884                     u32 compile_end = CompileRange(label, next_label);
885                     if (compile_end > next_label && compile_end != PROGRAM_END) {
886                         // This happens only when there is a label inside a IF/LOOP block
887                         shader.AddLine("{{ jmp_to = {}u; break; }}", compile_end);
888                         labels.emplace(compile_end);
889                     }
890 
891                     --shader.scope;
892                     shader.AddLine("}}");
893                 }
894 
895                 shader.AddLine("default: return false;");
896                 shader.AddLine("}}");
897 
898                 --shader.scope;
899                 shader.AddLine("}}");
900 
901                 shader.AddLine("return false;");
902             }
903 
904             --shader.scope;
905             shader.AddLine("}}\n");
906 
907             DEBUG_ASSERT(shader.scope == 0);
908         }
909     }
910 
911 private:
912     const std::set<Subroutine>& subroutines;
913     const Pica::Shader::ProgramCode& program_code;
914     const Pica::Shader::SwizzleData& swizzle_data;
915     const u32 main_offset;
916     const RegGetter& inputreg_getter;
917     const RegGetter& outputreg_getter;
918     const bool sanitize_mul;
919 
920     ShaderWriter shader;
921 };
922 
GetCommonDeclarations()923 std::string GetCommonDeclarations() {
924     return R"(
925 struct pica_uniforms {
926     bool b[16];
927     uvec4 i[4];
928     vec4 f[96];
929 };
930 
931 bool exec_shader();
932 
933 )";
934 }
935 
DecompileProgram(const Pica::Shader::ProgramCode & program_code,const Pica::Shader::SwizzleData & swizzle_data,u32 main_offset,const RegGetter & inputreg_getter,const RegGetter & outputreg_getter,bool sanitize_mul)936 std::optional<ProgramResult> DecompileProgram(const Pica::Shader::ProgramCode& program_code,
937                                               const Pica::Shader::SwizzleData& swizzle_data,
938                                               u32 main_offset, const RegGetter& inputreg_getter,
939                                               const RegGetter& outputreg_getter,
940                                               bool sanitize_mul) {
941 
942     try {
943         auto subroutines = ControlFlowAnalyzer(program_code, main_offset).MoveSubroutines();
944         GLSLGenerator generator(subroutines, program_code, swizzle_data, main_offset,
945                                 inputreg_getter, outputreg_getter, sanitize_mul);
946         return {ProgramResult{generator.MoveShaderCode()}};
947     } catch (const DecompileFail& exception) {
948         LOG_INFO(HW_GPU, "Shader decompilation failed: {}", exception.what());
949         return std::nullopt;
950     }
951 }
952 
953 } // namespace OpenGL::ShaderDecompiler
954