1 /*
2  * This file is part of libplacebo.
3  *
4  * libplacebo is free software; you can redistribute it and/or
5  * modify it under the terms of the GNU Lesser General Public
6  * License as published by the Free Software Foundation; either
7  * version 2.1 of the License, or (at your option) any later version.
8  *
9  * libplacebo is distributed in the hope that it will be useful,
10  * but WITHOUT ANY WARRANTY; without even the implied warranty of
11  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  * GNU Lesser General Public License for more details.
13  *
14  * You should have received a copy of the GNU Lesser General Public
15  * License along with libplacebo. If not, see <http://www.gnu.org/licenses/>.
16  */
17 
18 #include <math.h>
19 #include "gpu.h"
20 #include "shaders.h"
21 
pl_shader_custom(pl_shader sh,const struct pl_custom_shader * params)22 bool pl_shader_custom(pl_shader sh, const struct pl_custom_shader *params)
23 {
24     if (params->compute) {
25         int bw = PL_DEF(params->compute_group_size[0], 16);
26         int bh = PL_DEF(params->compute_group_size[1], 16);
27         bool flex = !params->compute_group_size[0] ||
28                     !params->compute_group_size[1];
29         if (!sh_try_compute(sh, bw, bh, flex, params->compute_shmem))
30             return false;
31     }
32 
33     if (!sh_require(sh, params->input, params->output_w, params->output_h))
34         return false;
35 
36     sh->res.output = params->output;
37 
38     // Attach the variables, descriptors etc. directly instead of going via
39     // `sh_var` / `sh_desc` etc. to avoid generating fresh names
40     for (int i = 0; i < params->num_variables; i++) {
41         struct pl_shader_var sv = params->variables[i];
42         sv.data = pl_memdup(SH_TMP(sh), sv.data, pl_var_host_layout(0, &sv.var).size);
43         sv.var.name = pl_strdup0(SH_TMP(sh), pl_str0(sv.var.name));
44         PL_ARRAY_APPEND(sh, sh->vars, sv);
45     }
46 
47     for (int i = 0; i < params->num_descriptors; i++) {
48         struct pl_shader_desc sd = params->descriptors[i];
49         size_t bsize = sizeof(sd.buffer_vars[0]) * sd.num_buffer_vars;
50         if (bsize)
51             sd.buffer_vars = pl_memdup(SH_TMP(sh), sd.buffer_vars, bsize);
52         sd.desc.name = pl_strdup0(SH_TMP(sh), pl_str0(sd.desc.name));
53         PL_ARRAY_APPEND(sh, sh->descs, sd);
54     }
55 
56     for (int i = 0; i < params->num_vertex_attribs; i++) {
57         struct pl_shader_va sva = params->vertex_attribs[i];
58         size_t vsize = sva.attr.fmt->texel_size;
59         for (int n = 0; n < PL_ARRAY_SIZE(sva.data); n++)
60             sva.data[n] = pl_memdup(SH_TMP(sh), sva.data[n], vsize);
61         sva.attr.name = pl_strdup0(SH_TMP(sh), pl_str0(sva.attr.name));
62         PL_ARRAY_APPEND(sh, sh->vas, sva);
63     }
64 
65     for (int i = 0; i < params->num_constants; i++) {
66         struct pl_shader_const sc = params->constants[i];
67         size_t csize = pl_var_type_size(sc.type);
68         sc.data = pl_memdup(SH_TMP(sh), sc.data, csize);
69         sc.name = pl_strdup0(SH_TMP(sh), pl_str0(sc.name));
70         PL_ARRAY_APPEND(sh, sh->consts, sc);
71     }
72 
73     if (params->prelude)
74         GLSLP("// pl_shader_custom prelude: \n%s\n", params->prelude);
75     if (params->header)
76         GLSLH("// pl_shader_custom header: \n%s\n", params->header);
77 
78     if (params->description)
79         sh_describe(sh, pl_strdup0(SH_TMP(sh), pl_str0(params->description)));
80 
81     if (params->body) {
82         const char *output_decl = "";
83         if (params->output != params->input) {
84             switch (params->output) {
85             case PL_SHADER_SIG_NONE: break;
86             case PL_SHADER_SIG_COLOR:
87                 output_decl = "vec4 color = vec4(0.0);";
88                 break;
89 
90             case PL_SHADER_SIG_SAMPLER:
91                 pl_unreachable();
92             }
93         }
94 
95         GLSL("// pl_shader_custom \n"
96              "%s                  \n"
97              "{                   \n"
98              "%s                  \n"
99              "}                   \n",
100              output_decl, params->body);
101     }
102 
103     return true;
104 }
105 
106 // Hard-coded size limits, mainly for convenience (to avoid dynamic memory)
107 #define SHADER_MAX_HOOKS 16
108 #define SHADER_MAX_BINDS 16
109 #define MAX_SZEXP_SIZE 32
110 
111 enum szexp_op {
112     SZEXP_OP_ADD,
113     SZEXP_OP_SUB,
114     SZEXP_OP_MUL,
115     SZEXP_OP_DIV,
116     SZEXP_OP_NOT,
117     SZEXP_OP_GT,
118     SZEXP_OP_LT,
119 };
120 
121 enum szexp_tag {
122     SZEXP_END = 0, // End of an RPN expression
123     SZEXP_CONST, // Push a constant value onto the stack
124     SZEXP_VAR_W, // Get the width/height of a named texture (variable)
125     SZEXP_VAR_H,
126     SZEXP_OP2, // Pop two elements and push the result of a dyadic operation
127     SZEXP_OP1, // Pop one element and push the result of a monadic operation
128 };
129 
130 struct szexp {
131     enum szexp_tag tag;
132     union {
133         float cval;
134         pl_str varname;
135         enum szexp_op op;
136     } val;
137 };
138 
139 struct custom_shader_hook {
140     // Variable/literal names of textures
141     pl_str pass_desc;
142     pl_str hook_tex[SHADER_MAX_HOOKS];
143     pl_str bind_tex[SHADER_MAX_BINDS];
144     pl_str save_tex;
145 
146     // Shader body itself + metadata
147     pl_str pass_body;
148     float offset[2];
149     bool offset_align;
150     int comps;
151 
152     // Special expressions governing the output size and execution conditions
153     struct szexp width[MAX_SZEXP_SIZE];
154     struct szexp height[MAX_SZEXP_SIZE];
155     struct szexp cond[MAX_SZEXP_SIZE];
156 
157     // Special metadata for compute shaders
158     bool is_compute;
159     int block_w, block_h;       // Block size (each block corresponds to one WG)
160     int threads_w, threads_h;   // How many threads form a WG
161 };
162 
parse_rpn_szexpr(pl_str line,struct szexp out[MAX_SZEXP_SIZE])163 static bool parse_rpn_szexpr(pl_str line, struct szexp out[MAX_SZEXP_SIZE])
164 {
165     int pos = 0;
166 
167     while (line.len > 0) {
168         pl_str word = pl_str_split_char(line, ' ', &line);
169         if (word.len == 0)
170             continue;
171 
172         if (pos >= MAX_SZEXP_SIZE)
173             return false;
174 
175         struct szexp *exp = &out[pos++];
176 
177         if (pl_str_eatend0(&word, ".w") || pl_str_eatend0(&word, ".width")) {
178             exp->tag = SZEXP_VAR_W;
179             exp->val.varname = word;
180             continue;
181         }
182 
183         if (pl_str_eatend0(&word, ".h") || pl_str_eatend0(&word, ".height")) {
184             exp->tag = SZEXP_VAR_H;
185             exp->val.varname = word;
186             continue;
187         }
188 
189         switch (word.buf[0]) {
190         case '+': exp->tag = SZEXP_OP2; exp->val.op = SZEXP_OP_ADD; continue;
191         case '-': exp->tag = SZEXP_OP2; exp->val.op = SZEXP_OP_SUB; continue;
192         case '*': exp->tag = SZEXP_OP2; exp->val.op = SZEXP_OP_MUL; continue;
193         case '/': exp->tag = SZEXP_OP2; exp->val.op = SZEXP_OP_DIV; continue;
194         case '!': exp->tag = SZEXP_OP1; exp->val.op = SZEXP_OP_NOT; continue;
195         case '>': exp->tag = SZEXP_OP2; exp->val.op = SZEXP_OP_GT;  continue;
196         case '<': exp->tag = SZEXP_OP2; exp->val.op = SZEXP_OP_LT;  continue;
197         }
198 
199         if (word.buf[0] >= '0' && word.buf[0] <= '9') {
200             exp->tag = SZEXP_CONST;
201             if (!pl_str_parse_float(word, &exp->val.cval))
202                 return false;
203             continue;
204         }
205 
206         // Some sort of illegal expression
207         return false;
208     }
209 
210     return true;
211 }
212 
213 // Evaluate a `szexp`, given a lookup function for named textures
214 // Returns whether successful. 'result' is left untouched on failure
pl_eval_szexpr(pl_log log,void * priv,bool (* lookup)(void * priv,pl_str var,float size[2]),const struct szexp expr[MAX_SZEXP_SIZE],float * result)215 static bool pl_eval_szexpr(pl_log log, void *priv,
216                            bool (*lookup)(void *priv, pl_str var, float size[2]),
217                            const struct szexp expr[MAX_SZEXP_SIZE],
218                            float *result)
219 {
220     float stack[MAX_SZEXP_SIZE] = {0};
221     int idx = 0; // points to next element to push
222 
223     for (int i = 0; i < MAX_SZEXP_SIZE; i++) {
224         switch (expr[i].tag) {
225         case SZEXP_END:
226             goto done;
227 
228         case SZEXP_CONST:
229             // Since our SZEXPs are bound by MAX_SZEXP_SIZE, it should be
230             // impossible to overflow the stack
231             assert(idx < MAX_SZEXP_SIZE);
232             stack[idx++] = expr[i].val.cval;
233             continue;
234 
235         case SZEXP_OP1:
236             if (idx < 1) {
237                 pl_warn(log, "Stack underflow in RPN expression!");
238                 return false;
239             }
240 
241             switch (expr[i].val.op) {
242             case SZEXP_OP_NOT: stack[idx-1] = !stack[idx-1]; break;
243             default: pl_unreachable();
244             }
245             continue;
246 
247         case SZEXP_OP2:
248             if (idx < 2) {
249                 pl_warn(log, "Stack underflow in RPN expression!");
250                 return false;
251             }
252 
253             // Pop the operands in reverse order
254             float op2 = stack[--idx];
255             float op1 = stack[--idx];
256             float res = 0.0;
257             switch (expr[i].val.op) {
258             case SZEXP_OP_ADD: res = op1 + op2; break;
259             case SZEXP_OP_SUB: res = op1 - op2; break;
260             case SZEXP_OP_MUL: res = op1 * op2; break;
261             case SZEXP_OP_DIV: res = op1 / op2; break;
262             case SZEXP_OP_GT:  res = op1 > op2; break;
263             case SZEXP_OP_LT:  res = op1 < op2; break;
264             case SZEXP_OP_NOT: pl_unreachable();
265             }
266 
267             if (!isfinite(res)) {
268                 pl_warn(log, "Illegal operation in RPN expression!");
269                 return false;
270             }
271 
272             stack[idx++] = res;
273             continue;
274 
275         case SZEXP_VAR_W:
276         case SZEXP_VAR_H: {
277             pl_str name = expr[i].val.varname;
278             float size[2];
279 
280             if (!lookup(priv, name, size)) {
281                 pl_warn(log, "Variable '%.*s' not found in RPN expression!",
282                         PL_STR_FMT(name));
283                 return false;
284             }
285 
286             stack[idx++] = (expr[i].tag == SZEXP_VAR_W) ? size[0] : size[1];
287             continue;
288             }
289         }
290     }
291 
292 done:
293     // Return the single stack element
294     if (idx != 1) {
295         pl_warn(log, "Malformed stack after RPN expression!");
296         return false;
297     }
298 
299     *result = stack[0];
300     return true;
301 }
302 
split_magic(pl_str * body)303 static inline pl_str split_magic(pl_str *body)
304 {
305     pl_str ret = pl_str_split_str0(*body, "//!", body);
306     if (body->len) {
307         // Make sure the separator is included in the remainder
308         body->buf -= 3;
309         body->len += 3;
310     }
311 
312     return ret;
313 }
314 
parse_hook(pl_log log,pl_str * body,struct custom_shader_hook * out)315 static bool parse_hook(pl_log log, pl_str *body, struct custom_shader_hook *out)
316 {
317     *out = (struct custom_shader_hook){
318         .pass_desc = pl_str0("unknown user shader"),
319         .width = {{ SZEXP_VAR_W, { .varname = pl_str0("HOOKED") }}},
320         .height = {{ SZEXP_VAR_H, { .varname = pl_str0("HOOKED") }}},
321         .cond = {{ SZEXP_CONST, { .cval = 1.0 }}},
322     };
323 
324     int hook_idx = 0;
325     int bind_idx = 0;
326 
327     // Parse all headers
328     while (true) {
329         pl_str rest;
330         pl_str line = pl_str_strip(pl_str_getline(*body, &rest));
331 
332         // Check for the presence of the magic line beginning
333         if (!pl_str_eatstart0(&line, "//!"))
334             break;
335 
336         *body = rest;
337 
338         // Parse the supported commands
339         if (pl_str_eatstart0(&line, "HOOK")) {
340             if (hook_idx == SHADER_MAX_HOOKS) {
341                 pl_err(log, "Passes may only hook up to %d textures!",
342                        SHADER_MAX_HOOKS);
343                 return false;
344             }
345             out->hook_tex[hook_idx++] = pl_str_strip(line);
346             continue;
347         }
348 
349         if (pl_str_eatstart0(&line, "BIND")) {
350             if (bind_idx == SHADER_MAX_BINDS) {
351                 pl_err(log, "Passes may only bind up to %d textures!",
352                        SHADER_MAX_BINDS);
353                 return false;
354             }
355             out->bind_tex[bind_idx++] = pl_str_strip(line);
356             continue;
357         }
358 
359         if (pl_str_eatstart0(&line, "SAVE")) {
360             pl_str save_tex = pl_str_strip(line);
361             if (pl_str_equals0(save_tex, "HOOKED")) {
362                 // This is a special name that means "overwrite existing"
363                 // texture, which we just signal by not having any `save_tex`
364                 // name set.
365                 out->save_tex = (pl_str) {0};
366             } else {
367                 out->save_tex = save_tex;
368             };
369             continue;
370         }
371 
372         if (pl_str_eatstart0(&line, "DESC")) {
373             out->pass_desc = pl_str_strip(line);
374             continue;
375         }
376 
377         if (pl_str_eatstart0(&line, "OFFSET")) {
378             line = pl_str_strip(line);
379             if (pl_str_equals0(line, "ALIGN")) {
380                 out->offset_align = true;
381             } else {
382                 if (!pl_str_parse_float(pl_str_split_char(line, ' ', &line), &out->offset[0]) ||
383                     !pl_str_parse_float(pl_str_split_char(line, ' ', &line), &out->offset[1]) ||
384                     line.len)
385                 {
386                     pl_err(log, "Error while parsing OFFSET!");
387                     return false;
388                 }
389             }
390             continue;
391         }
392 
393         if (pl_str_eatstart0(&line, "WIDTH")) {
394             if (!parse_rpn_szexpr(line, out->width)) {
395                 pl_err(log, "Error while parsing WIDTH!");
396                 return false;
397             }
398             continue;
399         }
400 
401         if (pl_str_eatstart0(&line, "HEIGHT")) {
402             if (!parse_rpn_szexpr(line, out->height)) {
403                 pl_err(log, "Error while parsing HEIGHT!");
404                 return false;
405             }
406             continue;
407         }
408 
409         if (pl_str_eatstart0(&line, "WHEN")) {
410             if (!parse_rpn_szexpr(line, out->cond)) {
411                 pl_err(log, "Error while parsing WHEN!");
412                 return false;
413             }
414             continue;
415         }
416 
417         if (pl_str_eatstart0(&line, "COMPONENTS")) {
418             if (!pl_str_parse_int(pl_str_strip(line), &out->comps)) {
419                 pl_err(log, "Error parsing COMPONENTS: '%.*s'", PL_STR_FMT(line));
420                 return false;
421             }
422             continue;
423         }
424 
425         if (pl_str_eatstart0(&line, "COMPUTE")) {
426             line = pl_str_strip(line);
427             bool ok = pl_str_parse_int(pl_str_split_char(line, ' ', &line), &out->block_w) &&
428                       pl_str_parse_int(pl_str_split_char(line, ' ', &line), &out->block_h);
429 
430             line = pl_str_strip(line);
431             if (ok && line.len) {
432                 ok = pl_str_parse_int(pl_str_split_char(line, ' ', &line), &out->threads_w) &&
433                      pl_str_parse_int(pl_str_split_char(line, ' ', &line), &out->threads_h) &&
434                      !line.len;
435             } else {
436                 out->threads_w = out->block_w;
437                 out->threads_h = out->block_h;
438             }
439 
440             if (!ok) {
441                 pl_err(log, "Error while parsing COMPUTE!");
442                 return false;
443             }
444 
445             out->is_compute = true;
446             continue;
447         }
448 
449         // Unknown command type
450         pl_err(log, "Unrecognized command '%.*s'!", PL_STR_FMT(line));
451         return false;
452     }
453 
454     // The rest of the file up until the next magic line beginning (if any)
455     // shall be the shader body
456     out->pass_body = split_magic(body);
457 
458     // Sanity checking
459     if (hook_idx == 0)
460         pl_warn(log, "Pass has no hooked textures (will be ignored)!");
461 
462     return true;
463 }
464 
parse_tex(pl_gpu gpu,void * alloc,pl_str * body,struct pl_shader_desc * out)465 static bool parse_tex(pl_gpu gpu, void *alloc, pl_str *body,
466                       struct pl_shader_desc *out)
467 {
468     *out = (struct pl_shader_desc) {
469         .desc = {
470             .name = "USER_TEX",
471             .type = PL_DESC_SAMPLED_TEX,
472         },
473     };
474 
475     struct pl_tex_params params = {
476         .w = 1, .h = 1, .d = 0,
477         .sampleable = true,
478     };
479 
480     while (true) {
481         pl_str rest;
482         pl_str line = pl_str_strip(pl_str_getline(*body, &rest));
483 
484         if (!pl_str_eatstart0(&line, "//!"))
485             break;
486 
487         *body = rest;
488 
489         if (pl_str_eatstart0(&line, "TEXTURE")) {
490             out->desc.name = pl_strdup0(alloc, pl_str_strip(line));
491             continue;
492         }
493 
494         if (pl_str_eatstart0(&line, "SIZE")) {
495             line = pl_str_strip(line);
496             int dims = 0;
497             int dim[4]; // extra space to catch invalid extra entries
498             while (line.len && dims < PL_ARRAY_SIZE(dim)) {
499                 if (!pl_str_parse_int(pl_str_split_char(line, ' ', &line), &dim[dims++])) {
500                     PL_ERR(gpu, "Error while parsing SIZE!");
501                     return false;
502                 }
503             }
504 
505             uint32_t lim = dims == 1 ? gpu->limits.max_tex_1d_dim
506                          : dims == 2 ? gpu->limits.max_tex_2d_dim
507                          : dims == 3 ? gpu->limits.max_tex_3d_dim
508                          : 0;
509 
510             // Sanity check against GPU size limits
511             switch (dims) {
512             case 3:
513                 params.d = dim[2];
514                 if (params.d < 1 || params.d > lim) {
515                     PL_ERR(gpu, "SIZE %d exceeds GPU's texture size limits (%d)!",
516                            params.d, lim);
517                     return false;
518                 }
519                 // fall through
520             case 2:
521                 params.h = dim[1];
522                 if (params.h < 1 || params.h > lim) {
523                     PL_ERR(gpu, "SIZE %d exceeds GPU's texture size limits (%d)!",
524                            params.h, lim);
525                     return false;
526                 }
527                 // fall through
528             case 1:
529                 params.w = dim[0];
530                 if (params.w < 1 || params.w > lim) {
531                     PL_ERR(gpu, "SIZE %d exceeds GPU's texture size limits (%d)!",
532                            params.w, lim);
533                     return false;
534                 }
535                 break;
536 
537             default:
538                 PL_ERR(gpu, "Invalid number of texture dimensions!");
539                 return false;
540             };
541 
542             // Clear out the superfluous components
543             if (dims < 3)
544                 params.d = 0;
545             if (dims < 2)
546                 params.h = 0;
547             continue;
548         }
549 
550         if (pl_str_eatstart0(&line, "FORMAT ")) {
551             line = pl_str_strip(line);
552             params.format = NULL;
553             for (int n = 0; n < gpu->num_formats; n++) {
554                 pl_fmt fmt = gpu->formats[n];
555                 if (pl_str_equals0(line, fmt->name)) {
556                     params.format = fmt;
557                     break;
558                 }
559             }
560 
561             if (!params.format || params.format->opaque) {
562                 PL_ERR(gpu, "Unrecognized/unavailable FORMAT name: '%.*s'!",
563                        PL_STR_FMT(line));
564                 return false;
565             }
566 
567             if (!(params.format->caps & PL_FMT_CAP_SAMPLEABLE)) {
568                 PL_ERR(gpu, "Chosen FORMAT '%.*s' is not sampleable!",
569                        PL_STR_FMT(line));
570                 return false;
571             }
572             continue;
573         }
574 
575         if (pl_str_eatstart0(&line, "FILTER")) {
576             line = pl_str_strip(line);
577             if (pl_str_equals0(line, "LINEAR")) {
578                 out->binding.sample_mode = PL_TEX_SAMPLE_LINEAR;
579             } else if (pl_str_equals0(line, "NEAREST")) {
580                 out->binding.sample_mode = PL_TEX_SAMPLE_NEAREST;
581             } else {
582                 PL_ERR(gpu, "Unrecognized FILTER: '%.*s'!", PL_STR_FMT(line));
583                 return false;
584             }
585             continue;
586         }
587 
588         if (pl_str_eatstart0(&line, "BORDER")) {
589             line = pl_str_strip(line);
590             if (pl_str_equals0(line, "CLAMP")) {
591                 out->binding.address_mode = PL_TEX_ADDRESS_CLAMP;
592             } else if (pl_str_equals0(line, "REPEAT")) {
593                 out->binding.address_mode = PL_TEX_ADDRESS_REPEAT;
594             } else if (pl_str_equals0(line, "MIRROR")) {
595                 out->binding.address_mode = PL_TEX_ADDRESS_MIRROR;
596             } else {
597                 PL_ERR(gpu, "Unrecognized BORDER: '%.*s'!", PL_STR_FMT(line));
598                 return false;
599             }
600             continue;
601         }
602 
603         if (pl_str_eatstart0(&line, "STORAGE")) {
604             params.storable = true;
605             out->desc.type = PL_DESC_STORAGE_IMG;
606             out->desc.access = PL_DESC_ACCESS_READWRITE;
607             out->memory = PL_MEMORY_COHERENT;
608             continue;
609         }
610 
611         PL_ERR(gpu, "Unrecognized command '%.*s'!", PL_STR_FMT(line));
612         return false;
613     }
614 
615     if (!params.format) {
616         PL_ERR(gpu, "No FORMAT specified!");
617         return false;
618     }
619 
620     int caps = params.format->caps;
621     if (out->binding.sample_mode == PL_TEX_SAMPLE_LINEAR && !(caps & PL_FMT_CAP_LINEAR)) {
622         PL_ERR(gpu, "The specified texture format cannot be linear filtered!");
623         return false;
624     }
625 
626     // Decode the rest of the section (up to the next //! marker) as raw hex
627     // data for the texture
628     pl_str tex, hexdata = split_magic(body);
629     if (!pl_str_decode_hex(NULL, pl_str_strip(hexdata), &tex)) {
630         PL_ERR(gpu, "Error while parsing TEXTURE body: must be a valid "
631                     "hexadecimal sequence!");
632         return false;
633     }
634 
635     int texels = params.w * PL_DEF(params.h, 1) * PL_DEF(params.d, 1);
636     size_t expected_len = texels * params.format->texel_size;
637     if (tex.len == 0 && params.storable) {
638         // In this case, it's okay that the texture has no initial data
639         pl_free_ptr(&tex.buf);
640     } else if (tex.len != expected_len) {
641         PL_ERR(gpu, "Shader TEXTURE size mismatch: got %zu bytes, expected %zu!",
642                tex.len, expected_len);
643         pl_free(tex.buf);
644         return false;
645     }
646 
647     params.initial_data = tex.buf;
648     out->binding.object = pl_tex_create(gpu, &params);
649     pl_free(tex.buf);
650 
651     if (!out->binding.object) {
652         PL_ERR(gpu, "Failed creating custom texture!");
653         return false;
654     }
655 
656     return true;
657 }
658 
parse_buf(pl_gpu gpu,void * alloc,pl_str * body,struct pl_shader_desc * out)659 static bool parse_buf(pl_gpu gpu, void *alloc, pl_str *body,
660                       struct pl_shader_desc *out)
661 {
662     *out = (struct pl_shader_desc) {
663         .desc = {
664             .name = "USER_BUF",
665             .type = PL_DESC_BUF_UNIFORM,
666         },
667     };
668 
669     // Temporary, to allow deferring variable placement until all headers
670     // have been processed (in order to e.g. determine buffer type)
671     void *tmp = pl_tmp(alloc); // will be freed automatically on failure
672     PL_ARRAY(struct pl_var) vars = {0};
673 
674     while (true) {
675         pl_str rest;
676         pl_str line = pl_str_strip(pl_str_getline(*body, &rest));
677 
678         if (!pl_str_eatstart0(&line, "//!"))
679             break;
680 
681         *body = rest;
682 
683         if (pl_str_eatstart0(&line, "BUFFER")) {
684             out->desc.name = pl_strdup0(alloc, pl_str_strip(line));
685             continue;
686         }
687 
688         if (pl_str_eatstart0(&line, "STORAGE")) {
689             out->desc.type = PL_DESC_BUF_STORAGE;
690             out->desc.access = PL_DESC_ACCESS_READWRITE;
691             out->memory = PL_MEMORY_COHERENT;
692             continue;
693         }
694 
695         if (pl_str_eatstart0(&line, "VAR")) {
696             pl_str type_name = pl_str_split_char(pl_str_strip(line), ' ', &line);
697             struct pl_var var = {0};
698             for (const struct pl_named_var *nv = pl_var_glsl_types; nv->glsl_name; nv++) {
699                 if (pl_str_equals0(type_name, nv->glsl_name)) {
700                     var = nv->var;
701                     break;
702                 }
703             }
704 
705             if (!var.type) {
706                 // No type found
707                 PL_ERR(gpu, "Unrecognized GLSL type '%.*s'!", PL_STR_FMT(type_name));
708                 return false;
709             }
710 
711             pl_str var_name = pl_str_split_char(line, '[', &line);
712             if (line.len > 0) {
713                 // Parse array dimension
714                 if (!pl_str_parse_int(pl_str_split_char(line, ']', NULL), &var.dim_a)) {
715                     PL_ERR(gpu, "Failed parsing array dimension from [%.*s!",
716                            PL_STR_FMT(line));
717                     return false;
718                 }
719 
720                 if (var.dim_a < 1) {
721                     PL_ERR(gpu, "Invalid array dimension %d!", var.dim_a);
722                     return false;
723                 }
724             }
725 
726             var.name = pl_strdup0(alloc, pl_str_strip(var_name));
727             PL_ARRAY_APPEND(tmp, vars, var);
728             continue;
729         }
730 
731         PL_ERR(gpu, "Unrecognized command '%.*s'!", PL_STR_FMT(line));
732         return false;
733     }
734 
735     // Try placing all of the buffer variables
736     for (int i = 0; i < vars.num; i++) {
737         if (!sh_buf_desc_append(alloc, gpu, out, NULL, vars.elem[i])) {
738             PL_ERR(gpu, "Custom buffer exceeds GPU limitations!");
739             return false;
740         }
741     }
742 
743     // Decode the rest of the section (up to the next //! marker) as raw hex
744     // data for the buffer
745     pl_str data, hexdata = split_magic(body);
746     if (!pl_str_decode_hex(tmp, pl_str_strip(hexdata), &data)) {
747         PL_ERR(gpu, "Error while parsing BUFFER body: must be a valid "
748                     "hexadecimal sequence!");
749         return false;
750     }
751 
752     size_t buf_size = sh_buf_desc_size(out);
753     if (data.len == 0 && out->desc.type == PL_DESC_BUF_STORAGE) {
754         // In this case, it's okay that the buffer has no initial data
755     } else if (data.len != buf_size) {
756         PL_ERR(gpu, "Shader BUFFER size mismatch: got %zu bytes, expected %zu!",
757                data.len, buf_size);
758         return false;
759     }
760 
761     out->binding.object = pl_buf_create(gpu, &(struct pl_buf_params) {
762         .size = buf_size,
763         .uniform = out->desc.type == PL_DESC_BUF_UNIFORM,
764         .storable = out->desc.type == PL_DESC_BUF_STORAGE,
765         .initial_data = data.len ? data.buf : NULL,
766     });
767 
768     if (!out->binding.object) {
769         PL_ERR(gpu, "Failed creating custom buffer!");
770         return false;
771     }
772 
773     pl_free(tmp);
774     return true;
775 }
776 
mp_stage_to_pl(pl_str stage)777 static enum pl_hook_stage mp_stage_to_pl(pl_str stage)
778 {
779     if (pl_str_equals0(stage, "RGB"))
780         return PL_HOOK_RGB_INPUT;
781     if (pl_str_equals0(stage, "LUMA"))
782         return PL_HOOK_LUMA_INPUT;
783     if (pl_str_equals0(stage, "CHROMA"))
784         return PL_HOOK_CHROMA_INPUT;
785     if (pl_str_equals0(stage, "ALPHA"))
786         return PL_HOOK_ALPHA_INPUT;
787     if (pl_str_equals0(stage, "XYZ"))
788         return PL_HOOK_XYZ_INPUT;
789 
790     if (pl_str_equals0(stage, "CHROMA_SCALED"))
791         return PL_HOOK_CHROMA_SCALED;
792     if (pl_str_equals0(stage, "ALPHA_SCALED"))
793         return PL_HOOK_ALPHA_SCALED;
794 
795     if (pl_str_equals0(stage, "NATIVE"))
796         return PL_HOOK_NATIVE;
797     if (pl_str_equals0(stage, "MAINPRESUB"))
798         return PL_HOOK_RGB;
799     if (pl_str_equals0(stage, "MAIN"))
800         return PL_HOOK_RGB; // Note: conflicts with above!
801 
802     if (pl_str_equals0(stage, "LINEAR"))
803         return PL_HOOK_LINEAR;
804     if (pl_str_equals0(stage, "SIGMOID"))
805         return PL_HOOK_SIGMOID;
806     if (pl_str_equals0(stage, "PREKERNEL"))
807         return PL_HOOK_PRE_KERNEL;
808     if (pl_str_equals0(stage, "POSTKERNEL"))
809         return PL_HOOK_POST_KERNEL;
810 
811     if (pl_str_equals0(stage, "SCALED"))
812         return PL_HOOK_SCALED;
813     if (pl_str_equals0(stage, "OUTPUT"))
814         return PL_HOOK_OUTPUT;
815 
816     return 0;
817 }
818 
pl_stage_to_mp(enum pl_hook_stage stage)819 static pl_str pl_stage_to_mp(enum pl_hook_stage stage)
820 {
821     switch (stage) {
822     case PL_HOOK_RGB_INPUT:     return pl_str0("RGB");
823     case PL_HOOK_LUMA_INPUT:    return pl_str0("LUMA");
824     case PL_HOOK_CHROMA_INPUT:  return pl_str0("CHROMA");
825     case PL_HOOK_ALPHA_INPUT:   return pl_str0("ALPHA");
826     case PL_HOOK_XYZ_INPUT:     return pl_str0("XYZ");
827 
828     case PL_HOOK_CHROMA_SCALED: return pl_str0("CHROMA_SCALED");
829     case PL_HOOK_ALPHA_SCALED:  return pl_str0("ALPHA_SCALED");
830 
831     case PL_HOOK_NATIVE:        return pl_str0("NATIVE");
832     case PL_HOOK_RGB:           return pl_str0("MAINPRESUB");
833 
834     case PL_HOOK_LINEAR:        return pl_str0("LINEAR");
835     case PL_HOOK_SIGMOID:       return pl_str0("SIGMOID");
836     case PL_HOOK_PRE_OVERLAY:   return pl_str0("PREOVERLAY"); // Note: doesn't exist!
837     case PL_HOOK_PRE_KERNEL:    return pl_str0("PREKERNEL");
838     case PL_HOOK_POST_KERNEL:   return pl_str0("POSTKERNEL");
839 
840     case PL_HOOK_SCALED:        return pl_str0("SCALED");
841     case PL_HOOK_OUTPUT:        return pl_str0("OUTPUT");
842     };
843 
844     pl_unreachable();
845 }
846 
847 struct hook_pass {
848     enum pl_hook_stage exec_stages;
849     struct custom_shader_hook hook;
850 };
851 
852 struct pass_tex {
853     pl_str name;
854     pl_tex tex;
855 
856     // Metadata
857     struct pl_rect2df rect;
858     struct pl_color_repr repr;
859     struct pl_color_space color;
860     int comps;
861 };
862 
863 struct hook_priv {
864     pl_log log;
865     pl_gpu gpu;
866     void *alloc;
867 
868     PL_ARRAY(struct hook_pass) hook_passes;
869 
870     // Fixed (for shader-local resources)
871     PL_ARRAY(struct pl_shader_desc) descriptors;
872 
873     // Dynamic per pass
874     enum pl_hook_stage save_stages;
875     PL_ARRAY(struct pass_tex) pass_textures;
876 
877     // State for PRNG/frame count
878     int frame_count;
879     uint64_t prng_state[4];
880 };
881 
hook_reset(void * priv)882 static void hook_reset(void *priv)
883 {
884     struct hook_priv *p = priv;
885     p->pass_textures.num = 0;
886 }
887 
888 struct szexp_ctx {
889     struct hook_priv *priv;
890     const struct pl_hook_params *params;
891 };
892 
lookup_tex(void * priv,pl_str var,float size[2])893 static bool lookup_tex(void *priv, pl_str var, float size[2])
894 {
895     struct szexp_ctx *ctx = priv;
896     struct hook_priv *p = ctx->priv;
897     const struct pl_hook_params *params = ctx->params;
898 
899     if (pl_str_equals0(var, "HOOKED")) {
900         pl_assert(params->tex);
901         size[0] = params->tex->params.w;
902         size[1] = params->tex->params.h;
903         return true;
904     }
905 
906     if (pl_str_equals0(var, "NATIVE_CROPPED")) {
907         size[0] = pl_rect_w(params->src_rect);
908         size[1] = pl_rect_h(params->src_rect);
909         return true;
910     }
911 
912     if (pl_str_equals0(var, "OUTPUT")) {
913         size[0] = pl_rect_w(params->dst_rect);
914         size[1] = pl_rect_h(params->dst_rect);
915         return true;
916     }
917 
918     if (pl_str_equals0(var, "MAIN"))
919         var = pl_str0("MAINPRESUB");
920 
921     for (int i = 0; i < p->pass_textures.num; i++) {
922         if (pl_str_equals(var, p->pass_textures.elem[i].name)) {
923             pl_tex tex = p->pass_textures.elem[i].tex;
924             size[0] = tex->params.w;
925             size[1] = tex->params.h;
926             return true;
927         }
928     }
929 
930     return false;
931 }
932 
prng_step(uint64_t s[4])933 static double prng_step(uint64_t s[4])
934 {
935     const uint64_t result = s[0] + s[3];
936     const uint64_t t = s[1] << 17;
937 
938     s[2] ^= s[0];
939     s[3] ^= s[1];
940     s[1] ^= s[2];
941     s[0] ^= s[3];
942 
943     s[2] ^= t;
944     s[3] = (s[3] << 45) | (s[3] >> (64 - 45));
945     return (result >> 11) * 0x1.0p-53;
946 }
947 
bind_pass_tex(pl_shader sh,pl_str name,const struct pass_tex * ptex,const struct pl_rect2df * rect)948 static bool bind_pass_tex(pl_shader sh, pl_str name,
949                           const struct pass_tex *ptex,
950                           const struct pl_rect2df *rect)
951 {
952     ident_t id, pos, size, pt;
953 
954     // Compatibility with mpv texture binding semantics
955     id = sh_bind(sh, ptex->tex, PL_TEX_ADDRESS_CLAMP, PL_TEX_SAMPLE_LINEAR,
956                  "hook_tex", rect, &pos, &size, &pt);
957     if (!id)
958         return false;
959 
960     GLSLH("#define %.*s_raw %s \n", PL_STR_FMT(name), id);
961     GLSLH("#define %.*s_pos %s \n", PL_STR_FMT(name), pos);
962     GLSLH("#define %.*s_map %s_map \n", PL_STR_FMT(name), pos);
963     GLSLH("#define %.*s_size %s \n", PL_STR_FMT(name), size);
964     GLSLH("#define %.*s_pt %s \n", PL_STR_FMT(name), pt);
965 
966     float off[2] = { ptex->rect.x0, ptex->rect.y0 };
967     GLSLH("#define %.*s_off %s \n", PL_STR_FMT(name),
968           sh_var(sh, (struct pl_shader_var) {
969               .var = pl_var_vec2("offset"),
970               .data = off,
971     }));
972 
973     struct pl_color_repr repr = ptex->repr;
974     ident_t scale = SH_FLOAT(pl_color_repr_normalize(&repr));
975     GLSLH("#define %.*s_mul %s \n", PL_STR_FMT(name), scale);
976 
977     // Compatibility with mpv
978     GLSLH("#define %.*s_rot mat2(1.0, 0.0, 0.0, 1.0) \n", PL_STR_FMT(name));
979 
980     // Sampling function boilerplate
981     GLSLH("#define %.*s_tex(pos) (%s * vec4(%s(%s, pos))) \n",
982           PL_STR_FMT(name), scale, sh_tex_fn(sh, ptex->tex->params), id);
983     GLSLH("#define %.*s_texOff(off) (%.*s_tex(%s + %s * vec2(off))) \n",
984           PL_STR_FMT(name), PL_STR_FMT(name), pos, pt);
985 
986     return true;
987 }
988 
save_pass_tex(struct hook_priv * p,struct pass_tex ptex)989 static void save_pass_tex(struct hook_priv *p, struct pass_tex ptex)
990 {
991 
992     for (int i = 0; i < p->pass_textures.num; i++) {
993         if (!pl_str_equals(p->pass_textures.elem[i].name, ptex.name))
994             continue;
995 
996         p->pass_textures.elem[i] = ptex;
997         return;
998     }
999 
1000     // No texture with this name yet, append new one
1001     PL_ARRAY_APPEND(p->alloc, p->pass_textures, ptex);
1002 }
1003 
hook_hook(void * priv,const struct pl_hook_params * params)1004 static struct pl_hook_res hook_hook(void *priv, const struct pl_hook_params *params)
1005 {
1006     struct hook_priv *p = priv;
1007     pl_str stage = pl_stage_to_mp(params->stage);
1008     struct pl_hook_res res = {0};
1009 
1010     // Save the input texture if needed
1011     if (p->save_stages & params->stage) {
1012         pl_assert(params->tex);
1013         struct pass_tex ptex = {
1014             .name = stage,
1015             .tex = params->tex,
1016             .rect = params->rect,
1017             .repr = params->repr,
1018             .color = params->color,
1019             .comps = params->components,
1020         };
1021 
1022         PL_TRACE(p, "Saving input texture '%.*s' for binding",
1023                  PL_STR_FMT(ptex.name));
1024         save_pass_tex(p, ptex);
1025     }
1026 
1027     pl_shader sh = NULL;
1028     struct szexp_ctx scope = {
1029         .priv = p,
1030         .params = params,
1031     };
1032 
1033     for (int n = 0; n < p->hook_passes.num; n++) {
1034         const struct hook_pass *pass = &p->hook_passes.elem[n];
1035         if (!(pass->exec_stages & params->stage))
1036             continue;
1037 
1038         const struct custom_shader_hook *hook = &pass->hook;
1039         PL_TRACE(p, "Executing hook pass %d on stage '%.*s': %.*s",
1040                  n, PL_STR_FMT(stage), PL_STR_FMT(hook->pass_desc));
1041 
1042         // Test for execution condition
1043         float run = 0;
1044         if (!pl_eval_szexpr(p->log, &scope, lookup_tex, hook->cond, &run))
1045             goto error;
1046 
1047         if (!run) {
1048             PL_TRACE(p, "Skipping hook due to condition");
1049             continue;
1050         }
1051 
1052         float out_size[2] = {0};
1053         if (!pl_eval_szexpr(p->log, &scope, lookup_tex, hook->width,  &out_size[0]) ||
1054             !pl_eval_szexpr(p->log, &scope, lookup_tex, hook->height, &out_size[1]))
1055         {
1056             goto error;
1057         }
1058 
1059         int out_w = roundf(out_size[0]),
1060             out_h = roundf(out_size[1]);
1061 
1062         // Generate a new texture to store the render result
1063         pl_tex fbo;
1064         fbo = params->get_tex(params->priv, out_w, out_h);
1065         if (!fbo) {
1066             PL_ERR(p, "Failed dispatching hook: `get_tex` callback failed?");
1067             goto error;
1068         }
1069 
1070         // Generate a new shader object
1071         sh = pl_dispatch_begin(params->dispatch);
1072         if (!sh_require(sh, PL_SHADER_SIG_NONE, out_w, out_h))
1073             goto error;
1074 
1075         if (hook->is_compute) {
1076             if (!sh_try_compute(sh, hook->threads_w, hook->threads_h, false, 0) ||
1077                 !fbo->params.storable)
1078             {
1079                 PL_ERR(p, "Failed dispatching COMPUTE shader");
1080                 goto error;
1081             }
1082         }
1083 
1084         // Bind all necessary input textures
1085         for (int i = 0; i < PL_ARRAY_SIZE(hook->bind_tex); i++) {
1086             pl_str texname = hook->bind_tex[i];
1087             if (!texname.len)
1088                 break;
1089 
1090             // Convenience alias, to allow writing shaders that are oblivious
1091             // of the exact stage they hooked. This simply translates to
1092             // whatever stage actually fired the hook.
1093             if (pl_str_equals0(texname, "HOOKED")) {
1094                 GLSLH("#define HOOKED_raw %.*s_raw \n", PL_STR_FMT(stage));
1095                 GLSLH("#define HOOKED_pos %.*s_pos \n", PL_STR_FMT(stage));
1096                 GLSLH("#define HOOKED_size %.*s_size \n", PL_STR_FMT(stage));
1097                 GLSLH("#define HOOKED_rot %.*s_rot \n", PL_STR_FMT(stage));
1098                 GLSLH("#define HOOKED_off %.*s_off \n", PL_STR_FMT(stage));
1099                 GLSLH("#define HOOKED_pt %.*s_pt \n", PL_STR_FMT(stage));
1100                 GLSLH("#define HOOKED_map %.*s_map \n", PL_STR_FMT(stage));
1101                 GLSLH("#define HOOKED_mul %.*s_mul \n", PL_STR_FMT(stage));
1102                 GLSLH("#define HOOKED_tex %.*s_tex \n", PL_STR_FMT(stage));
1103                 GLSLH("#define HOOKED_texOff %.*s_texOff \n", PL_STR_FMT(stage));
1104 
1105                 // Continue with binding this, under the new name
1106                 texname = stage;
1107             }
1108 
1109             // Compatibility alias, because MAIN and MAINPRESUB mean the same
1110             // thing to libplacebo, but user shaders are still written as
1111             // though they can be different concepts.
1112             if (pl_str_equals0(texname, "MAIN")) {
1113                 GLSLH("#define MAIN_raw MAINPRESUB_raw \n");
1114                 GLSLH("#define MAIN_pos MAINPRESUB_pos \n");
1115                 GLSLH("#define MAIN_size MAINPRESUB_size \n");
1116                 GLSLH("#define MAIN_rot MAINPRESUB_rot \n");
1117                 GLSLH("#define MAIN_off MAINPRESUB_off \n");
1118                 GLSLH("#define MAIN_pt MAINPRESUB_pt \n");
1119                 GLSLH("#define MAIN_map MAINPRESUB_map \n");
1120                 GLSLH("#define MAIN_mul MAINPRESUB_mul \n");
1121                 GLSLH("#define MAIN_tex MAINPRESUB_tex \n");
1122                 GLSLH("#define MAIN_texOff MAINPRESUB_texOff \n");
1123 
1124                 texname = pl_str0("MAINPRESUB");
1125             }
1126 
1127             for (int j = 0; j < p->descriptors.num; j++) {
1128                 if (pl_str_equals0(texname, p->descriptors.elem[j].desc.name)) {
1129                     // Directly bind this, no need to bother with all the
1130                     // `bind_pass_tex` boilerplate
1131                     ident_t id = sh_desc(sh, p->descriptors.elem[j]);
1132                     GLSLH("#define %.*s %s \n", PL_STR_FMT(texname), id);
1133 
1134                     if (p->descriptors.elem[j].desc.type == PL_DESC_SAMPLED_TEX) {
1135                         pl_tex tex = p->descriptors.elem[j].binding.object;
1136                         GLSLH("#define %.*s_tex(pos) (%s(%s, pos)) \n",
1137                               PL_STR_FMT(texname), sh_tex_fn(sh, tex->params), id);
1138                     }
1139                     goto next_bind;
1140                 }
1141             }
1142 
1143             for (int j = 0; j < p->pass_textures.num; j++) {
1144                 if (pl_str_equals(texname, p->pass_textures.elem[j].name)) {
1145                     // Note: We bind the whole texture, rather than
1146                     // params->rect, because user shaders in general are not
1147                     // designed to handle cropped input textures.
1148                     const struct pass_tex *ptex = &p->pass_textures.elem[j];
1149                     struct pl_rect2df rect = {
1150                         0, 0, ptex->tex->params.w, ptex->tex->params.h,
1151                     };
1152 
1153                     if (hook->offset_align && pl_str_equals(texname, stage)) {
1154                         float sx = pl_rect_w(params->rect) / pl_rect_w(params->src_rect),
1155                               sy = pl_rect_h(params->rect) / pl_rect_h(params->src_rect),
1156                               ox = params->rect.x0 - sx * params->src_rect.x0,
1157                               oy = params->rect.y0 - sy * params->src_rect.y0;
1158 
1159                         PL_TRACE(p, "Aligning plane with ref: %f %f", ox, oy);
1160                         pl_rect2df_offset(&rect, ox, oy);
1161                     }
1162 
1163                     if (!bind_pass_tex(sh, texname, &p->pass_textures.elem[j], &rect))
1164                         goto error;
1165                     goto next_bind;
1166                 }
1167             }
1168 
1169             // If none of the above matched, this is a bogus/unknown texture name
1170             PL_ERR(p, "Tried binding unknown texture '%.*s'!", PL_STR_FMT(texname));
1171             goto error;
1172 
1173     next_bind: ; // outer 'continue'
1174         }
1175 
1176         // Set up the input variables
1177         p->frame_count++;
1178         GLSLH("#define frame %s \n", sh_var(sh, (struct pl_shader_var) {
1179             .var = pl_var_int("frame"),
1180             .data = &p->frame_count,
1181             .dynamic = true,
1182         }));
1183 
1184         float random = prng_step(p->prng_state);
1185         GLSLH("#define random %s \n", sh_var(sh, (struct pl_shader_var) {
1186             .var = pl_var_float("random"),
1187             .data = &random,
1188             .dynamic = true,
1189         }));
1190 
1191         float src_size[2] = { pl_rect_w(params->src_rect), pl_rect_h(params->src_rect) };
1192         GLSLH("#define input_size %s \n", sh_var(sh, (struct pl_shader_var) {
1193             .var = pl_var_vec2("input_size"),
1194             .data = src_size,
1195         }));
1196 
1197         float dst_size[2] = { pl_rect_w(params->dst_rect), pl_rect_h(params->dst_rect) };
1198         GLSLH("#define target_size %s \n", sh_var(sh, (struct pl_shader_var) {
1199             .var = pl_var_vec2("target_size"),
1200             .data = dst_size,
1201         }));
1202 
1203         float tex_off[2] = { params->src_rect.x0, params->src_rect.y0 };
1204         GLSLH("#define tex_offset %s \n", sh_var(sh, (struct pl_shader_var) {
1205             .var = pl_var_vec2("tex_offset"),
1206             .data = tex_off,
1207         }));
1208 
1209         // Load and run the user shader itself
1210         sh_append_str(sh, SH_BUF_HEADER, hook->pass_body);
1211         sh_describe(sh, pl_strdup0(SH_TMP(sh), hook->pass_desc));
1212 
1213         bool ok;
1214         if (hook->is_compute) {
1215             GLSLP("#define out_image %s \n", sh_desc(sh, (struct pl_shader_desc) {
1216                 .binding.object = fbo,
1217                 .desc = {
1218                     .name = "out_image",
1219                     .type = PL_DESC_STORAGE_IMG,
1220                     .access = PL_DESC_ACCESS_WRITEONLY,
1221                 },
1222             }));
1223 
1224             sh->res.output = PL_SHADER_SIG_NONE;
1225 
1226             GLSL("hook(); \n");
1227             ok = pl_dispatch_compute(params->dispatch, &(struct pl_dispatch_compute_params) {
1228                 .shader = &sh,
1229                 .dispatch_size = {
1230                     // Round up as many blocks as are needed to cover the image
1231                     (out_w + hook->block_w - 1) / hook->block_w,
1232                     (out_h + hook->block_h - 1) / hook->block_h,
1233                     1,
1234                 },
1235                 .width  = out_w,
1236                 .height = out_h,
1237             });
1238         } else {
1239             GLSL("vec4 color = hook(); \n");
1240             ok = pl_dispatch_finish(params->dispatch, &(struct pl_dispatch_params) {
1241                 .shader = &sh,
1242                 .target = fbo,
1243             });
1244         }
1245 
1246         if (!ok)
1247             goto error;
1248 
1249         float sx = (float) out_w / params->tex->params.w,
1250               sy = (float) out_h / params->tex->params.h,
1251               x0 = sx * params->rect.x0 + hook->offset[0],
1252               y0 = sy * params->rect.y0 + hook->offset[1];
1253 
1254         struct pl_rect2df new_rect = {
1255             x0,
1256             y0,
1257             x0 + sx * pl_rect_w(params->rect),
1258             y0 + sy * pl_rect_h(params->rect),
1259         };
1260 
1261         if (hook->offset_align) {
1262             float rx = pl_rect_w(new_rect) / pl_rect_w(params->src_rect),
1263                   ry = pl_rect_h(new_rect) / pl_rect_h(params->src_rect),
1264                   ox = rx * params->src_rect.x0 - sx * params->rect.x0,
1265                   oy = ry * params->src_rect.y0 - sy * params->rect.y0;
1266 
1267             pl_rect2df_offset(&new_rect, ox, oy);
1268         }
1269 
1270         // Save the result of this shader invocation
1271         struct pass_tex ptex = {
1272             .name = hook->save_tex.len ? hook->save_tex : stage,
1273             .tex = fbo,
1274             .repr = params->repr,
1275             .color = params->color,
1276             .comps  = PL_DEF(hook->comps, params->components),
1277             .rect = new_rect,
1278         };
1279 
1280         // It's assumed that users will correctly normalize the input
1281         pl_color_repr_normalize(&ptex.repr);
1282 
1283         PL_TRACE(p, "Saving output texture '%.*s' from hook execution on '%.*s'",
1284                  PL_STR_FMT(ptex.name), PL_STR_FMT(stage));
1285 
1286         save_pass_tex(p, ptex);
1287 
1288         // Update the result object, unless we saved to a different name
1289         if (!hook->save_tex.len) {
1290             res = (struct pl_hook_res) {
1291                 .output = PL_HOOK_SIG_TEX,
1292                 .tex = fbo,
1293                 .repr = ptex.repr,
1294                 .color = ptex.color,
1295                 .components = PL_DEF(hook->comps, params->components),
1296                 .rect = new_rect,
1297             };
1298         }
1299     }
1300 
1301     return res;
1302 
1303 error:
1304     return (struct pl_hook_res) { .failed = true };
1305 }
1306 
pl_mpv_user_shader_parse(pl_gpu gpu,const char * shader_text,size_t shader_len)1307 const struct pl_hook *pl_mpv_user_shader_parse(pl_gpu gpu,
1308                                                const char *shader_text,
1309                                                size_t shader_len)
1310 {
1311     if (!shader_len)
1312         return NULL;
1313 
1314     struct pl_hook *hook = pl_alloc_obj(NULL, hook, struct hook_priv);
1315     struct hook_priv *p = PL_PRIV(hook);
1316 
1317     *hook = (struct pl_hook) {
1318         .input = PL_HOOK_SIG_TEX,
1319         .priv = p,
1320         .reset = hook_reset,
1321         .hook = hook_hook,
1322     };
1323 
1324     *p = (struct hook_priv) {
1325         .log = gpu->log,
1326         .gpu = gpu,
1327         .alloc = hook,
1328         .prng_state = {
1329             // Determined by fair die roll
1330             0xb76d71f9443c228allu, 0x93a02092fc4807e8llu,
1331             0x06d81748f838bd07llu, 0x9381ee129dddce6cllu,
1332         },
1333     };
1334 
1335     pl_str shader = { (char *) shader_text, shader_len };
1336     shader = pl_strdup(hook, shader);
1337 
1338     // Skip all garbage (e.g. comments) before the first header
1339     int pos = pl_str_find(shader, pl_str0("//!"));
1340     if (pos < 0) {
1341         PL_ERR(gpu, "Shader appears to contain no headers?");
1342         goto error;
1343     }
1344     shader = pl_str_drop(shader, pos);
1345 
1346     // Loop over the file
1347     while (shader.len > 0)
1348     {
1349         // Peek at the first header to dispatch the right type
1350         if (pl_str_startswith0(shader, "//!TEXTURE")) {
1351             struct pl_shader_desc sd;
1352             if (!parse_tex(gpu, hook, &shader, &sd))
1353                 goto error;
1354 
1355             PL_INFO(gpu, "Registering named texture '%s'", sd.desc.name);
1356             PL_ARRAY_APPEND(hook, p->descriptors, sd);
1357             continue;
1358         }
1359 
1360         if (pl_str_startswith0(shader, "//!BUFFER")) {
1361             struct pl_shader_desc sd;
1362             if (!parse_buf(gpu, hook, &shader, &sd))
1363                 goto error;
1364 
1365             PL_INFO(gpu, "Registering named buffer '%s'", sd.desc.name);
1366             PL_ARRAY_APPEND(hook, p->descriptors, sd);
1367             continue;
1368         }
1369 
1370         struct custom_shader_hook h;
1371         if (!parse_hook(gpu->log, &shader, &h))
1372             goto error;
1373 
1374         struct hook_pass pass = {
1375             .exec_stages = 0,
1376             .hook = h,
1377         };
1378 
1379         for (int i = 0; i < PL_ARRAY_SIZE(h.hook_tex); i++)
1380             pass.exec_stages |= mp_stage_to_pl(h.hook_tex[i]);
1381         for (int i = 0; i < PL_ARRAY_SIZE(h.bind_tex); i++) {
1382             p->save_stages |= mp_stage_to_pl(h.bind_tex[i]);
1383             if (pl_str_equals0(h.bind_tex[i], "HOOKED"))
1384                 p->save_stages |= pass.exec_stages;
1385         }
1386 
1387         // As an extra precaution, this avoids errors when trying to run
1388         // conditions against planes that were never hooked. As a sole
1389         // exception, OUTPUT is special because it's hard-coded to return the
1390         // dst_rect even before it was hooked. (This is an apparently
1391         // undocumented mpv quirk, but shaders rely on it in practice)
1392         enum pl_hook_stage rpn_stages = 0;
1393         for (int i = 0; i < PL_ARRAY_SIZE(h.width); i++) {
1394             if (h.width[i].tag == SZEXP_VAR_W || h.width[i].tag == SZEXP_VAR_H)
1395                 rpn_stages |= mp_stage_to_pl(h.width[i].val.varname);
1396         }
1397         for (int i = 0; i < PL_ARRAY_SIZE(h.height); i++) {
1398             if (h.height[i].tag == SZEXP_VAR_W || h.height[i].tag == SZEXP_VAR_H)
1399                 rpn_stages |= mp_stage_to_pl(h.height[i].val.varname);
1400         }
1401         for (int i = 0; i < PL_ARRAY_SIZE(h.cond); i++) {
1402             if (h.cond[i].tag == SZEXP_VAR_W || h.cond[i].tag == SZEXP_VAR_H)
1403                 rpn_stages |= mp_stage_to_pl(h.cond[i].val.varname);
1404         }
1405 
1406         p->save_stages |= rpn_stages & ~PL_HOOK_OUTPUT;
1407 
1408         PL_INFO(gpu, "Registering hook pass: %.*s", PL_STR_FMT(h.pass_desc));
1409         PL_ARRAY_APPEND(hook, p->hook_passes, pass);
1410     }
1411 
1412     // We need to hook on both the exec and save stages, so that we can keep
1413     // track of any textures we might need
1414     hook->stages |= p->save_stages;
1415     for (int i = 0; i < p->hook_passes.num; i++)
1416         hook->stages |= p->hook_passes.elem[i].exec_stages;
1417 
1418     return hook;
1419 
1420 error:
1421     pl_free(hook);
1422     return NULL;
1423 }
1424 
pl_mpv_user_shader_destroy(const struct pl_hook ** hookp)1425 void pl_mpv_user_shader_destroy(const struct pl_hook **hookp)
1426 {
1427     const struct pl_hook *hook = *hookp;
1428     if (!hook)
1429         return;
1430 
1431     struct hook_priv *p = PL_PRIV(hook);
1432     for (int i = 0; i < p->descriptors.num; i++) {
1433         switch (p->descriptors.elem[i].desc.type) {
1434             case PL_DESC_BUF_UNIFORM:
1435             case PL_DESC_BUF_STORAGE:
1436             case PL_DESC_BUF_TEXEL_UNIFORM:
1437             case PL_DESC_BUF_TEXEL_STORAGE: {
1438                 pl_buf buf = p->descriptors.elem[i].binding.object;
1439                 pl_buf_destroy(p->gpu, &buf);
1440                 break;
1441             }
1442 
1443             case PL_DESC_SAMPLED_TEX:
1444             case PL_DESC_STORAGE_IMG: {
1445                 pl_tex tex = p->descriptors.elem[i].binding.object;
1446                 pl_tex_destroy(p->gpu, &tex);
1447                 break;
1448 
1449             case PL_DESC_INVALID:
1450             case PL_DESC_TYPE_COUNT:
1451                 pl_unreachable();
1452             }
1453         }
1454     }
1455 
1456     pl_free((void *) hook);
1457 }
1458