1 /*
2 MIT License
3
4 Copyright (c) 2018-2019 HolyWu
5
6 Permission is hereby granted, free of charge, to any person obtaining a copy
7 of this software and associated documentation files (the "Software"), to deal
8 in the Software without restriction, including without limitation the rights
9 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 copies of the Software, and to permit persons to whom the Software is
11 furnished to do so, subject to the following conditions:
12
13 The above copyright notice and this permission notice shall be included in all
14 copies or substantial portions of the Software.
15
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 SOFTWARE.
23 */
24
25 #include <cmath>
26 #include <string>
27
28 #include <VapourSynth.h>
29 #include <VSHelper.h>
30
31 #include <w2xconv.h>
32
33 struct Waifu2xData {
34 VSNodeRef * node;
35 VSVideoInfo vi;
36 int noise, scale, block;
37 int iterTimesTwiceScaling;
38 float * srcInterleaved, * dstInterleaved;
39 W2XConv * conv;
40 };
41
isPowerOf2(const int i)42 static bool isPowerOf2(const int i) noexcept {
43 return i && !(i & (i - 1));
44 }
45
filter(const VSFrameRef * src,VSFrameRef * dst,Waifu2xData * const VS_RESTRICT d,const VSAPI * vsapi)46 static bool filter(const VSFrameRef * src, VSFrameRef * dst, Waifu2xData * const VS_RESTRICT d, const VSAPI * vsapi) noexcept {
47 const int width = vsapi->getFrameWidth(src, 0);
48 const int height = vsapi->getFrameHeight(src, 0);
49 const int srcStride = vsapi->getStride(src, 0) / sizeof(float);
50 const int dstStride = vsapi->getStride(dst, 0) / sizeof(float);
51 const float * srcpR = reinterpret_cast<const float *>(vsapi->getReadPtr(src, 0));
52 const float * srcpG = reinterpret_cast<const float *>(vsapi->getReadPtr(src, 1));
53 const float * srcpB = reinterpret_cast<const float *>(vsapi->getReadPtr(src, 2));
54 float * VS_RESTRICT dstpR = reinterpret_cast<float *>(vsapi->getWritePtr(dst, 0));
55 float * VS_RESTRICT dstpG = reinterpret_cast<float *>(vsapi->getWritePtr(dst, 1));
56 float * VS_RESTRICT dstpB = reinterpret_cast<float *>(vsapi->getWritePtr(dst, 2));
57
58 for (int y = 0; y < height; y++) {
59 for (int x = 0; x < width; x++) {
60 const int pos = (width * y + x) * 3;
61 d->srcInterleaved[pos + 0] = srcpR[x];
62 d->srcInterleaved[pos + 1] = srcpG[x];
63 d->srcInterleaved[pos + 2] = srcpB[x];
64 }
65
66 srcpR += srcStride;
67 srcpG += srcStride;
68 srcpB += srcStride;
69 }
70
71 if (w2xconv_convert_rgb_f32(d->conv,
72 reinterpret_cast<unsigned char *>(d->dstInterleaved),
73 d->vi.width * 3 * sizeof(float),
74 reinterpret_cast<unsigned char *>(d->srcInterleaved),
75 width * 3 * sizeof(float),
76 width, height, d->noise, d->scale, d->block) < 0)
77 return false;
78
79 for (int y = 0; y < d->vi.height; y++) {
80 for (int x = 0; x < d->vi.width; x++) {
81 const int pos = (d->vi.width * y + x) * 3;
82 dstpR[x] = d->dstInterleaved[pos + 0];
83 dstpG[x] = d->dstInterleaved[pos + 1];
84 dstpB[x] = d->dstInterleaved[pos + 2];
85 }
86
87 dstpR += dstStride;
88 dstpG += dstStride;
89 dstpB += dstStride;
90 }
91
92 return true;
93 }
94
waifu2xInit(VSMap * in,VSMap * out,void ** instanceData,VSNode * node,VSCore * core,const VSAPI * vsapi)95 static void VS_CC waifu2xInit(VSMap *in, VSMap *out, void **instanceData, VSNode *node, VSCore *core, const VSAPI *vsapi) {
96 Waifu2xData * d = static_cast<Waifu2xData *>(*instanceData);
97 vsapi->setVideoInfo(&d->vi, 1, node);
98 }
99
waifu2xGetFrame(int n,int activationReason,void ** instanceData,void ** frameData,VSFrameContext * frameCtx,VSCore * core,const VSAPI * vsapi)100 static const VSFrameRef *VS_CC waifu2xGetFrame(int n, int activationReason, void **instanceData, void **frameData, VSFrameContext *frameCtx, VSCore *core, const VSAPI *vsapi) {
101 Waifu2xData * d = static_cast<Waifu2xData *>(*instanceData);
102
103 if (activationReason == arInitial) {
104 vsapi->requestFrameFilter(n, d->node, frameCtx);
105 } else if (activationReason == arAllFramesReady) {
106 const VSFrameRef * src = vsapi->getFrameFilter(n, d->node, frameCtx);
107 VSFrameRef * dst = vsapi->newVideoFrame(d->vi.format, d->vi.width, d->vi.height, src, core);
108
109 if (!filter(src, dst, d, vsapi)) {
110 char * error = w2xconv_strerror(&d->conv->last_error);
111 vsapi->setFilterError((std::string{ "Waifu2x-w2xc: " } + error).c_str(), frameCtx);
112 w2xconv_free(error);
113 vsapi->freeFrame(src);
114 vsapi->freeFrame(dst);
115 return nullptr;
116 }
117
118 vsapi->freeFrame(src);
119 return dst;
120 }
121
122 return nullptr;
123 }
124
waifu2xFree(void * instanceData,VSCore * core,const VSAPI * vsapi)125 static void VS_CC waifu2xFree(void *instanceData, VSCore *core, const VSAPI *vsapi) {
126 Waifu2xData * d = static_cast<Waifu2xData *>(instanceData);
127
128 vsapi->freeNode(d->node);
129
130 delete[] d->srcInterleaved;
131 delete[] d->dstInterleaved;
132
133 w2xconv_fini(d->conv);
134
135 delete d;
136 }
137
waifu2xCreate(const VSMap * in,VSMap * out,void * userData,VSCore * core,const VSAPI * vsapi)138 static void VS_CC waifu2xCreate(const VSMap *in, VSMap *out, void *userData, VSCore *core, const VSAPI *vsapi) {
139 Waifu2xData d{};
140 int err;
141
142 d.node = vsapi->propGetNode(in, "clip", 0, nullptr);
143 d.vi = *vsapi->getVideoInfo(d.node);
144
145 try {
146 if (!isConstantFormat(&d.vi) || d.vi.format->colorFamily != cmRGB || d.vi.format->sampleType != stFloat || d.vi.format->bitsPerSample != 32)
147 throw std::string{ "only constant RGB format and 32 bit float input supported" };
148
149 d.noise = int64ToIntS(vsapi->propGetInt(in, "noise", 0, &err));
150
151 d.scale = int64ToIntS(vsapi->propGetInt(in, "scale", 0, &err));
152 if (err)
153 d.scale = 2;
154
155 d.block = int64ToIntS(vsapi->propGetInt(in, "block", 0, &err));
156 if (err)
157 d.block = 512;
158
159 const bool photo = !!vsapi->propGetInt(in, "photo", 0, &err);
160
161 W2XConvGPUMode gpu = static_cast<W2XConvGPUMode>(int64ToIntS(vsapi->propGetInt(in, "gpu", 0, &err)));
162 if (err)
163 gpu = W2XCONV_GPU_AUTO;
164
165 int processor = int64ToIntS(vsapi->propGetInt(in, "processor", 0, &err));
166 if (err)
167 processor = -1;
168
169 const bool log = !!vsapi->propGetInt(in, "log", 0, &err);
170
171 size_t numProcessors;
172 const W2XConvProcessor * processors = w2xconv_get_processor_list(&numProcessors);
173
174 if (d.noise < -1 || d.noise > 3)
175 throw std::string{ "noise must be -1, 0, 1, 2, or 3" };
176
177 if (d.scale < 1 || !isPowerOf2(d.scale))
178 throw std::string{ "scale must be greater than or equal to 1 and be a power of 2" };
179
180 if (d.block < 1)
181 throw std::string{ "block must be greater than or equal to 1" };
182
183 if (gpu < 0 || gpu > 2)
184 throw std::string{ "gpu must be 0, 1, or 2" };
185
186 if (processor >= static_cast<int>(numProcessors))
187 throw std::string{ "the specified processor is not available" };
188
189 if (!!vsapi->propGetInt(in, "list_proc", 0, &err)) {
190 std::string text;
191
192 for (size_t i = 0; i < numProcessors; i++) {
193 const W2XConvProcessor * p = &processors[i];
194 const char * type;
195
196 switch (p->type) {
197 case W2XCONV_PROC_HOST:
198 switch (p->sub_type) {
199 case W2XCONV_PROC_HOST_FMA:
200 type = "FMA";
201 break;
202 case W2XCONV_PROC_HOST_AVX:
203 type = "AVX";
204 break;
205 case W2XCONV_PROC_HOST_SSE3:
206 type = "SSE3";
207 break;
208 default:
209 type = "OpenCV";
210 }
211 break;
212
213 case W2XCONV_PROC_CUDA:
214 type = "CUDA";
215 break;
216
217 case W2XCONV_PROC_OPENCL:
218 type = "OpenCL";
219 break;
220
221 default:
222 type = "unknown";
223 }
224
225 text += std::to_string(i) + ": " + p->dev_name + " (" + type + ")\n";
226 }
227
228 VSMap * args = vsapi->createMap();
229 vsapi->propSetNode(args, "clip", d.node, paReplace);
230 vsapi->freeNode(d.node);
231 vsapi->propSetData(args, "text", text.c_str(), -1, paReplace);
232
233 VSMap * ret = vsapi->invoke(vsapi->getPluginById("com.vapoursynth.text", core), "Text", args);
234 if (vsapi->getError(ret)) {
235 vsapi->setError(out, vsapi->getError(ret));
236 vsapi->freeMap(args);
237 vsapi->freeMap(ret);
238 return;
239 }
240
241 d.node = vsapi->propGetNode(ret, "clip", 0, nullptr);
242 vsapi->freeMap(args);
243 vsapi->freeMap(ret);
244 vsapi->propSetNode(out, "clip", d.node, paReplace);
245 vsapi->freeNode(d.node);
246 return;
247 }
248
249 if (d.noise == -1 && d.scale == 1) {
250 vsapi->propSetNode(out, "clip", d.node, paReplace);
251 vsapi->freeNode(d.node);
252 return;
253 }
254
255 if (d.scale != 1) {
256 d.vi.width *= d.scale;
257 d.vi.height *= d.scale;
258 d.iterTimesTwiceScaling = static_cast<int>(std::log2(d.scale));
259 }
260
261 d.srcInterleaved = new (std::nothrow) float[vsapi->getVideoInfo(d.node)->width * vsapi->getVideoInfo(d.node)->height * 3];
262 d.dstInterleaved = new (std::nothrow) float[d.vi.width * d.vi.height * 3];
263 if (!d.srcInterleaved || !d.dstInterleaved)
264 throw std::string{ "malloc failure (srcInterleaved/dstInterleaved)" };
265
266 const int numThreads = vsapi->getCoreInfo(core)->numThreads;
267 if (processor > -1)
268 d.conv = w2xconv_init_with_processor(processor, numThreads, log);
269 else
270 d.conv = w2xconv_init(gpu, numThreads, log);
271
272 const std::string pluginPath{ vsapi->getPluginPath(vsapi->getPluginById("com.holywu.waifu2x-w2xc", core)) };
273 std::string modelPath{ pluginPath.substr(0, pluginPath.find_last_of('/')) };
274 if (photo)
275 modelPath += "/models/photo";
276 else
277 modelPath += "/models/anime_style_art_rgb";
278
279 if (w2xconv_load_models(d.conv, modelPath.c_str()) < 0) {
280 char * error = w2xconv_strerror(&d.conv->last_error);
281 vsapi->setError(out, (std::string{ "Waifu2x-w2xc: " } + error).c_str());
282 w2xconv_free(error);
283 vsapi->freeNode(d.node);
284 w2xconv_fini(d.conv);
285 return;
286 }
287 } catch (const std::string & error) {
288 vsapi->setError(out, ("Waifu2x-w2xc: " + error).c_str());
289 vsapi->freeNode(d.node);
290 return;
291 }
292
293 Waifu2xData * data = new Waifu2xData{ d };
294
295 vsapi->createFilter(in, out, "Waifu2x-w2xc", waifu2xInit, waifu2xGetFrame, waifu2xFree, fmParallelRequests, 0, data, core);
296 }
297
298 //////////////////////////////////////////
299 // Init
300
VapourSynthPluginInit(VSConfigPlugin configFunc,VSRegisterFunction registerFunc,VSPlugin * plugin)301 VS_EXTERNAL_API(void) VapourSynthPluginInit(VSConfigPlugin configFunc, VSRegisterFunction registerFunc, VSPlugin *plugin) {
302 configFunc("com.holywu.waifu2x-w2xc", "w2xc", "Image Super-Resolution using Deep Convolutional Neural Networks", VAPOURSYNTH_API_VERSION, 1, plugin);
303 registerFunc("Waifu2x",
304 "clip:clip;"
305 "noise:int:opt;"
306 "scale:int:opt;"
307 "block:int:opt;"
308 "photo:int:opt;"
309 "gpu:int:opt;"
310 "processor:int:opt;"
311 "list_proc:int:opt;"
312 "log:int:opt;",
313 waifu2xCreate, nullptr, plugin);
314 }
315