1 /* ============================================================
2 *
3 * This file is a part of digiKam project
4 * https://www.digikam.org
5 *
6 * Date : 2019-08-10
7 * Description : CLI tool to test and verify clustering for Face Recognition
8 *
9 * Copyright (C) 2019 by Thanh Trung Dinh <dinhthanhtrung1996 at gmail dot com>
10 *
11 * This program is free software; you can redistribute it
12 * and/or modify it under the terms of the GNU General
13 * Public License as published by the Free Software Foundation;
14 * either version 2, or (at your option)
15 * any later version.
16 *
17 * This program is distributed in the hope that it will be useful,
18 * but WITHOUT ANY WARRANTY; without even the implied warranty of
19 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 * GNU General Public License for more details.
21 *
22 * ============================================================ */
23
24 // C++ includes
25
26 #include <set>
27
28 // Qt includes
29
30 #include <QCoreApplication>
31 #include <QDir>
32 #include <QImage>
33 #include <QElapsedTimer>
34 #include <QCommandLineParser>
35 #include <QList>
36
37 // Local includes
38
39 #include "digikam_debug.h"
40 #include "dimg.h"
41 #include "facescansettings.h"
42 #include "facedetector.h"
43 #include "coredbaccess.h"
44 #include "dbengineparameters.h"
45 #include "facialrecognition_wrapper.h"
46
47 using namespace Digikam;
48
49 // --------------------------------------------------------------------------------------------------
50
51 /**
52 * Function to return the
53 * intersection vector of v1 and v2
54 */
intersection(const std::vector<int> & v1,const std::vector<int> & v2,std::vector<int> & vout)55 void intersection(const std::vector<int>& v1,
56 const std::vector<int>& v2,
57 std::vector<int>& vout)
58 {
59 // Find the intersection of the two sets
60
61 std::set_intersection(v1.begin(), v1.end(), v2.begin(), v2.end(),
62 std::inserter(vout, vout.begin()));
63 }
64
65 /**
66 * Function to return the Jaccard distance of two vectors
67 */
jaccard_distance(const std::vector<int> & v1,const std::vector<int> & v2)68 double jaccard_distance(const std::vector<int>& v1,
69 const std::vector<int>& v2)
70 {
71 // Sizes of both the sets
72
73 double size_v1 = v1.size();
74 double size_v2 = v2.size();
75
76 // Get the intersection set
77
78 std::vector<int> intersect;
79 intersection(v1, v2, intersect);
80
81 // Size of the intersection set
82
83 double size_in = intersect.size();
84
85 // Calculate the Jaccard index
86 // using the formula
87
88 double jaccard_index = size_in / (size_v1 + size_v2 - size_in);
89
90 // Calculate the Jaccard distance
91 // using the formula
92
93 double jaccard_dist = 1 - jaccard_index;
94
95 // Return the Jaccard distance
96
97 return jaccard_dist;
98 }
99
toPaths(char ** argv,int startIndex,int argc)100 QStringList toPaths(char** argv, int startIndex, int argc)
101 {
102 QStringList files;
103
104 for (int i = startIndex ; i < argc ; ++i)
105 {
106 files << QString::fromLatin1(argv[i]);
107 }
108
109 return files;
110 }
111
toImages(const QStringList & paths)112 QList<QImage> toImages(const QStringList& paths)
113 {
114 QList<QImage> images;
115
116 foreach (const QString& path, paths)
117 {
118 images << QImage(path);
119 }
120
121 return images;
122 }
123
prepareForTrain(QString datasetPath,QStringList & images,std::vector<int> & testClusteredIndices)124 int prepareForTrain(QString datasetPath,
125 QStringList& images,
126 std::vector<int>& testClusteredIndices)
127 {
128 if (!datasetPath.endsWith(QLatin1String("/")))
129 {
130 datasetPath.append(QLatin1String("/"));
131 }
132
133 QDir testSet(datasetPath);
134 QStringList subjects = testSet.entryList(QDir::Dirs | QDir::NoDotAndDotDot | QDir::NoSymLinks);
135 int nbOfClusters = subjects.size();
136
137 qCDebug(DIGIKAM_TESTS_LOG) << "Number of clusters to be defined" << nbOfClusters;
138
139 for (int i = 1 ; i <= nbOfClusters ; ++i)
140 {
141 QString subjectPath = QString::fromLatin1("%1%2")
142 .arg(datasetPath)
143 .arg(subjects.takeFirst());
144 QDir subjectDir(subjectPath);
145
146 QStringList files = subjectDir.entryList(QDir::Files);
147 unsigned int nbOfFacesPerClusters = files.size();
148
149 for (unsigned j = 1 ; j <= nbOfFacesPerClusters ; ++j)
150 {
151 QString path = QString::fromLatin1("%1/%2").arg(subjectPath)
152 .arg(files.takeFirst());
153
154 testClusteredIndices.push_back(i - 1);
155 images << path;
156 }
157 }
158
159 qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters (prepareForTrain) " << nbOfClusters;
160
161 return nbOfClusters;
162 }
163
processFaceDetection(const QString & imagePath,FaceDetector detector)164 QList<QRectF> processFaceDetection(const QString& imagePath, FaceDetector detector)
165 {
166 QList<QRectF> detectedFaces = detector.detectFaces(imagePath);
167
168 qCDebug(DIGIKAM_TESTS_LOG) << "(Input CV) Found " << detectedFaces.size() << " faces";
169
170 return detectedFaces;
171 }
172
retrieveFaces(const QList<QImage> & images,const QList<QRectF> & rects)173 QList<QImage> retrieveFaces(const QList<QImage>& images, const QList<QRectF>& rects)
174 {
175 QList<QImage> faces;
176 unsigned index = 0;
177
178 foreach (const QRectF& rect, rects)
179 {
180 DImg temp(images.at(index));
181 faces << temp.copyQImage(rect);
182 ++index;
183 }
184
185 return faces;
186 }
187
createClustersFromClusterIndices(const std::vector<int> & clusteredIndices,QList<std::vector<int>> & clusters)188 void createClustersFromClusterIndices(const std::vector<int>& clusteredIndices,
189 QList<std::vector<int>>& clusters)
190 {
191 int nbOfClusters = 0;
192
193 for (size_t i = 0 ; i < clusteredIndices.size() ; ++i)
194 {
195 int nb = clusteredIndices[i];
196
197 if (nb > nbOfClusters)
198 {
199 nbOfClusters = nb;
200 }
201 }
202
203 nbOfClusters++;
204
205 for (int i = 0 ; i < nbOfClusters ; ++i)
206 {
207 clusters << std::vector<int>();
208 }
209
210 qCDebug(DIGIKAM_TESTS_LOG) << "nbOfClusters " << clusters.size();
211
212 for (int i = 0 ; i < (int)clusteredIndices.size() ; ++i)
213 {
214 clusters[clusteredIndices[i]].push_back(i);
215 }
216 }
217
verifyClusteringResults(const std::vector<int> & clusteredIndices,const std::vector<int> & testClusteredIndices,const QStringList & dataset,QStringList & falsePositiveCases)218 void verifyClusteringResults(const std::vector<int>& clusteredIndices,
219 const std::vector<int>& testClusteredIndices,
220 const QStringList& dataset,
221 QStringList& falsePositiveCases)
222 {
223 QList<std::vector<int>> clusters, testClusters;
224 createClustersFromClusterIndices(clusteredIndices, clusters);
225 createClustersFromClusterIndices(testClusteredIndices, testClusters);
226
227 std::set<int> falsePositivePoints;
228 int testClustersSize = testClusters.size();
229 std::vector<float> visited(testClustersSize, 1.0);
230 std::vector<std::set<int>> lastVisit(testClustersSize, std::set<int>{});
231
232 for (int i = 0 ; i < testClustersSize ; ++i)
233 {
234 std::vector<int> refSet = testClusters.at(i);
235 double minDist = 1.0;
236 int indice = 0;
237
238 for (int j = 0 ; j < clusters.size() ; ++j)
239 {
240 double dist = jaccard_distance(refSet, clusters.at(j));
241
242 if (dist < minDist)
243 {
244 indice = j;
245 minDist = dist;
246 }
247 }
248
249 qCDebug(DIGIKAM_TESTS_LOG) << "testCluster " << i << " with group " << indice;
250
251 std::vector<int> similarSet = clusters.at(indice);
252
253 if (minDist < visited[indice])
254 {
255 visited[indice] = minDist;
256 std::set<int> lastVisitSet = lastVisit[indice];
257 std::set<int> newVisitSet;
258 std::set_symmetric_difference(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(),
259 std::inserter(newVisitSet, newVisitSet.begin()));
260
261 for (int elm: lastVisitSet)
262 {
263 falsePositivePoints.erase(elm);
264 }
265
266 lastVisit[indice] = newVisitSet;
267 falsePositivePoints.insert(newVisitSet.begin(), newVisitSet.end());
268 }
269 else
270 {
271 std::set_intersection(refSet.begin(), refSet.end(), similarSet.begin(), similarSet.end(),
272 std::inserter(falsePositivePoints, falsePositivePoints.begin()));
273 }
274 }
275
276 for (auto indx: falsePositivePoints)
277 {
278 falsePositiveCases << dataset[indx];
279 }
280 }
281
282 // --------------------------------------------------------------------------------------------------
283
main(int argc,char * argv[])284 int main(int argc, char* argv[])
285 {
286 QCoreApplication app(argc, argv);
287 app.setApplicationName(QString::fromLatin1("digikam")); // for DB init.
288
289 // Options for commandline parser
290
291 QCommandLineParser parser;
292 parser.addOption(QCommandLineOption(QLatin1String("db"),
293 QLatin1String("Faces database"),
294 QLatin1String("path to db folder")));
295 parser.addHelpOption();
296 parser.process(app);
297
298 // Parse arguments
299
300 bool optionErrors = false;
301
302 if (parser.optionNames().empty())
303 {
304 qCWarning(DIGIKAM_TESTS_LOG) << "No options!!!";
305 optionErrors = true;
306 }
307 else if (!parser.isSet(QLatin1String("db")))
308 {
309 qCWarning(DIGIKAM_TESTS_LOG) << "Missing database for test!!!";
310 optionErrors = true;
311 }
312
313 if (optionErrors)
314 {
315 parser.showHelp();
316 return 1;
317 }
318
319 QString facedb = parser.value(QLatin1String("db"));
320
321 // Init config for digiKam
322
323 DbEngineParameters prm = DbEngineParameters::parametersFromConfig();
324 CoreDbAccess::setParameters(prm, CoreDbAccess::MainApplication);
325 FacialRecognitionWrapper recognizer;
326
327 //db.setRecognizerThreshold(0.91F); // This is sensitive for the performance of face clustering
328
329 // Construct test set, data set
330
331 QStringList dataset;
332 std::vector<int> testClusteredIndices;
333 int nbOfClusters = prepareForTrain(facedb, dataset, testClusteredIndices);
334
335 // Init FaceDetector used for detecting faces and bounding box
336 // before recognizing
337
338 FaceDetector detector;
339
340 // Evaluation metrics
341
342 unsigned totalClustered = 0;
343 unsigned elapsedClustering = 0;
344
345 QStringList undetectedFaces;
346
347 QList<QImage> detectedFaces;
348 QList<QRectF> bboxes;
349 QList<QImage> rawImages = toImages(dataset);
350
351 foreach (const QImage& image, rawImages)
352 {
353 QString imagePath = dataset.takeFirst();
354 QList<QRectF> detectedBoundingBox = processFaceDetection(imagePath, detector);
355
356 if (detectedBoundingBox.size())
357 {
358 detectedFaces << image;
359 bboxes << detectedBoundingBox.first();
360 dataset << imagePath;
361
362 ++totalClustered;
363 }
364 else
365 {
366 undetectedFaces << imagePath;
367 }
368 }
369
370 std::vector<int> clusteredIndices(dataset.size(), -1);
371 QList<QImage> faces = retrieveFaces(detectedFaces, bboxes);
372
373 QElapsedTimer timer;
374
375 timer.start();
376 /*
377 TODO: port to new API
378 db.clusterFaces(faces, clusteredIndices, dataset, nbOfClusters);
379 */
380 elapsedClustering += timer.elapsed();
381
382 // Verify clustering
383
384 QStringList falsePositiveCases;
385 verifyClusteringResults(clusteredIndices, testClusteredIndices, dataset, falsePositiveCases);
386
387 // Display results
388
389 unsigned nbUndetectedFaces = undetectedFaces.size();
390 qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbUndetectedFaces << " / " << dataset.size() + nbUndetectedFaces
391 << " (" << float(nbUndetectedFaces) / (dataset.size() + nbUndetectedFaces) * 100 << "%)"
392 << " faces cannot be detected";
393
394 foreach (const QString& path, undetectedFaces)
395 {
396 qCDebug(DIGIKAM_TESTS_LOG) << path;
397 }
398
399 unsigned nbOfFalsePositiveCases = falsePositiveCases.size();
400 qCDebug(DIGIKAM_TESTS_LOG) << "\nFalse positive cases";
401 qCDebug(DIGIKAM_TESTS_LOG) << "\n" << nbOfFalsePositiveCases << " / " << dataset.size()
402 << " (" << float(nbOfFalsePositiveCases*100) / dataset.size()<< "%)"
403 << " faces were wrongly clustered";
404
405 foreach (const QString& imagePath, falsePositiveCases)
406 {
407 qCDebug(DIGIKAM_TESTS_LOG) << imagePath;
408 }
409
410 qCDebug(DIGIKAM_TESTS_LOG) << "\n Time for clustering " << elapsedClustering << " ms";
411
412 return 0;
413 }
414