1 /*
2   MIT License
3 
4   Copyright (c) 2018-2019 HolyWu
5 
6   Permission is hereby granted, free of charge, to any person obtaining a copy
7   of this software and associated documentation files (the "Software"), to deal
8   in the Software without restriction, including without limitation the rights
9   to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10   copies of the Software, and to permit persons to whom the Software is
11   furnished to do so, subject to the following conditions:
12 
13   The above copyright notice and this permission notice shall be included in all
14   copies or substantial portions of the Software.
15 
16   THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17   IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18   FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21   OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22   SOFTWARE.
23 */
24 
25 #include <cmath>
26 #include <string>
27 
28 #include <VapourSynth.h>
29 #include <VSHelper.h>
30 
31 #include <w2xconv.h>
32 
33 struct Waifu2xData {
34     VSNodeRef * node;
35     VSVideoInfo vi;
36     int noise, scale, block;
37     int iterTimesTwiceScaling;
38     float * srcInterleaved, * dstInterleaved;
39     W2XConv * conv;
40 };
41 
isPowerOf2(const int i)42 static bool isPowerOf2(const int i) noexcept {
43     return i && !(i & (i - 1));
44 }
45 
filter(const VSFrameRef * src,VSFrameRef * dst,Waifu2xData * const VS_RESTRICT d,const VSAPI * vsapi)46 static bool filter(const VSFrameRef * src, VSFrameRef * dst, Waifu2xData * const VS_RESTRICT d, const VSAPI * vsapi) noexcept {
47     const int width = vsapi->getFrameWidth(src, 0);
48     const int height = vsapi->getFrameHeight(src, 0);
49     const int srcStride = vsapi->getStride(src, 0) / sizeof(float);
50     const int dstStride = vsapi->getStride(dst, 0) / sizeof(float);
51     const float * srcpR = reinterpret_cast<const float *>(vsapi->getReadPtr(src, 0));
52     const float * srcpG = reinterpret_cast<const float *>(vsapi->getReadPtr(src, 1));
53     const float * srcpB = reinterpret_cast<const float *>(vsapi->getReadPtr(src, 2));
54     float * VS_RESTRICT dstpR = reinterpret_cast<float *>(vsapi->getWritePtr(dst, 0));
55     float * VS_RESTRICT dstpG = reinterpret_cast<float *>(vsapi->getWritePtr(dst, 1));
56     float * VS_RESTRICT dstpB = reinterpret_cast<float *>(vsapi->getWritePtr(dst, 2));
57 
58     for (int y = 0; y < height; y++) {
59         for (int x = 0; x < width; x++) {
60             const int pos = (width * y + x) * 3;
61             d->srcInterleaved[pos + 0] = srcpR[x];
62             d->srcInterleaved[pos + 1] = srcpG[x];
63             d->srcInterleaved[pos + 2] = srcpB[x];
64         }
65 
66         srcpR += srcStride;
67         srcpG += srcStride;
68         srcpB += srcStride;
69     }
70 
71     if (w2xconv_convert_rgb_f32(d->conv,
72                                 reinterpret_cast<unsigned char *>(d->dstInterleaved),
73                                 d->vi.width * 3 * sizeof(float),
74                                 reinterpret_cast<unsigned char *>(d->srcInterleaved),
75                                 width * 3 * sizeof(float),
76                                 width, height, d->noise, d->scale, d->block) < 0)
77         return false;
78 
79     for (int y = 0; y < d->vi.height; y++) {
80         for (int x = 0; x < d->vi.width; x++) {
81             const int pos = (d->vi.width * y + x) * 3;
82             dstpR[x] = d->dstInterleaved[pos + 0];
83             dstpG[x] = d->dstInterleaved[pos + 1];
84             dstpB[x] = d->dstInterleaved[pos + 2];
85         }
86 
87         dstpR += dstStride;
88         dstpG += dstStride;
89         dstpB += dstStride;
90     }
91 
92     return true;
93 }
94 
waifu2xInit(VSMap * in,VSMap * out,void ** instanceData,VSNode * node,VSCore * core,const VSAPI * vsapi)95 static void VS_CC waifu2xInit(VSMap *in, VSMap *out, void **instanceData, VSNode *node, VSCore *core, const VSAPI *vsapi) {
96     Waifu2xData * d = static_cast<Waifu2xData *>(*instanceData);
97     vsapi->setVideoInfo(&d->vi, 1, node);
98 }
99 
waifu2xGetFrame(int n,int activationReason,void ** instanceData,void ** frameData,VSFrameContext * frameCtx,VSCore * core,const VSAPI * vsapi)100 static const VSFrameRef *VS_CC waifu2xGetFrame(int n, int activationReason, void **instanceData, void **frameData, VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi) {
101     Waifu2xData * d = static_cast<Waifu2xData *>(*instanceData);
102 
103     if (activationReason == arInitial) {
104         vsapi->requestFrameFilter(n, d->node, frameCtx);
105     } else if (activationReason == arAllFramesReady) {
106         const VSFrameRef * src = vsapi->getFrameFilter(n, d->node, frameCtx);
107         VSFrameRef * dst = vsapi->newVideoFrame(d->vi.format, d->vi.width, d->vi.height, src, core);
108 
109         if (!filter(src, dst, d, vsapi)) {
110             char * error = w2xconv_strerror(&d->conv->last_error);
111             vsapi->setFilterError((std::string{ "Waifu2x-w2xc: " } + error).c_str(), frameCtx);
112             w2xconv_free(error);
113             vsapi->freeFrame(src);
114             vsapi->freeFrame(dst);
115             return nullptr;
116         }
117 
118         vsapi->freeFrame(src);
119         return dst;
120     }
121 
122     return nullptr;
123 }
124 
waifu2xFree(void * instanceData,VSCore * core,const VSAPI * vsapi)125 static void VS_CC waifu2xFree(void *instanceData, VSCore *core, const VSAPI *vsapi) {
126     Waifu2xData * d = static_cast<Waifu2xData *>(instanceData);
127 
128     vsapi->freeNode(d->node);
129 
130     delete[] d->srcInterleaved;
131     delete[] d->dstInterleaved;
132 
133     w2xconv_fini(d->conv);
134 
135     delete d;
136 }
137 
waifu2xCreate(const VSMap * in,VSMap * out,void * userData,VSCore * core,const VSAPI * vsapi)138 static void VS_CC waifu2xCreate(const VSMap *in, VSMap *out, void *userData, VSCore *core, const VSAPI *vsapi) {
139     Waifu2xData d{};
140     int err;
141 
142     d.node = vsapi->propGetNode(in, "clip", 0, nullptr);
143     d.vi = *vsapi->getVideoInfo(d.node);
144 
145     try {
146         if (!isConstantFormat(&d.vi) || d.vi.format->colorFamily != cmRGB || d.vi.format->sampleType != stFloat || d.vi.format->bitsPerSample != 32)
147             throw std::string{ "only constant RGB format and 32 bit float input supported" };
148 
149         d.noise = int64ToIntS(vsapi->propGetInt(in, "noise", 0, &err));
150 
151         d.scale = int64ToIntS(vsapi->propGetInt(in, "scale", 0, &err));
152         if (err)
153             d.scale = 2;
154 
155         d.block = int64ToIntS(vsapi->propGetInt(in, "block", 0, &err));
156         if (err)
157             d.block = 512;
158 
159         const bool photo = !!vsapi->propGetInt(in, "photo", 0, &err);
160 
161         W2XConvGPUMode gpu = static_cast<W2XConvGPUMode>(int64ToIntS(vsapi->propGetInt(in, "gpu", 0, &err)));
162         if (err)
163             gpu = W2XCONV_GPU_AUTO;
164 
165         int processor = int64ToIntS(vsapi->propGetInt(in, "processor", 0, &err));
166         if (err)
167             processor = -1;
168 
169         const bool log = !!vsapi->propGetInt(in, "log", 0, &err);
170 
171         size_t numProcessors;
172         const W2XConvProcessor * processors = w2xconv_get_processor_list(&numProcessors);
173 
174         if (d.noise < -1 || d.noise > 3)
175             throw std::string{ "noise must be -1, 0, 1, 2, or 3" };
176 
177         if (d.scale < 1 || !isPowerOf2(d.scale))
178             throw std::string{ "scale must be greater than or equal to 1 and be a power of 2" };
179 
180         if (d.block < 1)
181             throw std::string{ "block must be greater than or equal to 1" };
182 
183         if (gpu < 0 || gpu > 2)
184             throw std::string{ "gpu must be 0, 1, or 2" };
185 
186         if (processor >= static_cast<int>(numProcessors))
187             throw std::string{ "the specified processor is not available" };
188 
189         if (!!vsapi->propGetInt(in, "list_proc", 0, &err)) {
190             std::string text;
191 
192             for (size_t i = 0; i < numProcessors; i++) {
193                 const W2XConvProcessor * p = &processors[i];
194                 const char * type;
195 
196                 switch (p->type) {
197                 case W2XCONV_PROC_HOST:
198                     switch (p->sub_type) {
199                     case W2XCONV_PROC_HOST_FMA:
200                         type = "FMA";
201                         break;
202                     case W2XCONV_PROC_HOST_AVX:
203                         type = "AVX";
204                         break;
205                     case W2XCONV_PROC_HOST_SSE3:
206                         type = "SSE3";
207                         break;
208                     default:
209                         type = "OpenCV";
210                     }
211                     break;
212 
213                 case W2XCONV_PROC_CUDA:
214                     type = "CUDA";
215                     break;
216 
217                 case W2XCONV_PROC_OPENCL:
218                     type = "OpenCL";
219                     break;
220 
221                 default:
222                     type = "unknown";
223                 }
224 
225                 text += std::to_string(i) + ": " + p->dev_name + " (" + type + ")\n";
226             }
227 
228             VSMap * args = vsapi->createMap();
229             vsapi->propSetNode(args, "clip", d.node, paReplace);
230             vsapi->freeNode(d.node);
231             vsapi->propSetData(args, "text", text.c_str(), -1, paReplace);
232 
233             VSMap * ret = vsapi->invoke(vsapi->getPluginById("com.vapoursynth.text", core), "Text", args);
234             if (vsapi->getError(ret)) {
235                 vsapi->setError(out, vsapi->getError(ret));
236                 vsapi->freeMap(args);
237                 vsapi->freeMap(ret);
238                 return;
239             }
240 
241             d.node = vsapi->propGetNode(ret, "clip", 0, nullptr);
242             vsapi->freeMap(args);
243             vsapi->freeMap(ret);
244             vsapi->propSetNode(out, "clip", d.node, paReplace);
245             vsapi->freeNode(d.node);
246             return;
247         }
248 
249         if (d.noise == -1 && d.scale == 1) {
250             vsapi->propSetNode(out, "clip", d.node, paReplace);
251             vsapi->freeNode(d.node);
252             return;
253         }
254 
255         if (d.scale != 1) {
256             d.vi.width *= d.scale;
257             d.vi.height *= d.scale;
258             d.iterTimesTwiceScaling = static_cast<int>(std::log2(d.scale));
259         }
260 
261         d.srcInterleaved = new (std::nothrow) float[vsapi->getVideoInfo(d.node)->width * vsapi->getVideoInfo(d.node)->height * 3];
262         d.dstInterleaved = new (std::nothrow) float[d.vi.width * d.vi.height * 3];
263         if (!d.srcInterleaved || !d.dstInterleaved)
264             throw std::string{ "malloc failure (srcInterleaved/dstInterleaved)" };
265 
266         const int numThreads = vsapi->getCoreInfo(core)->numThreads;
267         if (processor > -1)
268             d.conv = w2xconv_init_with_processor(processor, numThreads, log);
269         else
270             d.conv = w2xconv_init(gpu, numThreads, log);
271 
272         const std::string pluginPath{ vsapi->getPluginPath(vsapi->getPluginById("com.holywu.waifu2x-w2xc", core)) };
273         std::string modelPath{ pluginPath.substr(0, pluginPath.find_last_of('/')) };
274         if (photo)
275             modelPath += "/models/photo";
276         else
277             modelPath += "/models/anime_style_art_rgb";
278 
279         if (w2xconv_load_models(d.conv, modelPath.c_str()) < 0) {
280             char * error = w2xconv_strerror(&d.conv->last_error);
281             vsapi->setError(out, (std::string{ "Waifu2x-w2xc: " } + error).c_str());
282             w2xconv_free(error);
283             vsapi->freeNode(d.node);
284             w2xconv_fini(d.conv);
285             return;
286         }
287     } catch (const std::string & error) {
288         vsapi->setError(out, ("Waifu2x-w2xc: " + error).c_str());
289         vsapi->freeNode(d.node);
290         return;
291     }
292 
293     Waifu2xData * data = new Waifu2xData{ d };
294 
295     vsapi->createFilter(in, out, "Waifu2x-w2xc", waifu2xInit, waifu2xGetFrame, waifu2xFree, fmParallelRequests, 0, data, core);
296 }
297 
298 //////////////////////////////////////////
299 // Init
300 
VapourSynthPluginInit(VSConfigPlugin configFunc,VSRegisterFunction registerFunc,VSPlugin * plugin)301 VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegisterFunction registerFunc, VSPlugin *plugin) {
302     configFunc("com.holywu.waifu2x-w2xc", "w2xc", "Image Super-Resolution using Deep Convolutional Neural Networks", VAPOURSYNTH_API_VERSION, 1, plugin);
303     registerFunc("Waifu2x",
304                  "clip:clip;"
305                  "noise:int:opt;"
306                  "scale:int:opt;"
307                  "block:int:opt;"
308                  "photo:int:opt;"
309                  "gpu:int:opt;"
310                  "processor:int:opt;"
311                  "list_proc:int:opt;"
312                  "log:int:opt;",
313                  waifu2xCreate, nullptr, plugin);
314 }
315