1 /*
2  *  Copyright (c) 2017 Eugene Ingerman
3  *
4  *  This program is free software; you can redistribute it and/or modify
5  *  it under the terms of the GNU General Public License as published by
6  *  the Free Software Foundation; either version 2 of the License, or
7  *  (at your option) any later version.
8  *
9  *  This program is distributed in the hope that it will be useful,
10  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *  GNU General Public License for more details.
13  *
14  *  You should have received a copy of the GNU General Public License
15  *  along with this program; if not, write to the Free Software
16  *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
17  */
18 
19 /**
20  * Inpaint using the PatchMatch Algorithm
21  *
22  * | PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
23  * | by Connelly Barnes and Eli Shechtman and Adam Finkelstein and Dan B Goldman
24  * | ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
25  *
26  * Original author Xavier Philippeau
27  * Code adopted from: David Chatting https://github.com/davidchatting/PatchMatch
28  */
29 
30 #include <boost/multi_array.hpp>
31 #include <random>
32 #include <iostream>
33 #include <functional>
34 
35 
36 #include "kis_paint_device.h"
37 #include "kis_painter.h"
38 #include "kis_selection.h"
39 
40 #include "kis_debug.h"
41 #include "kis_paint_device_debug_utils.h"
42 //#include "kis_random_accessor_ng.h"
43 
44 #include <QList>
45 #include <kis_transform_worker.h>
46 #include <kis_filter_strategy.h>
47 #include "KoColor.h"
48 #include "KoColorSpace.h"
49 #include "KoChannelInfo.h"
50 #include "KoMixColorsOp.h"
51 #include "KoColorModelStandardIds.h"
52 #include "KoColorSpaceRegistry.h"
53 #include "KoColorSpaceTraits.h"
54 
55 const int MAX_DIST = 65535;
56 const quint8 MASK_SET = 255;
57 const quint8 MASK_CLEAR = 0;
58 
59 class MaskedImage; //forward decl for the forward decl below
60 template <typename T> float distance_impl(const MaskedImage& my, int x, int y, const MaskedImage& other, int xo, int yo);
61 
62 
63 class ImageView
64 {
65 
66 protected:
67     quint8* m_data;
68     int m_imageWidth;
69     int m_imageHeight;
70     int m_pixelSize;
71 
72 public:
Init(quint8 * _data,int _imageWidth,int _imageHeight,int _pixelSize)73     void Init(quint8* _data, int _imageWidth, int _imageHeight, int _pixelSize)
74     {
75         m_data = _data;
76         m_imageWidth = _imageWidth;
77         m_imageHeight = _imageHeight;
78         m_pixelSize = _pixelSize;
79     }
80 
ImageView()81     ImageView() : m_data(nullptr)
82 
83     {
84         m_imageHeight =  m_imageWidth = m_pixelSize = 0;
85     }
86 
87 
ImageView(quint8 * _data,int _imageWidth,int _imageHeight,int _pixelSize)88     ImageView(quint8* _data, int _imageWidth, int _imageHeight, int _pixelSize)
89     {
90         Init(_data, _imageWidth, _imageHeight, _pixelSize);
91     }
92 
operator ()(int x,int y) const93     quint8* operator()(int x, int y) const
94     {
95         Q_ASSERT(m_data);
96         Q_ASSERT((x >= 0) && (x < m_imageWidth) && (y >= 0) && (y < m_imageHeight));
97         return (m_data + x * m_pixelSize + y * m_imageWidth * m_pixelSize);
98     }
99 
operator =(const ImageView & other)100     ImageView& operator=(const ImageView& other)
101     {
102         if (this != &other) {
103             if (other.num_bytes() != num_bytes()) {
104                 delete[] m_data;
105                 m_data = nullptr; //to preserve invariance if next line throws exception
106                 m_data = new quint8[other.num_bytes()];
107 
108             }
109             std::copy(other.data(), other.data() + other.num_bytes(), m_data);
110             m_imageHeight = other.m_imageHeight;
111             m_imageWidth = other.m_imageWidth;
112             m_pixelSize = other.m_pixelSize;
113         }
114         return *this;
115     }
116 
117     //move assignment operator
operator =(ImageView && other)118     ImageView& operator=(ImageView&& other) noexcept
119     {
120         if (this != &other) {
121             delete[] m_data;
122             m_data = nullptr;
123             Init(other.data(), other.m_imageWidth, other.m_imageHeight, other.m_pixelSize);
124             other.m_data = nullptr;
125         }
126         return *this;
127     }
128 
~ImageView()129     virtual ~ImageView() {} //this class doesn't own m_data, so it ain't going to delete it either.
130 
data(void) const131     quint8* data(void) const
132     {
133         return m_data;
134     }
135 
num_elements(void) const136     inline int num_elements(void) const
137     {
138         return m_imageHeight * m_imageWidth;
139     }
140 
num_bytes(void) const141     inline int num_bytes(void) const
142     {
143         return m_imageHeight * m_imageWidth * m_pixelSize;
144     }
145 
pixel_size(void) const146     inline int pixel_size(void) const
147     {
148         return m_pixelSize;
149     }
150 
saveToDevice(KisPaintDeviceSP outDev,QRect rect)151     void saveToDevice(KisPaintDeviceSP outDev, QRect rect)
152     {
153         Q_ASSERT(outDev->colorSpace()->pixelSize() == (quint32) m_pixelSize);
154         outDev->writeBytes(m_data, rect);
155     }
156 
DebugDump(const QString & fnamePrefix)157     void DebugDump(const QString& fnamePrefix)
158     {
159         QRect imSize(QPoint(0, 0), QSize(m_imageWidth, m_imageHeight));
160         const KoColorSpace* cs = (m_pixelSize == 1) ?
161                                  KoColorSpaceRegistry::instance()->alpha8() : (m_pixelSize == 3) ? KoColorSpaceRegistry::instance()->colorSpace("RGB", "U8", "") :
162                                  KoColorSpaceRegistry::instance()->colorSpace("RGBA", "U8", "");
163         KisPaintDeviceSP dbout = new KisPaintDevice(cs);
164         saveToDevice(dbout, imSize);
165         KIS_DUMP_DEVICE_2(dbout, imSize, fnamePrefix, "./");
166     }
167 };
168 
169 class ImageData : public ImageView
170 {
171 
172 public:
ImageData()173     ImageData() : ImageView() {}
174 
Init(int _imageWidth,int _imageHeight,int _pixelSize)175     void Init(int _imageWidth, int _imageHeight, int _pixelSize)
176     {
177         m_data = new quint8[ _imageWidth * _imageHeight * _pixelSize ];
178         ImageView::Init(m_data, _imageWidth, _imageHeight, _pixelSize);
179     }
180 
ImageData(int _imageWidth,int _imageHeight,int _pixelSize)181     ImageData(int _imageWidth, int _imageHeight, int _pixelSize) : ImageView()
182     {
183         Init(_imageWidth, _imageHeight, _pixelSize);
184     }
185 
Init(KisPaintDeviceSP imageDev,const QRect & imageSize)186     void Init(KisPaintDeviceSP imageDev, const QRect& imageSize)
187     {
188         const KoColorSpace* cs = imageDev->colorSpace();
189         m_pixelSize = cs->pixelSize();
190 
191         m_data = new quint8[ imageSize.width()*imageSize.height()*cs->pixelSize() ];
192         imageDev->readBytes(m_data, imageSize.x(), imageSize.y(), imageSize.width(), imageSize.height());
193         ImageView::Init(m_data, imageSize.width(), imageSize.height(), m_pixelSize);
194     }
195 
ImageData(KisPaintDeviceSP imageDev,const QRect & imageSize)196     ImageData(KisPaintDeviceSP imageDev, const QRect& imageSize) : ImageView()
197     {
198         Init(imageDev, imageSize);
199     }
200 
~ImageData()201     ~ImageData() override
202     {
203         delete[] m_data; //ImageData owns m_data, so it has to delete it
204     }
205 
206 };
207 
208 
209 
210 class MaskedImage : public KisShared
211 {
212 private:
213 
214     template <typename T> friend float distance_impl(const MaskedImage& my, int x, int y, const MaskedImage& other, int xo, int yo);
215 
216     QRect imageSize;
217     int nChannels;
218 
219     const KoColorSpace* cs;
220     const KoColorSpace* csMask;
221 
222     ImageData maskData;
223     ImageData imageData;
224 
225 
cacheImage(KisPaintDeviceSP imageDev,QRect rect)226     void cacheImage(KisPaintDeviceSP imageDev, QRect rect)
227     {
228         cs = imageDev->colorSpace();
229         nChannels = cs->channelCount();
230         imageData.Init(imageDev, rect);
231         imageSize = rect;
232     }
233 
234 
cacheMask(KisPaintDeviceSP maskDev,QRect rect)235     void cacheMask(KisPaintDeviceSP maskDev, QRect rect)
236     {
237         Q_ASSERT(maskDev->colorSpace()->pixelSize() == 1);
238         csMask = maskDev->colorSpace();
239         maskData.Init(maskDev, rect);
240 
241         //hard threshold for the initial mask
242         //may be optional. needs testing
243         std::for_each(maskData.data(), maskData.data() + maskData.num_bytes(), [](quint8 & v) {
244             v = (v > MASK_CLEAR) ? MASK_SET : MASK_CLEAR;
245         });
246     }
247 
MaskedImage()248     MaskedImage() {}
249 
250 public:
251     std::function< float(const MaskedImage&, int, int, const MaskedImage& , int , int ) > distance;
252 
toPaintDevice(KisPaintDeviceSP imageDev,QRect rect,KisSelectionSP selection)253     void toPaintDevice(KisPaintDeviceSP imageDev, QRect rect, KisSelectionSP selection)
254     {
255         if (!selection) {
256             imageData.saveToDevice(imageDev, rect);
257         } else {
258             KisPaintDeviceSP dev = new KisPaintDevice(imageDev->colorSpace());
259             dev->setDefaultBounds(imageDev->defaultBounds());
260 
261             imageData.saveToDevice(dev, rect);
262 
263             KisPainter::copyAreaOptimized(rect.topLeft(), dev, imageDev, rect, selection);
264         }
265     }
266 
DebugDump(const QString & name)267     void DebugDump(const QString& name)
268     {
269         imageData.DebugDump(name + "_img");
270         maskData.DebugDump(name + "_mask");
271     }
272 
clearMask(void)273     void clearMask(void)
274     {
275         std::fill(maskData.data(), maskData.data() + maskData.num_bytes(), MASK_CLEAR);
276     }
277 
initialize(KisPaintDeviceSP _imageDev,KisPaintDeviceSP _maskDev,QRect _maskRect)278     void initialize(KisPaintDeviceSP _imageDev, KisPaintDeviceSP _maskDev, QRect _maskRect)
279     {
280         cacheImage(_imageDev, _maskRect);
281         cacheMask(_maskDev, _maskRect);
282 
283         //distance function is the only that needs to know the type
284         //For performance reasons we can't use functions provided by color space
285         KoID colorDepthId =  _imageDev->colorSpace()->colorDepthId();
286 
287         //Use RGB traits to assign actual pixel data types.
288         distance = &distance_impl<KoRgbU8Traits::channels_type>;
289 
290         if( colorDepthId == Integer16BitsColorDepthID )
291             distance = &distance_impl<KoRgbU16Traits::channels_type>;
292 #ifdef HAVE_OPENEXR
293         if( colorDepthId == Float16BitsColorDepthID )
294             distance = &distance_impl<KoRgbF16Traits::channels_type>;
295 #endif
296         if( colorDepthId == Float32BitsColorDepthID )
297             distance = &distance_impl<KoRgbF32Traits::channels_type>;
298 
299         if( colorDepthId == Float64BitsColorDepthID )
300             distance = &distance_impl<KoRgbF64Traits::channels_type>;
301     }
302 
MaskedImage(KisPaintDeviceSP _imageDev,KisPaintDeviceSP _maskDev,QRect _maskRect)303     MaskedImage(KisPaintDeviceSP _imageDev, KisPaintDeviceSP _maskDev, QRect _maskRect)
304     {
305         initialize(_imageDev, _maskDev, _maskRect);
306     }
307 
downsample2x(void)308     void downsample2x(void)
309     {
310         int H = imageSize.height();
311         int W = imageSize.width();
312         int newW = W / 2, newH = H / 2;
313 
314         KisPaintDeviceSP imageDev = new KisPaintDevice(cs);
315         KisPaintDeviceSP maskDev = new KisPaintDevice(csMask);
316         imageDev->writeBytes(imageData.data(), 0, 0, W, H);
317         maskDev->writeBytes(maskData.data(), 0, 0, W, H);
318 
319         ImageData newImage(newW, newH, cs->pixelSize());
320         ImageData newMask(newW, newH, 1);
321 
322         KoDummyUpdater updater;
323         KisTransformWorker worker(imageDev, 1. / 2., 1. / 2., 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
324                                   &updater, KisFilterStrategyRegistry::instance()->value("Bicubic"));
325         worker.run();
326 
327         KisTransformWorker workerMask(maskDev, 1. / 2., 1. / 2., 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
328                                       &updater, KisFilterStrategyRegistry::instance()->value("Bicubic"));
329         workerMask.run();
330 
331         imageDev->readBytes(newImage.data(), 0, 0, newW, newH);
332         maskDev->readBytes(newMask.data(), 0, 0, newW, newH);
333         imageData = std::move(newImage);
334         maskData = std::move(newMask);
335 
336         for (int i = 0; i < imageData.num_elements(); ++i) {
337             quint8* maskPix = maskData.data() + i * maskData.pixel_size();
338             if (*maskPix == MASK_SET) {
339                 for (int k = 0; k < imageData.pixel_size(); k++)
340                     *(imageData.data() + i * imageData.pixel_size() + k) = 0;
341             } else {
342                 *maskPix = MASK_CLEAR;
343             }
344         }
345         imageSize = QRect(0, 0, newW, newH);
346     }
347 
upscale(int newW,int newH)348     void upscale(int newW, int newH)
349     {
350         int H = imageSize.height();
351         int W = imageSize.width();
352 
353         ImageData newImage(newW, newH, cs->pixelSize());
354         ImageData newMask(newW, newH, 1);
355 
356         QVector<float> colors(nChannels, 0.f);
357         QVector<float> v(nChannels, 0.f);
358 
359         for (int y = 0; y < newH; ++y) {
360             for (int x = 0; x < newW; ++x) {
361 
362                 // original pixel
363                 int xs = (x * W) / newW;
364                 int ys = (y * H) / newH;
365 
366                 // copy to new image
367                 if (!isMasked(xs, ys)) {
368                     std::copy(imageData(xs, ys), imageData(xs, ys) + imageData.pixel_size(), newImage(x, y));
369                     *newMask(x, y) = MASK_CLEAR;
370                 } else {
371                     std::fill(newImage(x, y), newImage(x, y) + newImage.pixel_size(), 0);
372                     *newMask(x, y) = MASK_SET;
373                 }
374             }
375         }
376 
377         imageData = std::move(newImage);
378         maskData = std::move(newMask);
379         imageSize = QRect(0, 0, newW, newH);
380     }
381 
size()382     QRect size()
383     {
384         return imageSize;
385     }
386 
copy(void)387     KisSharedPtr<MaskedImage> copy(void)
388     {
389         KisSharedPtr<MaskedImage> clone = new MaskedImage();
390         clone->imageSize = this->imageSize;
391         clone->nChannels = this->nChannels;
392         clone->maskData = this->maskData;
393         clone->imageData = this->imageData;
394         clone->cs = this->cs;
395         clone->csMask = this->csMask;
396         clone->distance = this->distance;
397         return clone;
398     }
399 
countMasked(void)400     int countMasked(void)
401     {
402         int count = std::count_if(maskData.data(), maskData.data() + maskData.num_elements(), [](quint8 v) {
403             return v > MASK_CLEAR;
404         });
405         return count;
406     }
407 
isMasked(int x,int y)408     inline bool isMasked(int x, int y)
409     {
410         return (*maskData(x, y) > MASK_CLEAR);
411     }
412 
413     //returns true if the patch contains a masked pixel
containsMasked(int x,int y,int S)414     bool containsMasked(int x, int y, int S)
415     {
416         for (int dy = -S; dy <= S; ++dy) {
417             int ys = y + dy;
418             if (ys < 0 || ys >= imageSize.height())
419                 continue;
420 
421             for (int dx = -S; dx <= S; ++dx) {
422                 int xs = x + dx;
423                 if (xs < 0 || xs >= imageSize.width())
424                     continue;
425                 if (isMasked(xs, ys))
426                     return true;
427             }
428         }
429         return false;
430     }
431 
getImagePixelU8(int x,int y,int chan) const432     inline quint8 getImagePixelU8(int x, int y, int chan) const
433     {
434         return cs->scaleToU8(imageData(x, y), chan);
435     }
436 
getImagePixels(int x,int y) const437     inline QVector<float> getImagePixels(int x, int y) const
438     {
439         QVector<float> v(cs->channelCount());
440         cs->normalisedChannelsValue(imageData(x, y), v);
441         return v;
442     }
443 
getImagePixel(int x,int y)444     inline quint8* getImagePixel(int x, int y)
445     {
446         return imageData(x, y);
447     }
448 
setImagePixels(int x,int y,QVector<float> & value)449     inline void setImagePixels(int x, int y, QVector<float>& value)
450     {
451         cs->fromNormalisedChannelsValue(imageData(x, y), value);
452     }
453 
mixColors(std::vector<quint8 * > pixels,std::vector<float> w,float wsum,quint8 * dst)454     inline void mixColors(std::vector< quint8* > pixels, std::vector< float > w, float wsum,  quint8* dst)
455     {
456         const KoMixColorsOp* mixOp = cs->mixColorsOp();
457 
458         size_t n = w.size();
459         assert(pixels.size() == n);
460         std::vector< qint16 > weights;
461         weights.clear();
462 
463         float dif = 0;
464 
465         float scale = 255 / (wsum + 0.001);
466 
467         for (auto& v : w) {
468             //compensated summation to increase accuracy
469             float v1 = v * scale + dif;
470             float v2 = std::round(v1);
471             dif = v1 - v2;
472             weights.push_back(v2);
473         }
474 
475         mixOp->mixColors(pixels.data(), weights.data(), n, dst);
476     }
477 
setMask(int x,int y,quint8 v)478     inline void setMask(int x, int y, quint8 v)
479     {
480         *(maskData(x, y)) = v;
481     }
482 
channelCount(void) const483     inline int channelCount(void) const
484     {
485         return cs->channelCount();
486     }
487 };
488 
489 
490 //Generic version of the distance function. produces distance between colors in the range [0, MAX_DIST]. This
491 //is a fast distance computation. More accurate, but very slow implementation is to use color space operations.
distance_impl(const MaskedImage & my,int x,int y,const MaskedImage & other,int xo,int yo)492 template <typename T> float distance_impl(const MaskedImage& my, int x, int y, const MaskedImage& other, int xo, int yo)
493 {
494     float dsq = 0;
495     quint32 nchannels = my.channelCount();
496     T *v1 = reinterpret_cast<T*>(my.imageData(x, y));
497     T *v2 = reinterpret_cast<T*>(other.imageData(xo, yo));
498 
499     for (quint32 chan = 0; chan < nchannels; chan++) {
500         //It's very important not to lose precision in the next line
501         float v = ((float)(*(v1 + chan)) - (float)(*(v2 + chan)));
502         dsq += v * v;
503     }
504     return dsq / ((float)KoColorSpaceMathsTraits<T>::unitValue * KoColorSpaceMathsTraits<T>::unitValue / MAX_DIST );
505 }
506 
507 
508 typedef KisSharedPtr<MaskedImage> MaskedImageSP;
509 
510 struct NNPixel {
511     int x;
512     int y;
513     int distance;
514 };
515 typedef boost::multi_array<NNPixel, 2> NNArray_type;
516 
517 struct Vote_elem {
518     QVector<float> channel_values;
519     float w;
520 };
521 typedef boost::multi_array<Vote_elem, 2> Vote_type;
522 
523 
524 
525 class NearestNeighborField : public KisShared
526 {
527 
528 private:
randomInt(T range)529     template< typename T> T randomInt(T range)
530     {
531         return rand() % range;
532     }
533 
534     //compute initial value of the distance term
initialize(void)535     void initialize(void)
536     {
537         for (int y = 0; y < imSize.height(); y++) {
538             for (int x = 0; x < imSize.width(); x++) {
539                 field[x][y].distance = distance(x, y, field[x][y].x, field[x][y].y);
540 
541                 //if the distance is "infinity", try to find a better link
542                 int iter = 0;
543                 const int maxretry = 20;
544                 while (field[x][y].distance == MAX_DIST && iter < maxretry) {
545                     field[x][y].x = randomInt(imSize.width() + 1);
546                     field[x][y].y = randomInt(imSize.height() + 1);
547                     field[x][y].distance = distance(x, y, field[x][y].x, field[x][y].y);
548                     iter++;
549                 }
550             }
551         }
552     }
553 
init_similarity_curve(void)554     void init_similarity_curve(void)
555     {
556         float s_zero = 0.999;
557         float t_halfmax = 0.10;
558 
559         float x  = (s_zero - 0.5) * 2;
560         float invtanh = 0.5 * std::log((1. + x) / (1. - x));
561         float coef = invtanh / t_halfmax;
562 
563         similarity.resize(MAX_DIST + 1);
564         for (int i = 0; i < (int)similarity.size(); i++) {
565             float t = (float)i / similarity.size();
566             similarity[i] = 0.5 - 0.5 * std::tanh(coef * (t - t_halfmax));
567         }
568     }
569 
570 
571 private:
572     int patchSize; //patch size
573 public:
574     MaskedImageSP input;
575     MaskedImageSP output;
576     QRect imSize;
577     NNArray_type field;
578     std::vector<float> similarity;
579     quint32 nColors;
580     QList<KoChannelInfo *> channels;
581 
582 public:
NearestNeighborField(const MaskedImageSP _input,MaskedImageSP _output,int _patchsize)583     NearestNeighborField(const MaskedImageSP _input, MaskedImageSP _output, int _patchsize) : patchSize(_patchsize), input(_input), output(_output)
584     {
585         imSize = input->size();
586         field.resize(boost::extents[imSize.width()][imSize.height()]);
587         init_similarity_curve();
588 
589         nColors = input->channelCount(); //only color count, doesn't include alpha channels
590     }
591 
randomize(void)592     void randomize(void)
593     {
594         for (int y = 0; y < imSize.height(); y++) {
595             for (int x = 0; x < imSize.width(); x++) {
596                 field[x][y].x = randomInt(imSize.width() + 1);
597                 field[x][y].y = randomInt(imSize.height() + 1);
598                 field[x][y].distance = MAX_DIST;
599             }
600         }
601         initialize();
602     }
603 
604     //initialize field from an existing (possibly smaller) nearest neighbor field
initialize(const NearestNeighborField & nnf)605     void initialize(const NearestNeighborField& nnf)
606     {
607         float xscale = qreal(imSize.width()) / nnf.imSize.width();
608         float yscale = qreal(imSize.height()) / nnf.imSize.height();
609 
610         for (int y = 0; y < imSize.height(); y++) {
611             for (int x = 0; x < imSize.width(); x++) {
612                 int xlow = std::min((int)(x / xscale), nnf.imSize.width() - 1);
613                 int ylow = std::min((int)(y / yscale), nnf.imSize.height() - 1);
614 
615                 field[x][y].x = nnf.field[xlow][ylow].x * xscale;
616                 field[x][y].y = nnf.field[xlow][ylow].y * yscale;
617                 field[x][y].distance = MAX_DIST;
618             }
619         }
620         initialize();
621     }
622 
623     //multi-pass NN-field minimization (see "PatchMatch" paper referenced above - page 4)
minimize(int pass)624     void minimize(int pass)
625     {
626         int min_x = 0;
627         int min_y = 0;
628         int max_x = imSize.width() - 1;
629         int max_y = imSize.height() - 1;
630 
631         for (int i = 0; i < pass; i++) {
632             //scanline order
633             for (int y = min_y; y < max_y; y++)
634                 for (int x = min_x; x <= max_x; x++)
635                     if (field[x][y].distance > 0)
636                         minimizeLink(x, y, 1);
637 
638             //reverse scanline order
639             for (int y = max_y; y >= min_y; y--)
640                 for (int x = max_x; x >= min_x; x--)
641                     if (field[x][y].distance > 0)
642                         minimizeLink(x, y, -1);
643         }
644     }
645 
minimizeLink(int x,int y,int dir)646     void minimizeLink(int x, int y, int dir)
647     {
648         int xp, yp, dp;
649 
650         //Propagation Left/Right
651         if (x - dir > 0 && x - dir < imSize.width()) {
652             xp = field[x - dir][y].x + dir;
653             yp = field[x - dir][y].y;
654             dp = distance(x, y, xp, yp);
655             if (dp < field[x][y].distance) {
656                 field[x][y].x = xp;
657                 field[x][y].y = yp;
658                 field[x][y].distance = dp;
659             }
660         }
661 
662         //Propagation Up/Down
663         if (y - dir > 0 && y - dir < imSize.height()) {
664             xp = field[x][y - dir].x;
665             yp = field[x][y - dir].y + dir;
666             dp = distance(x, y, xp, yp);
667             if (dp < field[x][y].distance) {
668                 field[x][y].x = xp;
669                 field[x][y].y = yp;
670                 field[x][y].distance = dp;
671             }
672         }
673 
674         //Random search
675         int wi = std::max(output->size().width(), output->size().height());
676         int xpi = field[x][y].x;
677         int ypi = field[x][y].y;
678         while (wi > 0) {
679             xp = xpi + randomInt(2 * wi) - wi;
680             yp = ypi + randomInt(2 * wi) - wi;
681             xp = std::max(0, std::min(output->size().width() - 1, xp));
682             yp = std::max(0, std::min(output->size().height() - 1, yp));
683 
684             dp = distance(x, y, xp, yp);
685             if (dp < field[x][y].distance) {
686                 field[x][y].x = xp;
687                 field[x][y].y = yp;
688                 field[x][y].distance = dp;
689             }
690             wi /= 2;
691         }
692     }
693 
694     //compute distance between two patches
distance(int x,int y,int xp,int yp)695     int distance(int x, int y, int xp, int yp)
696     {
697         float distance = 0;
698         float wsum = 0;
699         float ssdmax = nColors * 255 * 255;
700 
701         //for each pixel in the source patch
702         for (int dy = -patchSize; dy <= patchSize; dy++) {
703             for (int dx = -patchSize; dx <= patchSize; dx++) {
704                 wsum += ssdmax;
705                 int xks = x + dx;
706                 int yks = y + dy;
707 
708                 if (xks < 0 || xks >= input->size().width()) {
709                     distance += ssdmax;
710                     continue;
711                 }
712 
713                 if (yks < 0 || yks >= input->size().height()) {
714                     distance += ssdmax;
715                     continue;
716                 }
717 
718                 //cannot use masked pixels as a valid source of information
719                 if (input->isMasked(xks, yks)) {
720                     distance += ssdmax;
721                     continue;
722                 }
723 
724                 //corresponding pixel in target patch
725                 int xkt = xp + dx;
726                 int ykt = yp + dy;
727                 if (xkt < 0 || xkt >= output->size().width()) {
728                     distance += ssdmax;
729                     continue;
730                 }
731                 if (ykt < 0 || ykt >= output->size().height()) {
732                     distance += ssdmax;
733                     continue;
734                 }
735 
736                 //cannot use masked pixels as a valid source of information
737                 if (output->isMasked(xkt, ykt)) {
738                     distance += ssdmax;
739                     continue;
740                 }
741 
742                 //SSD distance between pixels
743                 float ssd = input->distance(*input, xks, yks, *output, xkt, ykt);
744                 distance += ssd;
745 
746             }
747         }
748         return (int)(MAX_DIST * (distance / wsum));
749     }
750 
751     static MaskedImageSP ExpectationMaximization(KisSharedPtr<NearestNeighborField> TargetToSource, int level, int radius, QList<MaskedImageSP>& pyramid);
752 
753     static void ExpectationStep(KisSharedPtr<NearestNeighborField> nnf, MaskedImageSP source, MaskedImageSP target, bool upscale);
754 
755     void EM_Step(MaskedImageSP source, MaskedImageSP target, int R, bool upscaled);
756 };
757 typedef KisSharedPtr<NearestNeighborField> NearestNeighborFieldSP;
758 
759 
760 class Inpaint
761 {
762 private:
763     KisPaintDeviceSP devCache;
764     MaskedImageSP initial;
765     NearestNeighborFieldSP nnf_TargetToSource;
766     NearestNeighborFieldSP nnf_SourceToTarget;
767     int radius;
768     QList<MaskedImageSP> pyramid;
769 
770 
771 public:
Inpaint(KisPaintDeviceSP dev,KisPaintDeviceSP devMask,int _radius,QRect maskRect)772     Inpaint(KisPaintDeviceSP dev, KisPaintDeviceSP devMask, int _radius, QRect maskRect)
773     : devCache(dev)
774     , initial(new MaskedImage(dev, devMask, maskRect))
775     , radius(_radius)
776     {
777     }
778     MaskedImageSP patch(void);
779     MaskedImageSP patch_simple(void);
780 };
781 
782 
783 
patch()784 MaskedImageSP Inpaint::patch()
785 {
786     MaskedImageSP source = initial->copy();
787 
788     pyramid.append(initial);
789 
790     QRect size = source->size();
791 
792     //qDebug() << "countMasked: " <<  source->countMasked() << "\n";
793     while ((size.width() > radius) && (size.height() > radius) && source->countMasked() > 0) {
794         source->downsample2x();
795         //source->DebugDump("Pyramid");
796         //qDebug() << "countMasked1: " <<  source->countMasked() << "\n";
797         pyramid.append(source->copy());
798         size = source->size();
799     }
800     int maxlevel = pyramid.size();
801     //qDebug() << "MaxLevel: " <<  maxlevel << "\n";
802 
803     // The initial target is the same as the smallest source.
804     // We consider that this target contains no masked pixels
805     MaskedImageSP target = source->copy();
806     target->clearMask();
807 
808     //recursively building nearest neighbor field
809     for (int level = maxlevel - 1; level > 0; level--) {
810         source = pyramid.at(level);
811 
812         if (level == maxlevel - 1) {
813             //random initial guess
814             nnf_TargetToSource = new NearestNeighborField(target, source, radius);
815             nnf_TargetToSource->randomize();
816         } else {
817             // then, we use the rebuilt (upscaled) target
818             // and reuse the previous NNF as initial guess
819 
820             NearestNeighborFieldSP new_nnf_rev = new NearestNeighborField(target, source, radius);
821             new_nnf_rev->initialize(*nnf_TargetToSource);
822             nnf_TargetToSource = new_nnf_rev;
823         }
824 
825         //Build an upscaled target by EM-like algorithm (see "PatchMatch" paper referenced above - page 6)
826         target = NearestNeighborField::ExpectationMaximization(nnf_TargetToSource, level, radius, pyramid);
827         //target->DebugDump( "target" );
828     }
829     return target;
830 }
831 
832 
833 //EM-Like algorithm (see "PatchMatch" - page 6)
834 //Returns a float sized target image
ExpectationMaximization(NearestNeighborFieldSP nnf_TargetToSource,int level,int radius,QList<MaskedImageSP> & pyramid)835 MaskedImageSP NearestNeighborField::ExpectationMaximization(NearestNeighborFieldSP nnf_TargetToSource, int level, int radius, QList<MaskedImageSP>& pyramid)
836 {
837     int iterEM = std::min(2 * level, 4);
838     int iterNNF = std::min(5, 1 + level);
839 
840     MaskedImageSP source = nnf_TargetToSource->output;
841     MaskedImageSP target = nnf_TargetToSource->input;
842     MaskedImageSP newtarget = nullptr;
843 
844     //EM loop
845     for (int emloop = 1; emloop <= iterEM; emloop++) {
846         //set the new target as current target
847         if (!newtarget.isNull()) {
848             nnf_TargetToSource->input = newtarget;
849             target = newtarget;
850             newtarget = nullptr;
851         }
852 
853         for (int x = 0; x < target->size().width(); ++x) {
854             for (int y = 0; y < target->size().height(); ++y) {
855                 if (!source->containsMasked(x, y, radius)) {
856                     nnf_TargetToSource->field[x][y].x = x;
857                     nnf_TargetToSource->field[x][y].y = y;
858                     nnf_TargetToSource->field[x][y].distance = 0;
859                 }
860             }
861         }
862 
863         //minimize the NNF
864         nnf_TargetToSource->minimize(iterNNF);
865 
866         //Now we rebuild the target using best patches from source
867         MaskedImageSP newsource = nullptr;
868         bool upscaled = false;
869 
870         // Instead of upsizing the final target, we build the last target from the next level source image
871         // So the final target is less blurry (see "Space-Time Video Completion" - page 5)
872         if (level >= 1 && (emloop == iterEM)) {
873             newsource = pyramid.at(level - 1);
874             QRect sz = newsource->size();
875             newtarget = target->copy();
876             newtarget->upscale(sz.width(), sz.height());
877             upscaled = true;
878         } else {
879             newsource = pyramid.at(level);
880             newtarget = target->copy();
881             upscaled = false;
882         }
883         //EM Step
884 
885         //EM_Step(newsource, newtarget, radius, upscaled);
886         ExpectationStep(nnf_TargetToSource, newsource, newtarget, upscaled);
887     }
888 
889     return newtarget;
890 }
891 
892 
ExpectationStep(NearestNeighborFieldSP nnf,MaskedImageSP source,MaskedImageSP target,bool upscale)893 void NearestNeighborField::ExpectationStep(NearestNeighborFieldSP nnf, MaskedImageSP source, MaskedImageSP target, bool upscale)
894 {
895     //int*** field = nnf->field;
896     int R = nnf->patchSize;
897     if (upscale)
898         R *= 2;
899 
900     int H_nnf = nnf->input->size().height();
901     int W_nnf = nnf->input->size().width();
902     int H_target = target->size().height();
903     int W_target = target->size().width();
904     int H_source = source->size().height();
905     int W_source = source->size().width();
906 
907     std::vector< quint8* > pixels;
908     std::vector< float > weights;
909     pixels.reserve(R * R);
910     weights.reserve(R * R);
911     for (int x = 0 ; x < W_target ; ++x) {
912         for (int y = 0 ; y < H_target; ++y) {
913             float wsum = 0;
914             pixels.clear();
915             weights.clear();
916 
917 
918             if (!source->containsMasked(x, y, R + 4) /*&& upscale*/) {
919                 //speedup computation by copying parts that are not masked.
920                 pixels.push_back(source->getImagePixel(x, y));
921                 weights.push_back(1.f);
922                 target->mixColors(pixels, weights, 1.f, target->getImagePixel(x, y));
923             } else {
924                 for (int dx = -R ; dx <= R; ++dx) {
925                     for (int dy = -R ; dy <= R ; ++dy) {
926                         // xpt,ypt = center pixel of the target patch
927                         int xpt = x + dx;
928                         int ypt = y + dy;
929 
930                         int xst, yst;
931                         float w;
932 
933                         if (!upscale) {
934                             if (xpt < 0 || xpt >= W_nnf || ypt < 0 || ypt >= H_nnf)
935                                 continue;
936 
937                             xst = nnf->field[xpt][ypt].x;
938                             yst = nnf->field[xpt][ypt].y;
939                             float dp = nnf->field[xpt][ypt].distance;
940                             // similarity measure between the two patches
941                             w = nnf->similarity[dp];
942 
943                         } else {
944                             if (xpt < 0 || (xpt / 2) >= W_nnf || ypt < 0 || (ypt / 2) >= H_nnf)
945                                 continue;
946                             xst = 2 * nnf->field[xpt / 2][ypt / 2].x + (xpt % 2);
947                             yst = 2 * nnf->field[xpt / 2][ypt / 2].y + (ypt % 2);
948                             float dp = nnf->field[xpt / 2][ypt / 2].distance;
949                             // similarity measure between the two patches
950                             w = nnf->similarity[dp];
951                         }
952 
953                         int xs = xst - dx;
954                         int ys = yst - dy;
955 
956                         if (xs < 0 || xs >= W_source || ys < 0 || ys >= H_source)
957                             continue;
958 
959                         if (source->isMasked(xs, ys))
960                             continue;
961 
962                         pixels.push_back(source->getImagePixel(xs, ys));
963                         weights.push_back(w);
964                         wsum += w;
965                     }
966                 }
967 
968                 if (wsum < 1)
969                     continue;
970 
971                 target->mixColors(pixels, weights, wsum, target->getImagePixel(x, y));
972             }
973         }
974     }
975 }
976 
getMaskBoundingBox(KisPaintDeviceSP maskDev)977 QRect getMaskBoundingBox(KisPaintDeviceSP maskDev)
978 {
979     QRect maskRect = maskDev->nonDefaultPixelArea();
980     return maskRect;
981 }
982 
983 
patchImage(const KisPaintDeviceSP imageDev,const KisPaintDeviceSP maskDev,int patchRadius,int accuracy,KisSelectionSP selection)984 QRect patchImage(const KisPaintDeviceSP imageDev, const KisPaintDeviceSP maskDev, int patchRadius, int accuracy, KisSelectionSP selection)
985 {
986     QRect maskRect = getMaskBoundingBox(maskDev);
987     QRect imageRect = imageDev->exactBounds();
988 
989     float scale = 1.0 + (accuracy / 25.0); //higher accuracy means we include more surrounding area around the patch. Minimum 2x padding.
990     int dx = maskRect.width() * scale;
991     int dy = maskRect.height() * scale;
992     maskRect.adjust(-dx, -dy, dx, dy);
993     maskRect = maskRect.intersected(imageRect);
994 
995     if (!maskRect.isEmpty()) {
996         Inpaint inpaint(imageDev, maskDev, patchRadius, maskRect);
997         MaskedImageSP output = inpaint.patch();
998         output->toPaintDevice(imageDev, maskRect, selection);
999     }
1000 
1001     return maskRect;
1002 }
1003 
1004