1 // waifu2x implemented with ncnn library
2
3 #include <stdio.h>
4 #include <algorithm>
5 #include <queue>
6 #include <vector>
7 #include <clocale>
8
9 #if _WIN32
10 // image decoder and encoder with wic
11 #include "wic_image.h"
12 #else // _WIN32
13 // image decoder and encoder with stb
14 #define STB_IMAGE_IMPLEMENTATION
15 #define STBI_NO_PSD
16 #define STBI_NO_TGA
17 #define STBI_NO_GIF
18 #define STBI_NO_HDR
19 #define STBI_NO_PIC
20 #define STBI_NO_STDIO
21 #include "stb_image.h"
22 #define STB_IMAGE_WRITE_IMPLEMENTATION
23 #include "stb_image_write.h"
24 #endif // _WIN32
25 #include "webp_image.h"
26
27 #if _WIN32
28 #include <wchar.h>
29 static wchar_t* optarg = NULL;
30 static int optind = 1;
getopt(int argc,wchar_t * const argv[],const wchar_t * optstring)31 static wchar_t getopt(int argc, wchar_t* const argv[], const wchar_t* optstring)
32 {
33 if (optind >= argc || argv[optind][0] != L'-')
34 return -1;
35
36 wchar_t opt = argv[optind][1];
37 const wchar_t* p = wcschr(optstring, opt);
38 if (p == NULL)
39 return L'?';
40
41 optarg = NULL;
42
43 if (p[1] == L':')
44 {
45 optind++;
46 if (optind >= argc)
47 return L'?';
48
49 optarg = argv[optind];
50 }
51
52 optind++;
53
54 return opt;
55 }
56
parse_optarg_int_array(const wchar_t * optarg)57 static std::vector<int> parse_optarg_int_array(const wchar_t* optarg)
58 {
59 std::vector<int> array;
60 array.push_back(_wtoi(optarg));
61
62 const wchar_t* p = wcschr(optarg, L',');
63 while (p)
64 {
65 p++;
66 array.push_back(_wtoi(p));
67 p = wcschr(p, L',');
68 }
69
70 return array;
71 }
72 #else // _WIN32
73 #include <unistd.h> // getopt()
74
parse_optarg_int_array(const char * optarg)75 static std::vector<int> parse_optarg_int_array(const char* optarg)
76 {
77 std::vector<int> array;
78 array.push_back(atoi(optarg));
79
80 const char* p = strchr(optarg, ',');
81 while (p)
82 {
83 p++;
84 array.push_back(atoi(p));
85 p = strchr(p, ',');
86 }
87
88 return array;
89 }
90 #endif // _WIN32
91
92 // ncnn
93 #include "cpu.h"
94 #include "gpu.h"
95 #include "platform.h"
96
97 #include "waifu2x.h"
98
99 #include "filesystem_utils.h"
100
print_usage()101 static void print_usage()
102 {
103 fprintf(stdout, "Usage: waifu2x-ncnn-vulkan -i infile -o outfile [options]...\n\n");
104 fprintf(stdout, " -h show this help\n");
105 fprintf(stdout, " -v verbose output\n");
106 fprintf(stdout, " -i input-path input image path (jpg/png/webp) or directory\n");
107 fprintf(stdout, " -o output-path output image path (jpg/png/webp) or directory\n");
108 fprintf(stdout, " -n noise-level denoise level (-1/0/1/2/3, default=0)\n");
109 fprintf(stdout, " -s scale upscale ratio (1/2/4/8/16/32, default=2)\n");
110 fprintf(stdout, " -t tile-size tile size (>=32/0=auto, default=0) can be 0,0,0 for multi-gpu\n");
111 fprintf(stdout, " -m model-path waifu2x model path (default=models-cunet)\n");
112 fprintf(stdout, " -g gpu-id gpu device to use (-1=cpu, default=auto) can be 0,1,2 for multi-gpu\n");
113 fprintf(stdout, " -j load:proc:save thread count for load/proc/save (default=1:2:2) can be 1:2,2,2:2 for multi-gpu\n");
114 fprintf(stdout, " -x enable tta mode\n");
115 fprintf(stdout, " -f format output image format (jpg/png/webp, default=ext/png)\n");
116 }
117
118 class Task
119 {
120 public:
121 int id;
122 int webp;
123 int scale;
124
125 path_t inpath;
126 path_t outpath;
127
128 ncnn::Mat inimage;
129 ncnn::Mat outimage;
130 };
131
132 class TaskQueue
133 {
134 public:
TaskQueue()135 TaskQueue()
136 {
137 }
138
put(const Task & v)139 void put(const Task& v)
140 {
141 lock.lock();
142
143 while (tasks.size() >= 8) // FIXME hardcode queue length
144 {
145 condition.wait(lock);
146 }
147
148 tasks.push(v);
149
150 lock.unlock();
151
152 condition.signal();
153 }
154
get(Task & v)155 void get(Task& v)
156 {
157 lock.lock();
158
159 while (tasks.size() == 0)
160 {
161 condition.wait(lock);
162 }
163
164 v = tasks.front();
165 tasks.pop();
166
167 lock.unlock();
168
169 condition.signal();
170 }
171
172 private:
173 ncnn::Mutex lock;
174 ncnn::ConditionVariable condition;
175 std::queue<Task> tasks;
176 };
177
178 TaskQueue toproc;
179 TaskQueue tosave;
180
181 class LoadThreadParams
182 {
183 public:
184 int scale;
185 int jobs_load;
186
187 // session data
188 std::vector<path_t> input_files;
189 std::vector<path_t> output_files;
190 };
191
load(void * args)192 void* load(void* args)
193 {
194 const LoadThreadParams* ltp = (const LoadThreadParams*)args;
195 const int count = ltp->input_files.size();
196 const int scale = ltp->scale;
197
198 #pragma omp parallel for schedule(static,1) num_threads(ltp->jobs_load)
199 for (int i=0; i<count; i++)
200 {
201 const path_t& imagepath = ltp->input_files[i];
202
203 int webp = 0;
204
205 unsigned char* pixeldata = 0;
206 int w;
207 int h;
208 int c;
209
210 #if _WIN32
211 FILE* fp = _wfopen(imagepath.c_str(), L"rb");
212 #else
213 FILE* fp = fopen(imagepath.c_str(), "rb");
214 #endif
215 if (fp)
216 {
217 // read whole file
218 unsigned char* filedata = 0;
219 int length = 0;
220 {
221 fseek(fp, 0, SEEK_END);
222 length = ftell(fp);
223 rewind(fp);
224 filedata = (unsigned char*)malloc(length);
225 if (filedata)
226 {
227 fread(filedata, 1, length, fp);
228 }
229 fclose(fp);
230 }
231
232 if (filedata)
233 {
234 pixeldata = webp_load(filedata, length, &w, &h, &c);
235 if (pixeldata)
236 {
237 webp = 1;
238 }
239 else
240 {
241 // not webp, try jpg png etc.
242 #if _WIN32
243 pixeldata = wic_decode_image(imagepath.c_str(), &w, &h, &c);
244 #else // _WIN32
245 pixeldata = stbi_load_from_memory(filedata, length, &w, &h, &c, 0);
246 if (pixeldata)
247 {
248 // stb_image auto channel
249 if (c == 1)
250 {
251 // grayscale -> rgb
252 stbi_image_free(pixeldata);
253 pixeldata = stbi_load_from_memory(filedata, length, &w, &h, &c, 3);
254 c = 3;
255 }
256 else if (c == 2)
257 {
258 // grayscale + alpha -> rgba
259 stbi_image_free(pixeldata);
260 pixeldata = stbi_load_from_memory(filedata, length, &w, &h, &c, 4);
261 c = 4;
262 }
263 }
264 #endif // _WIN32
265 }
266
267 free(filedata);
268 }
269 }
270 if (pixeldata)
271 {
272 Task v;
273 v.id = i;
274 v.webp = webp;
275 v.scale = scale;
276 v.inpath = imagepath;
277 v.outpath = ltp->output_files[i];
278
279 v.inimage = ncnn::Mat(w, h, (void*)pixeldata, (size_t)c, c);
280
281 path_t ext = get_file_extension(v.outpath);
282 if (c == 4 && (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG")))
283 {
284 path_t output_filename2 = ltp->output_files[i] + PATHSTR(".png");
285 v.outpath = output_filename2;
286 #if _WIN32
287 fwprintf(stderr, L"image %ls has alpha channel ! %ls will output %ls\n", imagepath.c_str(), imagepath.c_str(), output_filename2.c_str());
288 #else // _WIN32
289 fprintf(stderr, "image %s has alpha channel ! %s will output %s\n", imagepath.c_str(), imagepath.c_str(), output_filename2.c_str());
290 #endif // _WIN32
291 }
292
293 toproc.put(v);
294 }
295 else
296 {
297 #if _WIN32
298 fwprintf(stderr, L"decode image %ls failed\n", imagepath.c_str());
299 #else // _WIN32
300 fprintf(stderr, "decode image %s failed\n", imagepath.c_str());
301 #endif // _WIN32
302 }
303 }
304
305 return 0;
306 }
307
308 class ProcThreadParams
309 {
310 public:
311 const Waifu2x* waifu2x;
312 };
313
proc(void * args)314 void* proc(void* args)
315 {
316 const ProcThreadParams* ptp = (const ProcThreadParams*)args;
317 const Waifu2x* waifu2x = ptp->waifu2x;
318
319 for (;;)
320 {
321 Task v;
322
323 toproc.get(v);
324
325 if (v.id == -233)
326 break;
327
328 const int scale = v.scale;
329 if (scale == 1)
330 {
331 v.outimage = ncnn::Mat(v.inimage.w, v.inimage.h, (size_t)v.inimage.elemsize, (int)v.inimage.elemsize);
332 waifu2x->process(v.inimage, v.outimage);
333
334 tosave.put(v);
335 continue;
336 }
337
338 int scale_run_count = 0;
339 if (scale == 2)
340 {
341 scale_run_count = 1;
342 }
343 if (scale == 4)
344 {
345 scale_run_count = 2;
346 }
347 if (scale == 8)
348 {
349 scale_run_count = 3;
350 }
351 if (scale == 16)
352 {
353 scale_run_count = 4;
354 }
355 if (scale == 32)
356 {
357 scale_run_count = 5;
358 }
359
360 v.outimage = ncnn::Mat(v.inimage.w * 2, v.inimage.h * 2, (size_t)v.inimage.elemsize, (int)v.inimage.elemsize);
361 waifu2x->process(v.inimage, v.outimage);
362
363 for (int i = 1; i < scale_run_count; i++)
364 {
365 ncnn::Mat tmp = v.outimage;
366 v.outimage = ncnn::Mat(tmp.w * 2, tmp.h * 2, (size_t)v.inimage.elemsize, (int)v.inimage.elemsize);
367 waifu2x->process(tmp, v.outimage);
368 }
369
370 tosave.put(v);
371 }
372
373 return 0;
374 }
375
376 class SaveThreadParams
377 {
378 public:
379 int verbose;
380 };
381
save(void * args)382 void* save(void* args)
383 {
384 const SaveThreadParams* stp = (const SaveThreadParams*)args;
385 const int verbose = stp->verbose;
386
387 for (;;)
388 {
389 Task v;
390
391 tosave.get(v);
392
393 if (v.id == -233)
394 break;
395
396 // free input pixel data
397 {
398 unsigned char* pixeldata = (unsigned char*)v.inimage.data;
399 if (v.webp == 1)
400 {
401 free(pixeldata);
402 }
403 else
404 {
405 #if _WIN32
406 free(pixeldata);
407 #else
408 stbi_image_free(pixeldata);
409 #endif
410 }
411 }
412
413 int success = 0;
414
415 path_t ext = get_file_extension(v.outpath);
416
417 if (ext == PATHSTR("webp") || ext == PATHSTR("WEBP"))
418 {
419 success = webp_save(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, (const unsigned char*)v.outimage.data);
420 }
421 else if (ext == PATHSTR("png") || ext == PATHSTR("PNG"))
422 {
423 #if _WIN32
424 success = wic_encode_image(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data);
425 #else
426 success = stbi_write_png(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data, 0);
427 #endif
428 }
429 else if (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG"))
430 {
431 #if _WIN32
432 success = wic_encode_jpeg_image(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data);
433 #else
434 success = stbi_write_jpg(v.outpath.c_str(), v.outimage.w, v.outimage.h, v.outimage.elempack, v.outimage.data, 100);
435 #endif
436 }
437 if (success)
438 {
439 if (verbose)
440 {
441 #if _WIN32
442 fwprintf(stdout, L"%ls -> %ls done\n", v.inpath.c_str(), v.outpath.c_str());
443 #else
444 fprintf(stdout, "%s -> %s done\n", v.inpath.c_str(), v.outpath.c_str());
445 #endif
446 }
447 }
448 else
449 {
450 #if _WIN32
451 fwprintf(stderr, L"encode image %ls failed\n", v.outpath.c_str());
452 #else
453 fprintf(stderr, "encode image %s failed\n", v.outpath.c_str());
454 #endif
455 }
456 }
457
458 return 0;
459 }
460
461
462 #if _WIN32
wmain(int argc,wchar_t ** argv)463 int wmain(int argc, wchar_t** argv)
464 #else
465 int main(int argc, char** argv)
466 #endif
467 {
468 path_t inputpath;
469 path_t outputpath;
470 int noise = 0;
471 int scale = 2;
472 std::vector<int> tilesize;
473 path_t model = PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-cunet");
474 std::vector<int> gpuid;
475 int jobs_load = 1;
476 std::vector<int> jobs_proc;
477 int jobs_save = 2;
478 int verbose = 0;
479 int tta_mode = 0;
480 path_t format = PATHSTR("png");
481
482 #if _WIN32
483 setlocale(LC_ALL, "");
484 wchar_t opt;
485 while ((opt = getopt(argc, argv, L"i:o:n:s:t:m:g:j:f:vxh")) != (wchar_t)-1)
486 {
487 switch (opt)
488 {
489 case L'i':
490 inputpath = optarg;
491 break;
492 case L'o':
493 outputpath = optarg;
494 break;
495 case L'n':
496 noise = _wtoi(optarg);
497 break;
498 case L's':
499 scale = _wtoi(optarg);
500 break;
501 case L't':
502 tilesize = parse_optarg_int_array(optarg);
503 break;
504 case L'm':
505 model = optarg;
506 break;
507 case L'g':
508 gpuid = parse_optarg_int_array(optarg);
509 break;
510 case L'j':
511 swscanf(optarg, L"%d:%*[^:]:%d", &jobs_load, &jobs_save);
512 jobs_proc = parse_optarg_int_array(wcschr(optarg, L':') + 1);
513 break;
514 case L'f':
515 format = optarg;
516 break;
517 case L'v':
518 verbose = 1;
519 break;
520 case L'x':
521 tta_mode = 1;
522 break;
523 case L'h':
524 default:
525 print_usage();
526 return -1;
527 }
528 }
529 #else // _WIN32
530 int opt;
531 while ((opt = getopt(argc, argv, "i:o:n:s:t:m:g:j:f:vxh")) != -1)
532 {
533 switch (opt)
534 {
535 case 'i':
536 inputpath = optarg;
537 break;
538 case 'o':
539 outputpath = optarg;
540 break;
541 case 'n':
542 noise = atoi(optarg);
543 break;
544 case 's':
545 scale = atoi(optarg);
546 break;
547 case 't':
548 tilesize = parse_optarg_int_array(optarg);
549 break;
550 case 'm':
551 model = optarg;
552 break;
553 case 'g':
554 gpuid = parse_optarg_int_array(optarg);
555 break;
556 case 'j':
557 sscanf(optarg, "%d:%*[^:]:%d", &jobs_load, &jobs_save);
558 jobs_proc = parse_optarg_int_array(strchr(optarg, ':') + 1);
559 break;
560 case 'f':
561 format = optarg;
562 break;
563 case 'v':
564 verbose = 1;
565 break;
566 case 'x':
567 tta_mode = 1;
568 break;
569 case 'h':
570 default:
571 print_usage();
572 return -1;
573 }
574 }
575 #endif // _WIN32
576
577 if (inputpath.empty() || outputpath.empty())
578 {
579 print_usage();
580 return -1;
581 }
582
583 if (noise < -1 || noise > 3)
584 {
585 fprintf(stderr, "invalid noise argument\n");
586 return -1;
587 }
588
589 if (!(scale == 1 || scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32))
590 {
591 fprintf(stderr, "invalid scale argument\n");
592 return -1;
593 }
594
595 if (tilesize.size() != (gpuid.empty() ? 1 : gpuid.size()) && !tilesize.empty())
596 {
597 fprintf(stderr, "invalid tilesize argument\n");
598 return -1;
599 }
600
601 for (int i=0; i<(int)tilesize.size(); i++)
602 {
603 if (tilesize[i] != 0 && tilesize[i] < 32)
604 {
605 fprintf(stderr, "invalid tilesize argument\n");
606 return -1;
607 }
608 }
609
610 if (jobs_load < 1 || jobs_save < 1)
611 {
612 fprintf(stderr, "invalid thread count argument\n");
613 return -1;
614 }
615
616 if (jobs_proc.size() != (gpuid.empty() ? 1 : gpuid.size()) && !jobs_proc.empty())
617 {
618 fprintf(stderr, "invalid jobs_proc thread count argument\n");
619 return -1;
620 }
621
622 for (int i=0; i<(int)jobs_proc.size(); i++)
623 {
624 if (jobs_proc[i] < 1)
625 {
626 fprintf(stderr, "invalid jobs_proc thread count argument\n");
627 return -1;
628 }
629 }
630
631 if (!path_is_directory(outputpath))
632 {
633 // guess format from outputpath no matter what format argument specified
634 path_t ext = get_file_extension(outputpath);
635
636 if (ext == PATHSTR("png") || ext == PATHSTR("PNG"))
637 {
638 format = PATHSTR("png");
639 }
640 else if (ext == PATHSTR("webp") || ext == PATHSTR("WEBP"))
641 {
642 format = PATHSTR("webp");
643 }
644 else if (ext == PATHSTR("jpg") || ext == PATHSTR("JPG") || ext == PATHSTR("jpeg") || ext == PATHSTR("JPEG"))
645 {
646 format = PATHSTR("jpg");
647 }
648 else
649 {
650 fprintf(stderr, "invalid outputpath extension type\n");
651 return -1;
652 }
653 }
654
655 if (format != PATHSTR("png") && format != PATHSTR("webp") && format != PATHSTR("jpg"))
656 {
657 fprintf(stderr, "invalid format argument\n");
658 return -1;
659 }
660
661 // collect input and output filepath
662 std::vector<path_t> input_files;
663 std::vector<path_t> output_files;
664 {
665 if (path_is_directory(inputpath) && path_is_directory(outputpath))
666 {
667 std::vector<path_t> filenames;
668 int lr = list_directory(inputpath, filenames);
669 if (lr != 0)
670 return -1;
671
672 const int count = filenames.size();
673 input_files.resize(count);
674 output_files.resize(count);
675
676 path_t last_filename;
677 path_t last_filename_noext;
678 for (int i=0; i<count; i++)
679 {
680 path_t filename = filenames[i];
681 path_t filename_noext = get_file_name_without_extension(filename);
682 path_t output_filename = filename_noext + PATHSTR('.') + format;
683
684 // filename list is sorted, check if output image path conflicts
685 if (filename_noext == last_filename_noext)
686 {
687 path_t output_filename2 = filename + PATHSTR('.') + format;
688 #if _WIN32
689 fwprintf(stderr, L"both %ls and %ls output %ls ! %ls will output %ls\n", filename.c_str(), last_filename.c_str(), output_filename.c_str(), filename.c_str(), output_filename2.c_str());
690 #else
691 fprintf(stderr, "both %s and %s output %s ! %s will output %s\n", filename.c_str(), last_filename.c_str(), output_filename.c_str(), filename.c_str(), output_filename2.c_str());
692 #endif
693 output_filename = output_filename2;
694 }
695 else
696 {
697 last_filename = filename;
698 last_filename_noext = filename_noext;
699 }
700
701 input_files[i] = inputpath + PATHSTR('/') + filename;
702 output_files[i] = outputpath + PATHSTR('/') + output_filename;
703 }
704 }
705 else if (!path_is_directory(inputpath) && !path_is_directory(outputpath))
706 {
707 input_files.push_back(inputpath);
708 output_files.push_back(outputpath);
709 }
710 else
711 {
712 fprintf(stderr, "inputpath and outputpath must be either file or directory at the same time\n");
713 return -1;
714 }
715 }
716
717 int prepadding = 0;
718
719 if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-cunet")) != path_t::npos)
720 {
721 if (noise == -1)
722 {
723 prepadding = 18;
724 }
725 else if (scale == 1)
726 {
727 prepadding = 28;
728 }
729 else if (scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32)
730 {
731 prepadding = 18;
732 }
733 }
734 else if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_anime_style_art_rgb")) != path_t::npos)
735 {
736 prepadding = 7;
737 }
738 else if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_photo")) != path_t::npos)
739 {
740 prepadding = 7;
741 }
742 else
743 {
744 fprintf(stderr, "unknown model dir type\n");
745 return -1;
746 }
747
748 #if _WIN32
749 wchar_t parampath[256];
750 wchar_t modelpath[256];
751 if (noise == -1)
752 {
753 swprintf(parampath, 256, L"%s/scale2.0x_model.param", model.c_str());
754 swprintf(modelpath, 256, L"%s/scale2.0x_model.bin", model.c_str());
755 }
756 else if (scale == 1)
757 {
758 swprintf(parampath, 256, L"%s/noise%d_model.param", model.c_str(), noise);
759 swprintf(modelpath, 256, L"%s/noise%d_model.bin", model.c_str(), noise);
760 }
761 else if (scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32)
762 {
763 swprintf(parampath, 256, L"%s/noise%d_scale2.0x_model.param", model.c_str(), noise);
764 swprintf(modelpath, 256, L"%s/noise%d_scale2.0x_model.bin", model.c_str(), noise);
765 }
766 #else
767 char parampath[256];
768 char modelpath[256];
769 if (noise == -1)
770 {
771 sprintf(parampath, "%s/scale2.0x_model.param", model.c_str());
772 sprintf(modelpath, "%s/scale2.0x_model.bin", model.c_str());
773 }
774 else if (scale == 1)
775 {
776 sprintf(parampath, "%s/noise%d_model.param", model.c_str(), noise);
777 sprintf(modelpath, "%s/noise%d_model.bin", model.c_str(), noise);
778 }
779 else if (scale == 2 || scale == 4 || scale == 8 || scale == 16 || scale == 32)
780 {
781 sprintf(parampath, "%s/noise%d_scale2.0x_model.param", model.c_str(), noise);
782 sprintf(modelpath, "%s/noise%d_scale2.0x_model.bin", model.c_str(), noise);
783 }
784 #endif
785
786 path_t paramfullpath = sanitize_filepath(parampath);
787 path_t modelfullpath = sanitize_filepath(modelpath);
788
789 #if _WIN32
790 CoInitializeEx(NULL, COINIT_MULTITHREADED);
791 #endif
792
793 ncnn::create_gpu_instance();
794
795 if (gpuid.empty())
796 {
797 gpuid.push_back(ncnn::get_default_gpu_index());
798 }
799
800 const int use_gpu_count = (int)gpuid.size();
801
802 if (jobs_proc.empty())
803 {
804 jobs_proc.resize(use_gpu_count, 2);
805 }
806
807 if (tilesize.empty())
808 {
809 tilesize.resize(use_gpu_count, 0);
810 }
811
812 int cpu_count = std::max(1, ncnn::get_cpu_count());
813 jobs_load = std::min(jobs_load, cpu_count);
814 jobs_save = std::min(jobs_save, cpu_count);
815
816 int gpu_count = ncnn::get_gpu_count();
817 for (int i=0; i<use_gpu_count; i++)
818 {
819 if (gpuid[i] < -1 || gpuid[i] >= gpu_count)
820 {
821 fprintf(stderr, "invalid gpu device\n");
822
823 ncnn::destroy_gpu_instance();
824 return -1;
825 }
826 }
827
828 int total_jobs_proc = 0;
829 for (int i=0; i<use_gpu_count; i++)
830 {
831 if (gpuid[i] == -1)
832 {
833 jobs_proc[i] = std::min(jobs_proc[i], cpu_count);
834 total_jobs_proc += 1;
835 }
836 else
837 {
838 total_jobs_proc += jobs_proc[i];
839 }
840 }
841
842 for (int i=0; i<use_gpu_count; i++)
843 {
844 if (tilesize[i] != 0)
845 continue;
846
847 if (gpuid[i] == -1)
848 {
849 // cpu only
850 tilesize[i] = 4000;
851 continue;
852 }
853
854 uint32_t heap_budget = ncnn::get_gpu_device(gpuid[i])->get_heap_budget();
855
856 // more fine-grained tilesize policy here
857 if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-cunet")) != path_t::npos)
858 {
859 if (heap_budget > 2600)
860 tilesize[i] = 400;
861 else if (heap_budget > 740)
862 tilesize[i] = 200;
863 else if (heap_budget > 250)
864 tilesize[i] = 100;
865 else
866 tilesize[i] = 32;
867 }
868 else if (model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_anime_style_art_rgb")) != path_t::npos
869 || model.find(PATHSTR("/usr/local/share/waifu2x-ncnn-vulkan/models-upconv_7_photo")) != path_t::npos)
870 {
871 if (heap_budget > 1900)
872 tilesize[i] = 400;
873 else if (heap_budget > 550)
874 tilesize[i] = 200;
875 else if (heap_budget > 190)
876 tilesize[i] = 100;
877 else
878 tilesize[i] = 32;
879 }
880 }
881
882 {
883 std::vector<Waifu2x*> waifu2x(use_gpu_count);
884
885 for (int i=0; i<use_gpu_count; i++)
886 {
887 int num_threads = gpuid[i] == -1 ? jobs_proc[i] : 1;
888
889 waifu2x[i] = new Waifu2x(gpuid[i], tta_mode, num_threads);
890
891 waifu2x[i]->load(paramfullpath, modelfullpath);
892
893 waifu2x[i]->noise = noise;
894 waifu2x[i]->scale = (scale >= 2) ? 2 : scale;
895 waifu2x[i]->tilesize = tilesize[i];
896 waifu2x[i]->prepadding = prepadding;
897 }
898
899 // main routine
900 {
901 // load image
902 LoadThreadParams ltp;
903 ltp.scale = scale;
904 ltp.jobs_load = jobs_load;
905 ltp.input_files = input_files;
906 ltp.output_files = output_files;
907
908 ncnn::Thread load_thread(load, (void*)<p);
909
910 // waifu2x proc
911 std::vector<ProcThreadParams> ptp(use_gpu_count);
912 for (int i=0; i<use_gpu_count; i++)
913 {
914 ptp[i].waifu2x = waifu2x[i];
915 }
916
917 std::vector<ncnn::Thread*> proc_threads(total_jobs_proc);
918 {
919 int total_jobs_proc_id = 0;
920 for (int i=0; i<use_gpu_count; i++)
921 {
922 if (gpuid[i] == -1)
923 {
924 proc_threads[total_jobs_proc_id++] = new ncnn::Thread(proc, (void*)&ptp[i]);
925 }
926 else
927 {
928 for (int j=0; j<jobs_proc[i]; j++)
929 {
930 proc_threads[total_jobs_proc_id++] = new ncnn::Thread(proc, (void*)&ptp[i]);
931 }
932 }
933 }
934 }
935
936 // save image
937 SaveThreadParams stp;
938 stp.verbose = verbose;
939
940 std::vector<ncnn::Thread*> save_threads(jobs_save);
941 for (int i=0; i<jobs_save; i++)
942 {
943 save_threads[i] = new ncnn::Thread(save, (void*)&stp);
944 }
945
946 // end
947 load_thread.join();
948
949 Task end;
950 end.id = -233;
951
952 for (int i=0; i<total_jobs_proc; i++)
953 {
954 toproc.put(end);
955 }
956
957 for (int i=0; i<total_jobs_proc; i++)
958 {
959 proc_threads[i]->join();
960 delete proc_threads[i];
961 }
962
963 for (int i=0; i<jobs_save; i++)
964 {
965 tosave.put(end);
966 }
967
968 for (int i=0; i<jobs_save; i++)
969 {
970 save_threads[i]->join();
971 delete save_threads[i];
972 }
973 }
974
975 for (int i=0; i<use_gpu_count; i++)
976 {
977 delete waifu2x[i];
978 }
979 waifu2x.clear();
980 }
981
982 ncnn::destroy_gpu_instance();
983
984 return 0;
985 }
986