1 /* ============================================================
2  *
3  * This file is a part of digiKam project
4  * https://www.digikam.org
5  *
6  * Date        : 2019-08-10
7  * Description : CLI tool to test and verify clustering for Face Recognition
8  *
9  * Copyright (C) 2019 by Thanh Trung Dinh <dinhthanhtrung1996 at gmail dot com>
10  *
11  * This program is free software; you can redistribute it
12  * and/or modify it under the terms of the GNU General
13  * Public License as published by the Free Software Foundation;
14  * either version 2, or (at your option)
15  * any later version.
16  *
17  * This program is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20  * GNU General Public License for more details.
21  *
22  * ============================================================ */
23 
24 // C++ includes
25 
26 #include <set>
27 
28 // Qt includes
29 
30 #include <QCoreApplication>
31 #include <QDir>
32 #include <QImage>
33 #include <QElapsedTimer>
34 #include <QCommandLineParser>
35 #include <QList>
36 
37 // Local includes
38 
39 #include "digikam_debug.h"
40 #include "dimg.h"
41 #include "facescansettings.h"
42 #include "facedetector.h"
43 #include "coredbaccess.h"
44 #include "dbengineparameters.h"
45 #include "facialrecognition_wrapper.h"
46 
47 using namespace Digikam;
48 
49 // --------------------------------------------------------------------------------------------------
50 
51 /**
52  * Function to return the
53  * intersection vector of v1 and v2
54  */
intersection(const std::vector<int> & v1,const std::vector<int> & v2,std::vector<int> & vout)55 void intersection(const std::vector<int>& v1,
56                   const std::vector<int>& v2,
57                   std::vector<int>& vout)
58 {
59     // Find the intersection of the two sets
60 
61     std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(),
62                           std::inserter(vout, vout.begin()));
63 }
64 
65 /**
66  * Function to return the Jaccard distance of two vectors
67  */
jaccard_distance(const std::vector<int> & v1,const std::vector<int> & v2)68 double jaccard_distance(const std::vector<int>& v1,
69                         const std::vector<int>& v2)
70 {
71     // Sizes of both the sets
72 
73     double size_v1       = v1.size();
74     double size_v2       = v2.size();
75 
76     // Get the intersection set
77 
78     std::vector<int> intersect;
79     intersection(v1, v2, intersect);
80 
81     // Size of the intersection set
82 
83     double size_in       = intersect.size();
84 
85     // Calculate the Jaccard index
86     // using the formula
87 
88     double jaccard_index = size_in / (size_v1 + size_v2 - size_in);
89 
90     // Calculate the Jaccard distance
91     // using the formula
92 
93     double jaccard_dist  = 1 - jaccard_index;
94 
95     // Return the Jaccard distance
96 
97     return jaccard_dist;
98 }
99 
toPaths(char ** argv,int startIndex,int argc)100 QStringList toPaths(char** argv, int startIndex, int argc)
101 {
102     QStringList files;
103 
104     for (int i = startIndex ; i < argc ; ++i)
105     {
106         files << QString::fromLatin1(argv[i]);
107     }
108 
109     return files;
110 }
111 
toImages(const QStringList & paths)112 QList<QImage> toImages(const QStringList& paths)
113 {
114     QList<QImage> images;
115 
116     foreach (const QString& path, paths)
117     {
118         images << QImage(path);
119     }
120 
121     return images;
122 }
123 
prepareForTrain(QString datasetPath,QStringList & images,std::vector<int> & testClusteredIndices)124 int prepareForTrain(QString datasetPath,
125                     QStringList& images,
126                     std::vector<int>& testClusteredIndices)
127 {
128     if (!datasetPath.endsWith(QLatin1String("/")))
129     {
130         datasetPath.append(QLatin1String("/"));
131     }
132 
133     QDir testSet(datasetPath);
134     QStringList subjects = testSet.entryList(QDir::Dirs | QDir::NoDotAndDotDot | QDir::NoSymLinks);
135     int nbOfClusters     = subjects.size();
136 
137     qCDebug(DIGIKAM_TESTS_LOG) << "Number of clusters to be defined" << nbOfClusters;
138 
139     for (int i = 1 ; i <= nbOfClusters ; ++i)
140     {
141         QString subjectPath               = QString::fromLatin1("%1%2")
142                                                      .arg(datasetPath)
143                                                      .arg(subjects.takeFirst());
144         QDir subjectDir(subjectPath);
145 
146         QStringList files                 = subjectDir.entryList(QDir::Files);
147         unsigned int nbOfFacesPerClusters = files.size();
148 
149         for (unsigned j = 1 ; j <= nbOfFacesPerClusters ; ++j)
150         {
151             QString path = QString::fromLatin1("%1/%2").arg(subjectPath)
152                                                        .arg(files.takeFirst());
153 
154             testClusteredIndices.push_back(i - 1);
155             images << path;
156         }
157     }
158 
159     qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters (prepareForTrain) " << nbOfClusters;
160 
161     return nbOfClusters;
162 }
163 
processFaceDetection(const QString & imagePath,FaceDetector detector)164 QList<QRectF> processFaceDetection(const QString& imagePath, FaceDetector detector)
165 {
166     QList<QRectF> detectedFaces = detector.detectFaces(imagePath);
167 
168     qCDebug(DIGIKAM_TESTS_LOG) << "(Input CV) Found " << detectedFaces.size() << " faces";
169 
170     return detectedFaces;
171 }
172 
retrieveFaces(const QList<QImage> & images,const QList<QRectF> & rects)173 QList<QImage> retrieveFaces(const QList<QImage>& images, const QList<QRectF>& rects)
174 {
175     QList<QImage> faces;
176     unsigned index = 0;
177 
178     foreach (const QRectF& rect, rects)
179     {
180         DImg temp(images.at(index));
181         faces << temp.copyQImage(rect);
182         ++index;
183     }
184 
185     return faces;
186 }
187 
createClustersFromClusterIndices(const std::vector<int> & clusteredIndices,QList<std::vector<int>> & clusters)188 void createClustersFromClusterIndices(const std::vector<int>& clusteredIndices,
189                                       QList<std::vector<int>>& clusters)
190 {
191     int nbOfClusters = 0;
192 
193     for (size_t i = 0 ; i < clusteredIndices.size() ; ++i)
194     {
195         int nb = clusteredIndices[i];
196 
197         if (nb > nbOfClusters)
198         {
199             nbOfClusters = nb;
200         }
201     }
202 
203     nbOfClusters++;
204 
205     for (int i = 0 ; i < nbOfClusters ; ++i)
206     {
207         clusters << std::vector<int>();
208     }
209 
210     qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters " << clusters.size();
211 
212     for (int i = 0 ; i < (int)clusteredIndices.size() ; ++i)
213     {
214         clusters[clusteredIndices[i]].push_back(i);
215     }
216 }
217 
verifyClusteringResults(const std::vector<int> & clusteredIndices,const std::vector<int> & testClusteredIndices,const QStringList & dataset,QStringList & falsePositiveCases)218 void verifyClusteringResults(const std::vector<int>& clusteredIndices,
219                              const std::vector<int>& testClusteredIndices,
220                              const QStringList& dataset,
221                              QStringList& falsePositiveCases)
222 {
223     QList<std::vector<int>> clusters, testClusters;
224     createClustersFromClusterIndices(clusteredIndices, clusters);
225     createClustersFromClusterIndices(testClusteredIndices, testClusters);
226 
227     std::set<int> falsePositivePoints;
228     int testClustersSize = testClusters.size();
229     std::vector<float> visited(testClustersSize, 1.0);
230     std::vector<std::set<int>> lastVisit(testClustersSize, std::set<int>{});
231 
232     for (int i = 0 ; i < testClustersSize ; ++i)
233     {
234         std::vector<int> refSet = testClusters.at(i);
235         double minDist          = 1.0;
236         int indice              = 0;
237 
238         for (int j = 0 ; j < clusters.size() ; ++j)
239         {
240             double dist = jaccard_distance(refSet, clusters.at(j));
241 
242             if (dist < minDist)
243             {
244                 indice  = j;
245                 minDist = dist;
246             }
247         }
248 
249         qCDebug(DIGIKAM_TESTS_LOG) << "testCluster " << i << " with group " << indice;
250 
251         std::vector<int> similarSet = clusters.at(indice);
252 
253         if (minDist < visited[indice])
254         {
255             visited[indice]            = minDist;
256             std::set<int> lastVisitSet = lastVisit[indice];
257             std::set<int> newVisitSet;
258             std::set_symmetric_difference(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(),
259                                           std::inserter(newVisitSet, newVisitSet.begin()));
260 
261             for (int elm: lastVisitSet)
262             {
263                 falsePositivePoints.erase(elm);
264             }
265 
266             lastVisit[indice] = newVisitSet;
267             falsePositivePoints.insert(newVisitSet.begin(), newVisitSet.end());
268         }
269         else
270         {
271             std::set_intersection(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(),
272                                   std::inserter(falsePositivePoints, falsePositivePoints.begin()));
273         }
274     }
275 
276     for (auto indx: falsePositivePoints)
277     {
278         falsePositiveCases << dataset[indx];
279     }
280 }
281 
282 // --------------------------------------------------------------------------------------------------
283 
main(int argc,char * argv[])284 int main(int argc, char* argv[])
285 {
286     QCoreApplication app(argc, argv);
287     app.setApplicationName(QString::fromLatin1("digikam"));          // for DB init.
288 
289     // Options for commandline parser
290 
291     QCommandLineParser parser;
292     parser.addOption(QCommandLineOption(QLatin1String("db"),
293                      QLatin1String("Faces database"),
294                      QLatin1String("path to db folder")));
295     parser.addHelpOption();
296     parser.process(app);
297 
298     // Parse arguments
299 
300     bool optionErrors = false;
301 
302     if      (parser.optionNames().empty())
303     {
304         qCWarning(DIGIKAM_TESTS_LOG) << "No options!!!";
305         optionErrors = true;
306     }
307     else if (!parser.isSet(QLatin1String("db")))
308     {
309         qCWarning(DIGIKAM_TESTS_LOG) << "Missing database for test!!!";
310         optionErrors = true;
311     }
312 
313     if (optionErrors)
314     {
315         parser.showHelp();
316         return 1;
317     }
318 
319     QString facedb         = parser.value(QLatin1String("db"));
320 
321     // Init config for digiKam
322 
323     DbEngineParameters prm = DbEngineParameters::parametersFromConfig();
324     CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication);
325     FacialRecognitionWrapper recognizer;
326 
327     //db.setRecognizerThreshold(0.91F);       // This is sensitive for the performance of face clustering
328 
329     // Construct test set, data set
330 
331     QStringList dataset;
332     std::vector<int> testClusteredIndices;
333     int nbOfClusters = prepareForTrain(facedb, dataset, testClusteredIndices);
334 
335     // Init FaceDetector used for detecting faces and bounding box
336     // before recognizing
337 
338     FaceDetector detector;
339 
340     // Evaluation metrics
341 
342     unsigned totalClustered    = 0;
343     unsigned elapsedClustering = 0;
344 
345     QStringList undetectedFaces;
346 
347     QList<QImage> detectedFaces;
348     QList<QRectF> bboxes;
349     QList<QImage> rawImages    = toImages(dataset);
350 
351     foreach (const QImage& image, rawImages)
352     {
353         QString imagePath                 = dataset.takeFirst();
354         QList<QRectF> detectedBoundingBox = processFaceDetection(imagePath, detector);
355 
356         if (detectedBoundingBox.size())
357         {
358             detectedFaces << image;
359             bboxes        << detectedBoundingBox.first();
360             dataset       << imagePath;
361 
362             ++totalClustered;
363         }
364         else
365         {
366             undetectedFaces << imagePath;
367         }
368     }
369 
370     std::vector<int> clusteredIndices(dataset.size(), -1);
371     QList<QImage> faces = retrieveFaces(detectedFaces, bboxes);
372 
373     QElapsedTimer timer;
374 
375     timer.start();
376 /*
377     TODO: port to new API
378     db.clusterFaces(faces, clusteredIndices, dataset, nbOfClusters);
379 */
380     elapsedClustering  += timer.elapsed();
381 
382     // Verify clustering
383 
384     QStringList falsePositiveCases;
385     verifyClusteringResults(clusteredIndices, testClusteredIndices, dataset, falsePositiveCases);
386 
387     // Display results
388 
389     unsigned nbUndetectedFaces = undetectedFaces.size();
390     qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbUndetectedFaces << " / " << dataset.size() + nbUndetectedFaces
391              << " (" << float(nbUndetectedFaces) / (dataset.size() + nbUndetectedFaces) * 100 << "%)"
392              << " faces cannot be detected";
393 
394     foreach (const QString& path, undetectedFaces)
395     {
396         qCDebug(DIGIKAM_TESTS_LOG) << path;
397     }
398 
399     unsigned nbOfFalsePositiveCases = falsePositiveCases.size();
400     qCDebug(DIGIKAM_TESTS_LOG) << "\nFalse positive cases";
401     qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbOfFalsePositiveCases << " / " << dataset.size()
402              << " (" << float(nbOfFalsePositiveCases*100) / dataset.size()<< "%)"
403              << " faces were wrongly clustered";
404 
405     foreach (const QString& imagePath, falsePositiveCases)
406     {
407         qCDebug(DIGIKAM_TESTS_LOG) << imagePath;
408     }
409 
410     qCDebug(DIGIKAM_TESTS_LOG) << "\n Time for clustering " << elapsedClustering << " ms";
411 
412     return 0;
413 }
414