1 // waifu2x implemented with ncnn library
2 
3 #include <stdio.h>
4 #include <algorithm>
5 #include <queue>
6 #include <vector>
7 #include <clocale>
8 
9 #if _WIN32
10 // image decoder and encoder with wic
11 #include "wic_image.h"
12 #else // _WIN32
13 // image decoder and encoder with stb
14 #define STB_IMAGE_IMPLEMENTATION
15 #define STBI_NO_PSD
16 #define STBI_NO_TGA
17 #define STBI_NO_GIF
18 #define STBI_NO_HDR
19 #define STBI_NO_PIC
20 #define STBI_NO_STDIO
21 #include "stb_image.h"
22 #define STB_IMAGE_WRITE_IMPLEMENTATION
23 #include "stb_image_write.h"
24 #endif // _WIN32
25 #include "webp_image.h"
26 
27 #if _WIN32
28 #include <wchar.h>
29 static wchar_t* optarg = NULL;
30 static int optind = 1;
getopt(int argc,wchar_t * const argv[],const wchar_t * optstring)31 static wchar_t getopt(int argc, wchar_t* const argv[], const wchar_t* optstring)
32 {
33     if (optind >= argc || argv[optind][0] != L'-')
34         return -1;
35 
36     wchar_t opt = argv[optind][1];
37     const wchar_t* p = wcschr(optstring, opt);
38     if (p == NULL)
39         return L'?';
40 
41     optarg = NULL;
42 
43     if (p[1] == L':')
44     {
45         optind++;
46         if (optind >= argc)
47             return L'?';
48 
49         optarg = argv[optind];
50     }
51 
52     optind++;
53 
54     return opt;
55 }
56 
parse_optarg_int_array(const wchar_t * optarg)57 static std::vector<int> parse_optarg_int_array(const wchar_t* optarg)
58 {
59     std::vector<int> array;
60     array.push_back(_wtoi(optarg));
61 
62     const wchar_t* p = wcschr(optarg, L',');
63     while (p)
64     {
65         p++;
66         array.push_back(_wtoi(p));
67         p = wcschr(p, L',');
68     }
69 
70     return array;
71 }
72 #else // _WIN32
73 #include <unistd.h> // getopt()
74 
parse_optarg_int_array(const char * optarg)75 static std::vector<int> parse_optarg_int_array(const char* optarg)
76 {
77     std::vector<int> array;
78     array.push_back(atoi(optarg));
79 
80     const char* p = strchr(optarg, ',');
81     while (p)
82     {
83         p++;
84         array.push_back(atoi(p));
85         p = strchr(p, ',');
86     }
87 
88     return array;
89 }
90 #endif // _WIN32
91 
92 // ncnn
93 #include "cpu.h"
94 #include "gpu.h"
95 #include "platform.h"
96 
97 #include "waifu2x.h"
98 
99 #include "filesystem_utils.h"
100 
print_usage()101 static void print_usage()
102 {
103     fprintf(stdout, "Usage: waifu2x-ncnn-vulkan -i infile -o outfile [options]...\n\n");
104     fprintf(stdout, "  -h                   show this help\n");
105     fprintf(stdout, "  -v                   verbose output\n");
106     fprintf(stdout, "  -i input-path        input image path (jpg/png/webp) or directory\n");
107     fprintf(stdout, "  -o output-path       output image path (jpg/png/webp) or directory\n");
108     fprintf(stdout, "  -n noise-level       denoise level (-1/0/1/2/3, default=0)\n");
109     fprintf(stdout, "  -s scale             upscale ratio (1/2/4/8/16/32, default=2)\n");
110     fprintf(stdout, "  -t tile-size         tile size (>=32/0=auto, default=0) can be 0,0,0 for multi-gpu\n");
111     fprintf(stdout, "  -m model-path        waifu2x model path (default=models-cunet)\n");
112     fprintf(stdout, "  -g gpu-id            gpu device to use (-1=cpu, default=auto) can be 0,1,2 for multi-gpu\n");
113     fprintf(stdout, "  -j load:proc:save    thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu\n");
114     fprintf(stdout, "  -x                   enable tta mode\n");
115     fprintf(stdout, "  -f format            output image format (jpg/png/webp, default=ext/png)\n");
116 }
117 
118 class Task
119 {
120 public:
121     int id;
122     int webp;
123     int scale;
124 
125     path_t inpath;
126     path_t outpath;
127 
128     ncnn::Mat inimage;
129     ncnn::Mat outimage;
130 };
131 
132 class TaskQueue
133 {
134 public:
TaskQueue()135     TaskQueue()
136     {
137     }
138 
put(const Task & v)139     void put(const Task& v)
140     {
141         lock.lock();
142 
143         while (tasks.size() >= 8) // FIXME hardcode queue length
144         {
145             condition.wait(lock);
146         }
147 
148         tasks.push(v);
149 
150         lock.unlock();
151 
152         condition.signal();
153     }
154 
get(Task & v)155     void get(Task& v)
156     {
157         lock.lock();
158 
159         while (tasks.size() == 0)
160         {
161             condition.wait(lock);
162         }
163 
164         v = tasks.front();
165         tasks.pop();
166 
167         lock.unlock();
168 
169         condition.signal();
170     }
171 
172 private:
173     ncnn::Mutex lock;
174     ncnn::ConditionVariable condition;
175     std::queue<Task> tasks;
176 };
177 
178 TaskQueue toproc;
179 TaskQueue tosave;
180 
181 class LoadThreadParams
182 {
183 public:
184     int scale;
185     int jobs_load;
186 
187     // session data
188     std::vector<path_t> input_files;
189     std::vector<path_t> output_files;
190 };
191 
load(void * args)192 void* load(void* args)
193 {
194     const LoadThreadParams* ltp = (const LoadThreadParams*)args;
195     const int count = ltp->input_files.size();
196     const int scale = ltp->scale;
197 
198     #pragma omp parallel for schedule(static,1) num_threads(ltp->jobs_load)
199     for (int i=0; i<count; i++)
200     {
201         const path_t& imagepath = ltp->input_files[i];
202 
203         int webp = 0;
204 
205         unsigned char* pixeldata = 0;
206         int w;
207         int h;
208         int c;
209 
210 #if _WIN32
211         FILE* fp = _wfopen(imagepath.c_str(), L"rb");
212 #else
213         FILE* fp = fopen(imagepath.c_str(), "rb");
214 #endif
215         if (fp)
216         {
217             // read whole file
218             unsigned char* filedata = 0;
219             int length = 0;
220             {
221                 fseek(fp, 0, SEEK_END);
222                 length = ftell(fp);
223                 rewind(fp);
224                 filedata = (unsigned char*)malloc(length);
225                 if (filedata)
226                 {
227                     fread(filedata, 1, length, fp);
228                 }
229                 fclose(fp);
230             }
231 
232             if (filedata)
233             {
234                 pixeldata = webp_load(filedata, length, &w, &h, &c);
235                 if (pixeldata)
236                 {
237                     webp = 1;
238                 }
239                 else
240                 {
241                     // not webp, try jpg png etc.
242 #if _WIN32
243                     pixeldata = wic_decode_image(imagepath.c_str(), &w, &h, &c);
244 #else // _WIN32
245                     pixeldata = stbi_load_from_memory(filedata, length, &w, &h, &c, 0);
246                     if (pixeldata)
247                     {
248                         // stb_image auto channel
249                         if (c == 1)
250                         {
251                             // grayscale -> rgb
252                             stbi_image_free(pixeldata);
253                             pixeldata = stbi_load_from_memory(filedata, length, &w, &h, &c, 3);
254                             c = 3;
255                         }
256                         else if (c == 2)
257                         {
258                             // grayscale + alpha -> rgba
259                             stbi_image_free(pixeldata);
260                             pixeldata = stbi_load_from_memory(filedata, length, &w, &h, &c, 4);
261                             c = 4;
262                         }
263                     }
264 #endif // _WIN32
265                 }
266 
267                 free(filedata);
268             }
269         }
270         if (pixeldata)
271         {
272             Task v;
273             v.id = i;
274             v.webp = webp;
275             v.scale = scale;
276             v.inpath = imagepath;
277             v.outpath = ltp->output_files[i];
278 
279             v.inimage = ncnn::Mat(w, h, (void*)pixeldata, (size_t)c, c);
280 
281             path_t ext = get_file_extension(v.outpath);
282             if (c == 4 && (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG")))
283             {
284                 path_t output_filename2 = ltp->output_files[i] + PATHSTR(".png");
285                 v.outpath = output_filename2;
286 #if _WIN32
287                 fwprintf(stderr, L"image %ls has alpha channel ! %ls will output %ls\n", imagepath.c_str(), imagepath.c_str(), output_filename2.c_str());
288 #else // _WIN32
289                 fprintf(stderr, "image %s has alpha channel ! %s will output %s\n", imagepath.c_str(), imagepath.c_str(), output_filename2.c_str());
290 #endif // _WIN32
291             }
292 
293             toproc.put(v);
294         }
295         else
296         {
297 #if _WIN32
298             fwprintf(stderr, L"decode image %ls failed\n", imagepath.c_str());
299 #else // _WIN32
300             fprintf(stderr, "decode image %s failed\n", imagepath.c_str());
301 #endif // _WIN32
302         }
303     }
304 
305     return 0;
306 }
307 
308 class ProcThreadParams
309 {
310 public:
311     const Waifu2x* waifu2x;
312 };
313 
proc(void * args)314 void* proc(void* args)
315 {
316     const ProcThreadParams* ptp = (const ProcThreadParams*)args;
317     const Waifu2x* waifu2x = ptp->waifu2x;
318 
319     for (;;)
320     {
321         Task v;
322 
323         toproc.get(v);
324 
325         if (v.id == -233)
326             break;
327 
328         const int scale = v.scale;
329         if (scale == 1)
330         {
331             v.outimage = ncnn::Mat(v.inimage.w, v.inimage.h, (size_t)v.inimage.elemsize, (int)v.inimage.elemsize);
332             waifu2x->process(v.inimage, v.outimage);
333 
334             tosave.put(v);
335             continue;
336         }
337 
338         int scale_run_count = 0;
339         if (scale == 2)
340         {
341             scale_run_count = 1;
342         }
343         if (scale == 4)
344         {
345             scale_run_count = 2;
346         }
347         if (scale == 8)
348         {
349             scale_run_count = 3;
350         }
351         if (scale == 16)
352         {
353             scale_run_count = 4;
354         }
355         if (scale == 32)
356         {
357             scale_run_count = 5;
358         }
359 
360         v.outimage = ncnn::Mat(v.inimage.w * 2, v.inimage.h * 2, (size_t)v.inimage.elemsize, (int)v.inimage.elemsize);
361         waifu2x->process(v.inimage, v.outimage);
362 
363         for (int i = 1; i < scale_run_count; i++)
364         {
365             ncnn::Mat tmp = v.outimage;
366             v.outimage = ncnn::Mat(tmp.w * 2, tmp.h * 2, (size_t)v.inimage.elemsize, (int)v.inimage.elemsize);
367             waifu2x->process(tmp, v.outimage);
368         }
369 
370         tosave.put(v);
371     }
372 
373     return 0;
374 }
375 
376 class SaveThreadParams
377 {
378 public:
379     int verbose;
380 };
381 
save(void * args)382 void* save(void* args)
383 {
384     const SaveThreadParams* stp = (const SaveThreadParams*)args;
385     const int verbose = stp->verbose;
386 
387     for (;;)
388     {
389         Task v;
390 
391         tosave.get(v);
392 
393         if (v.id == -233)
394             break;
395 
396         // free input pixel data
397         {
398             unsigned char* pixeldata = (unsigned char*)v.inimage.data;
399             if (v.webp == 1)
400             {
401                 free(pixeldata);
402             }
403             else
404             {
405 #if _WIN32
406                 free(pixeldata);
407 #else
408                 stbi_image_free(pixeldata);
409 #endif
410             }
411         }
412 
413         int success = 0;
414 
415         path_t ext = get_file_extension(v.outpath);
416 
417         if (ext == PATHSTR("webp") || ext == PATHSTR("WEBP"))
418         {
419             success = webp_save(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, (const unsigned char*)v.outimage.data);
420         }
421         else if (ext == PATHSTR("png") || ext == PATHSTR("PNG"))
422         {
423 #if _WIN32
424             success = wic_encode_image(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data);
425 #else
426             success = stbi_write_png(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data, 0);
427 #endif
428         }
429         else if (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG"))
430         {
431 #if _WIN32
432             success = wic_encode_jpeg_image(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data);
433 #else
434             success = stbi_write_jpg(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data, 100);
435 #endif
436         }
437         if (success)
438         {
439             if (verbose)
440             {
441 #if _WIN32
442                 fwprintf(stdout, L"%ls -> %ls done\n", v.inpath.c_str(), v.outpath.c_str());
443 #else
444                 fprintf(stdout, "%s -> %s done\n", v.inpath.c_str(), v.outpath.c_str());
445 #endif
446             }
447         }
448         else
449         {
450 #if _WIN32
451             fwprintf(stderr, L"encode image %ls failed\n", v.outpath.c_str());
452 #else
453             fprintf(stderr, "encode image %s failed\n", v.outpath.c_str());
454 #endif
455         }
456     }
457 
458     return 0;
459 }
460 
461 
462 #if _WIN32
wmain(int argc,wchar_t ** argv)463 int wmain(int argc, wchar_t** argv)
464 #else
465 int main(int argc, char** argv)
466 #endif
467 {
468     path_t inputpath;
469     path_t outputpath;
470     int noise = 0;
471     int scale = 2;
472     std::vector<int> tilesize;
473     path_t model = PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-cunet");
474     std::vector<int> gpuid;
475     int jobs_load = 1;
476     std::vector<int> jobs_proc;
477     int jobs_save = 2;
478     int verbose = 0;
479     int tta_mode = 0;
480     path_t format = PATHSTR("png");
481 
482 #if _WIN32
483     setlocale(LC_ALL, "");
484     wchar_t opt;
485     while ((opt = getopt(argc, argv, L"i:o:n:s:t:m:g:j:f:vxh")) != (wchar_t)-1)
486     {
487         switch (opt)
488         {
489         case L'i':
490             inputpath = optarg;
491             break;
492         case L'o':
493             outputpath = optarg;
494             break;
495         case L'n':
496             noise = _wtoi(optarg);
497             break;
498         case L's':
499             scale = _wtoi(optarg);
500             break;
501         case L't':
502             tilesize = parse_optarg_int_array(optarg);
503             break;
504         case L'm':
505             model = optarg;
506             break;
507         case L'g':
508             gpuid = parse_optarg_int_array(optarg);
509             break;
510         case L'j':
511             swscanf(optarg, L"%d:%*[^:]:%d", &jobs_load, &jobs_save);
512             jobs_proc = parse_optarg_int_array(wcschr(optarg, L':') + 1);
513             break;
514         case L'f':
515             format = optarg;
516             break;
517         case L'v':
518             verbose = 1;
519             break;
520         case L'x':
521             tta_mode = 1;
522             break;
523         case L'h':
524         default:
525             print_usage();
526             return -1;
527         }
528     }
529 #else // _WIN32
530     int opt;
531     while ((opt = getopt(argc, argv, "i:o:n:s:t:m:g:j:f:vxh")) != -1)
532     {
533         switch (opt)
534         {
535         case 'i':
536             inputpath = optarg;
537             break;
538         case 'o':
539             outputpath = optarg;
540             break;
541         case 'n':
542             noise = atoi(optarg);
543             break;
544         case 's':
545             scale = atoi(optarg);
546             break;
547         case 't':
548             tilesize = parse_optarg_int_array(optarg);
549             break;
550         case 'm':
551             model = optarg;
552             break;
553         case 'g':
554             gpuid = parse_optarg_int_array(optarg);
555             break;
556         case 'j':
557             sscanf(optarg, "%d:%*[^:]:%d", &jobs_load, &jobs_save);
558             jobs_proc = parse_optarg_int_array(strchr(optarg, ':') + 1);
559             break;
560         case 'f':
561             format = optarg;
562             break;
563         case 'v':
564             verbose = 1;
565             break;
566         case 'x':
567             tta_mode = 1;
568             break;
569         case 'h':
570         default:
571             print_usage();
572             return -1;
573         }
574     }
575 #endif // _WIN32
576 
577     if (inputpath.empty() || outputpath.empty())
578     {
579         print_usage();
580         return -1;
581     }
582 
583     if (noise < -1 || noise > 3)
584     {
585         fprintf(stderr, "invalid noise argument\n");
586         return -1;
587     }
588 
589     if (!(scale == 1 || scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32))
590     {
591         fprintf(stderr, "invalid scale argument\n");
592         return -1;
593     }
594 
595     if (tilesize.size() != (gpuid.empty() ? 1 : gpuid.size()) && !tilesize.empty())
596     {
597         fprintf(stderr, "invalid tilesize argument\n");
598         return -1;
599     }
600 
601     for (int i=0; i<(int)tilesize.size(); i++)
602     {
603         if (tilesize[i] != 0 && tilesize[i] < 32)
604         {
605             fprintf(stderr, "invalid tilesize argument\n");
606             return -1;
607         }
608     }
609 
610     if (jobs_load < 1 || jobs_save < 1)
611     {
612         fprintf(stderr, "invalid thread count argument\n");
613         return -1;
614     }
615 
616     if (jobs_proc.size() != (gpuid.empty() ? 1 : gpuid.size()) && !jobs_proc.empty())
617     {
618         fprintf(stderr, "invalid jobs_proc thread count argument\n");
619         return -1;
620     }
621 
622     for (int i=0; i<(int)jobs_proc.size(); i++)
623     {
624         if (jobs_proc[i] < 1)
625         {
626             fprintf(stderr, "invalid jobs_proc thread count argument\n");
627             return -1;
628         }
629     }
630 
631     if (!path_is_directory(outputpath))
632     {
633         // guess format from outputpath no matter what format argument specified
634         path_t ext = get_file_extension(outputpath);
635 
636         if (ext == PATHSTR("png") || ext == PATHSTR("PNG"))
637         {
638             format = PATHSTR("png");
639         }
640         else if (ext == PATHSTR("webp") || ext == PATHSTR("WEBP"))
641         {
642             format = PATHSTR("webp");
643         }
644         else if (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG"))
645         {
646             format = PATHSTR("jpg");
647         }
648         else
649         {
650             fprintf(stderr, "invalid outputpath extension type\n");
651             return -1;
652         }
653     }
654 
655     if (format != PATHSTR("png") && format != PATHSTR("webp") && format != PATHSTR("jpg"))
656     {
657         fprintf(stderr, "invalid format argument\n");
658         return -1;
659     }
660 
661     // collect input and output filepath
662     std::vector<path_t> input_files;
663     std::vector<path_t> output_files;
664     {
665         if (path_is_directory(inputpath) && path_is_directory(outputpath))
666         {
667             std::vector<path_t> filenames;
668             int lr = list_directory(inputpath, filenames);
669             if (lr != 0)
670                 return -1;
671 
672             const int count = filenames.size();
673             input_files.resize(count);
674             output_files.resize(count);
675 
676             path_t last_filename;
677             path_t last_filename_noext;
678             for (int i=0; i<count; i++)
679             {
680                 path_t filename = filenames[i];
681                 path_t filename_noext = get_file_name_without_extension(filename);
682                 path_t output_filename = filename_noext + PATHSTR('.') + format;
683 
684                 // filename list is sorted, check if output image path conflicts
685                 if (filename_noext == last_filename_noext)
686                 {
687                     path_t output_filename2 = filename + PATHSTR('.') + format;
688 #if _WIN32
689                     fwprintf(stderr, L"both %ls and %ls output %ls ! %ls will output %ls\n", filename.c_str(), last_filename.c_str(), output_filename.c_str(), filename.c_str(), output_filename2.c_str());
690 #else
691                     fprintf(stderr, "both %s and %s output %s ! %s will output %s\n", filename.c_str(), last_filename.c_str(), output_filename.c_str(), filename.c_str(), output_filename2.c_str());
692 #endif
693                     output_filename = output_filename2;
694                 }
695                 else
696                 {
697                     last_filename = filename;
698                     last_filename_noext = filename_noext;
699                 }
700 
701                 input_files[i] = inputpath + PATHSTR('/') + filename;
702                 output_files[i] = outputpath + PATHSTR('/') + output_filename;
703             }
704         }
705         else if (!path_is_directory(inputpath) && !path_is_directory(outputpath))
706         {
707             input_files.push_back(inputpath);
708             output_files.push_back(outputpath);
709         }
710         else
711         {
712             fprintf(stderr, "inputpath and outputpath must be either file or directory at the same time\n");
713             return -1;
714         }
715     }
716 
717     int prepadding = 0;
718 
719     if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-cunet")) != path_t::npos)
720     {
721         if (noise == -1)
722         {
723             prepadding = 18;
724         }
725         else if (scale == 1)
726         {
727             prepadding = 28;
728         }
729         else if (scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32)
730         {
731             prepadding = 18;
732         }
733     }
734     else if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_anime_style_art_rgb")) != path_t::npos)
735     {
736         prepadding = 7;
737     }
738     else if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_photo")) != path_t::npos)
739     {
740         prepadding = 7;
741     }
742     else
743     {
744         fprintf(stderr, "unknown model dir type\n");
745         return -1;
746     }
747 
748 #if _WIN32
749     wchar_t parampath[256];
750     wchar_t modelpath[256];
751     if (noise == -1)
752     {
753         swprintf(parampath, 256, L"%s/scale2.0x_model.param", model.c_str());
754         swprintf(modelpath, 256, L"%s/scale2.0x_model.bin", model.c_str());
755     }
756     else if (scale == 1)
757     {
758         swprintf(parampath, 256, L"%s/noise%d_model.param", model.c_str(), noise);
759         swprintf(modelpath, 256, L"%s/noise%d_model.bin", model.c_str(), noise);
760     }
761     else if (scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32)
762     {
763         swprintf(parampath, 256, L"%s/noise%d_scale2.0x_model.param", model.c_str(), noise);
764         swprintf(modelpath, 256, L"%s/noise%d_scale2.0x_model.bin", model.c_str(), noise);
765     }
766 #else
767     char parampath[256];
768     char modelpath[256];
769     if (noise == -1)
770     {
771         sprintf(parampath, "%s/scale2.0x_model.param", model.c_str());
772         sprintf(modelpath, "%s/scale2.0x_model.bin", model.c_str());
773     }
774     else if (scale == 1)
775     {
776         sprintf(parampath, "%s/noise%d_model.param", model.c_str(), noise);
777         sprintf(modelpath, "%s/noise%d_model.bin", model.c_str(), noise);
778     }
779     else if (scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32)
780     {
781         sprintf(parampath, "%s/noise%d_scale2.0x_model.param", model.c_str(), noise);
782         sprintf(modelpath, "%s/noise%d_scale2.0x_model.bin", model.c_str(), noise);
783     }
784 #endif
785 
786     path_t paramfullpath = sanitize_filepath(parampath);
787     path_t modelfullpath = sanitize_filepath(modelpath);
788 
789 #if _WIN32
790     CoInitializeEx(NULL, COINIT_MULTITHREADED);
791 #endif
792 
793     ncnn::create_gpu_instance();
794 
795     if (gpuid.empty())
796     {
797         gpuid.push_back(ncnn::get_default_gpu_index());
798     }
799 
800     const int use_gpu_count = (int)gpuid.size();
801 
802     if (jobs_proc.empty())
803     {
804         jobs_proc.resize(use_gpu_count, 2);
805     }
806 
807     if (tilesize.empty())
808     {
809         tilesize.resize(use_gpu_count, 0);
810     }
811 
812     int cpu_count = std::max(1, ncnn::get_cpu_count());
813     jobs_load = std::min(jobs_load, cpu_count);
814     jobs_save = std::min(jobs_save, cpu_count);
815 
816     int gpu_count = ncnn::get_gpu_count();
817     for (int i=0; i<use_gpu_count; i++)
818     {
819         if (gpuid[i] < -1 || gpuid[i] >= gpu_count)
820         {
821             fprintf(stderr, "invalid gpu device\n");
822 
823             ncnn::destroy_gpu_instance();
824             return -1;
825         }
826     }
827 
828     int total_jobs_proc = 0;
829     for (int i=0; i<use_gpu_count; i++)
830     {
831         if (gpuid[i] == -1)
832         {
833             jobs_proc[i] = std::min(jobs_proc[i], cpu_count);
834             total_jobs_proc += 1;
835         }
836         else
837         {
838             total_jobs_proc += jobs_proc[i];
839         }
840     }
841 
842     for (int i=0; i<use_gpu_count; i++)
843     {
844         if (tilesize[i] != 0)
845             continue;
846 
847         if (gpuid[i] == -1)
848         {
849             // cpu only
850             tilesize[i] = 4000;
851             continue;
852         }
853 
854         uint32_t heap_budget = ncnn::get_gpu_device(gpuid[i])->get_heap_budget();
855 
856         // more fine-grained tilesize policy here
857         if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-cunet")) != path_t::npos)
858         {
859             if (heap_budget > 2600)
860                 tilesize[i] = 400;
861             else if (heap_budget > 740)
862                 tilesize[i] = 200;
863             else if (heap_budget > 250)
864                 tilesize[i] = 100;
865             else
866                 tilesize[i] = 32;
867         }
868         else if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_anime_style_art_rgb")) != path_t::npos
869             || model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_photo")) != path_t::npos)
870         {
871             if (heap_budget > 1900)
872                 tilesize[i] = 400;
873             else if (heap_budget > 550)
874                 tilesize[i] = 200;
875             else if (heap_budget > 190)
876                 tilesize[i] = 100;
877             else
878                 tilesize[i] = 32;
879         }
880     }
881 
882     {
883         std::vector<Waifu2x*> waifu2x(use_gpu_count);
884 
885         for (int i=0; i<use_gpu_count; i++)
886         {
887             int num_threads = gpuid[i] == -1 ? jobs_proc[i] : 1;
888 
889             waifu2x[i] = new Waifu2x(gpuid[i], tta_mode, num_threads);
890 
891             waifu2x[i]->load(paramfullpath, modelfullpath);
892 
893             waifu2x[i]->noise = noise;
894             waifu2x[i]->scale = (scale >= 2) ? 2 : scale;
895             waifu2x[i]->tilesize = tilesize[i];
896             waifu2x[i]->prepadding = prepadding;
897         }
898 
899         // main routine
900         {
901             // load image
902             LoadThreadParams ltp;
903             ltp.scale = scale;
904             ltp.jobs_load = jobs_load;
905             ltp.input_files = input_files;
906             ltp.output_files = output_files;
907 
908             ncnn::Thread load_thread(load, (void*)&ltp);
909 
910             // waifu2x proc
911             std::vector<ProcThreadParams> ptp(use_gpu_count);
912             for (int i=0; i<use_gpu_count; i++)
913             {
914                 ptp[i].waifu2x = waifu2x[i];
915             }
916 
917             std::vector<ncnn::Thread*> proc_threads(total_jobs_proc);
918             {
919                 int total_jobs_proc_id = 0;
920                 for (int i=0; i<use_gpu_count; i++)
921                 {
922                     if (gpuid[i] == -1)
923                     {
924                         proc_threads[total_jobs_proc_id++] = new ncnn::Thread(proc, (void*)&ptp[i]);
925                     }
926                     else
927                     {
928                         for (int j=0; j<jobs_proc[i]; j++)
929                         {
930                             proc_threads[total_jobs_proc_id++] = new ncnn::Thread(proc, (void*)&ptp[i]);
931                         }
932                     }
933                 }
934             }
935 
936             // save image
937             SaveThreadParams stp;
938             stp.verbose = verbose;
939 
940             std::vector<ncnn::Thread*> save_threads(jobs_save);
941             for (int i=0; i<jobs_save; i++)
942             {
943                 save_threads[i] = new ncnn::Thread(save, (void*)&stp);
944             }
945 
946             // end
947             load_thread.join();
948 
949             Task end;
950             end.id = -233;
951 
952             for (int i=0; i<total_jobs_proc; i++)
953             {
954                 toproc.put(end);
955             }
956 
957             for (int i=0; i<total_jobs_proc; i++)
958             {
959                 proc_threads[i]->join();
960                 delete proc_threads[i];
961             }
962 
963             for (int i=0; i<jobs_save; i++)
964             {
965                 tosave.put(end);
966             }
967 
968             for (int i=0; i<jobs_save; i++)
969             {
970                 save_threads[i]->join();
971                 delete save_threads[i];
972             }
973         }
974 
975         for (int i=0; i<use_gpu_count; i++)
976         {
977             delete waifu2x[i];
978         }
979         waifu2x.clear();
980     }
981 
982     ncnn::destroy_gpu_instance();
983 
984     return 0;
985 }
986