1 #include <algorithm>
2 #include <cctype>
3 #include <cmath>
4 #include <iostream>
5 #include <limits>
6 #include <numeric>
7 #include <stdexcept>
8 #include <string>
9 #include <vector>
10 
11 #include <opencv2/gapi.hpp>
12 #include <opencv2/gapi/core.hpp>
13 #include <opencv2/gapi/cpu/gcpukernel.hpp>
14 #include <opencv2/gapi/infer.hpp>
15 #include <opencv2/gapi/infer/ie.hpp>
16 #include <opencv2/gapi/streaming/cap.hpp>
17 
18 #include <opencv2/highgui.hpp>
19 #include <opencv2/core/utility.hpp>
20 
21 const std::string about =
22     "This is an OpenCV-based version of OMZ Text Detection example";
23 const std::string keys =
24     "{ h help |                           | Print this help message }"
25     "{ input  |                           | Path to the input video file }"
26     "{ tdm    | text-detection-0004.xml   | Path to OpenVINO text detection model (.xml), versions 0003 and 0004 work }"
27     "{ tdd    | CPU                       | Target device for the text detector (e.g. CPU, GPU, VPU, ...) }"
28     "{ trm    | text-recognition-0012.xml | Path to OpenVINO text recognition model (.xml) }"
29     "{ trd    | CPU                       | Target device for the text recognition (e.g. CPU, GPU, VPU, ...) }"
30     "{ bw     | 0                         | CTC beam search decoder bandwidth, if 0, a CTC greedy decoder is used}"
31     "{ sset   | 0123456789abcdefghijklmnopqrstuvwxyz | Symbol set to use with text recognition decoder. Shouldn't contain symbol #. }"
32     "{ thr    | 0.2                       | Text recognition confidence threshold}"
33     ;
34 
35 namespace {
weights_path(const std::string & model_path)36 std::string weights_path(const std::string &model_path) {
37     const auto EXT_LEN = 4u;
38     const auto sz = model_path.size();
39     CV_Assert(sz > EXT_LEN);
40 
41     const auto ext = model_path.substr(sz - EXT_LEN);
42     CV_Assert(cv::toLowerCase(ext) == ".xml");
43     return model_path.substr(0u, sz - EXT_LEN) + ".bin";
44 }
45 
46 //////////////////////////////////////////////////////////////////////
47 // Taken from OMZ samples as-is
48 template<typename Iter>
softmax_and_choose(Iter begin,Iter end,int * argmax,float * prob)49 void softmax_and_choose(Iter begin, Iter end, int *argmax, float *prob) {
50     auto max_element = std::max_element(begin, end);
51     *argmax = static_cast<int>(std::distance(begin, max_element));
52     float max_val = *max_element;
53     double sum = 0;
54     for (auto i = begin; i != end; i++) {
55        sum += std::exp((*i) - max_val);
56     }
57     if (std::fabs(sum) < std::numeric_limits<double>::epsilon()) {
58         throw std::logic_error("sum can't be equal to zero");
59     }
60     *prob = 1.0f / static_cast<float>(sum);
61 }
62 
63 template<typename Iter>
softmax(Iter begin,Iter end)64 std::vector<float> softmax(Iter begin, Iter end) {
65     std::vector<float> prob(end - begin, 0.f);
66     std::transform(begin, end, prob.begin(), [](float x) { return std::exp(x); });
67     float sum = std::accumulate(prob.begin(), prob.end(), 0.0f);
68     for (int i = 0; i < static_cast<int>(prob.size()); i++)
69         prob[i] /= sum;
70     return prob;
71 }
72 
73 struct BeamElement {
74     std::vector<int> sentence;   //!< The sequence of chars that will be a result of the beam element
75 
76     float prob_blank;            //!< The probability that the last char in CTC sequence
77                                  //!< for the beam element is the special blank char
78 
79     float prob_not_blank;        //!< The probability that the last char in CTC sequence
80                                  //!< for the beam element is NOT the special blank char
81 
prob__anonce03ae130111::BeamElement82     float prob() const {         //!< The probability of the beam element.
83         return prob_blank + prob_not_blank;
84     }
85 };
86 
CTCGreedyDecoder(const float * data,const std::size_t sz,const std::string & alphabet,const char pad_symbol,double * conf)87 std::string CTCGreedyDecoder(const float *data,
88                              const std::size_t sz,
89                              const std::string &alphabet,
90                              const char pad_symbol,
91                              double *conf) {
92     std::string res = "";
93     bool prev_pad = false;
94     *conf = 1;
95 
96     const auto num_classes = alphabet.length();
97     for (auto it = data; it != (data+sz); it += num_classes) {
98         int argmax = 0;
99         float prob = 0.f;
100 
101         softmax_and_choose(it, it + num_classes, &argmax, &prob);
102         (*conf) *= prob;
103 
104         auto symbol = alphabet[argmax];
105         if (symbol != pad_symbol) {
106             if (res.empty() || prev_pad || (!res.empty() && symbol != res.back())) {
107                 prev_pad = false;
108                 res += symbol;
109             }
110         } else {
111             prev_pad = true;
112         }
113     }
114     return res;
115 }
116 
CTCBeamSearchDecoder(const float * data,const std::size_t sz,const std::string & alphabet,double * conf,int bandwidth)117 std::string CTCBeamSearchDecoder(const float *data,
118                                  const std::size_t sz,
119                                  const std::string &alphabet,
120                                  double *conf,
121                                  int bandwidth) {
122     const auto num_classes = alphabet.length();
123 
124     std::vector<BeamElement> curr;
125     std::vector<BeamElement> last;
126 
127     last.push_back(BeamElement{std::vector<int>(), 1.f, 0.f});
128 
129     for (auto it = data; it != (data+sz); it += num_classes) {
130         curr.clear();
131 
132         std::vector<float> prob = softmax(it, it + num_classes);
133 
134         for(const auto& candidate: last) {
135             float prob_not_blank = 0.f;
136             const std::vector<int>& candidate_sentence = candidate.sentence;
137             if (!candidate_sentence.empty()) {
138                 int n = candidate_sentence.back();
139                 prob_not_blank = candidate.prob_not_blank * prob[n];
140             }
141             float prob_blank = candidate.prob() * prob[num_classes - 1];
142 
143             auto check_res = std::find_if(curr.begin(),
144                                           curr.end(),
145                                           [&candidate_sentence](const BeamElement& n) {
146                                               return n.sentence == candidate_sentence;
147                                           });
148             if (check_res == std::end(curr)) {
149                 curr.push_back(BeamElement{candidate.sentence, prob_blank, prob_not_blank});
150             } else {
151                 check_res->prob_not_blank  += prob_not_blank;
152                 if (check_res->prob_blank != 0.f) {
153                     throw std::logic_error("Probability that the last char in CTC-sequence "
154                                            "is the special blank char must be zero here");
155                 }
156                 check_res->prob_blank = prob_blank;
157             }
158 
159             for (int i = 0; i < static_cast<int>(num_classes) - 1; i++) {
160                 auto extend = candidate_sentence;
161                 extend.push_back(i);
162 
163                 if (candidate_sentence.size() > 0 && candidate.sentence.back() == i) {
164                     prob_not_blank = prob[i] * candidate.prob_blank;
165                 } else {
166                     prob_not_blank = prob[i] * candidate.prob();
167                 }
168 
169                 auto check_res2 = std::find_if(curr.begin(),
170                                               curr.end(),
171                                               [&extend](const BeamElement &n) {
172                                                   return n.sentence == extend;
173                                               });
174                 if (check_res2 == std::end(curr)) {
175                     curr.push_back(BeamElement{extend, 0.f, prob_not_blank});
176                 } else {
177                     check_res2->prob_not_blank += prob_not_blank;
178                 }
179             }
180         }
181 
182         sort(curr.begin(), curr.end(), [](const BeamElement &a, const BeamElement &b) -> bool {
183             return a.prob() > b.prob();
184         });
185 
186         last.clear();
187         int num_to_copy = std::min(bandwidth, static_cast<int>(curr.size()));
188         for (int b = 0; b < num_to_copy; b++) {
189             last.push_back(curr[b]);
190         }
191     }
192 
193     *conf = last[0].prob();
194     std::string res="";
195     for (const auto& idx: last[0].sentence) {
196         res += alphabet[idx];
197     }
198 
199     return res;
200 }
201 
202 //////////////////////////////////////////////////////////////////////
203 } // anonymous namespace
204 
205 namespace custom {
206 namespace {
207 
208 //////////////////////////////////////////////////////////////////////
209 // Define networks for this sample
210 using GMat2 = std::tuple<cv::GMat, cv::GMat>;
211 G_API_NET(TextDetection,
212           <GMat2(cv::GMat)>,
213           "sample.custom.text_detect");
214 
215 G_API_NET(TextRecognition,
216           <cv::GMat(cv::GMat)>,
217           "sample.custom.text_recogn");
218 
219 // Define custom operations
220 using GSize = cv::GOpaque<cv::Size>;
221 using GRRects = cv::GArray<cv::RotatedRect>;
222 G_API_OP(PostProcess,
223         <GRRects(cv::GMat,cv::GMat,GSize,float,float)>,
224         "sample.custom.text.post_proc") {
outMeta(const cv::GMatDesc &,const cv::GMatDesc &,const cv::GOpaqueDesc &,float,float)225     static cv::GArrayDesc outMeta(const cv::GMatDesc &,
226                                   const cv::GMatDesc &,
227                                   const cv::GOpaqueDesc &,
228                                   float,
229                                   float) {
230         return cv::empty_array_desc();
231     }
232 };
233 
234 using GMats = cv::GArray<cv::GMat>;
235 G_API_OP(CropLabels,
236          <GMats(cv::GMat,GRRects,GSize)>,
237          "sample.custom.text.crop") {
outMeta(const cv::GMatDesc &,const cv::GArrayDesc &,const cv::GOpaqueDesc &)238     static cv::GArrayDesc outMeta(const cv::GMatDesc &,
239                                   const cv::GArrayDesc &,
240                                   const cv::GOpaqueDesc &) {
241         return cv::empty_array_desc();
242     }
243 };
244 
245 //////////////////////////////////////////////////////////////////////
246 // Implement custom operations
GAPI_OCV_KERNEL(OCVPostProcess,PostProcess)247 GAPI_OCV_KERNEL(OCVPostProcess, PostProcess) {
248     static void run(const cv::Mat &link,
249                     const cv::Mat &segm,
250                     const cv::Size &img_size,
251                     const float link_threshold,
252                     const float segm_threshold,
253                     std::vector<cv::RotatedRect> &out) {
254         // NOTE: Taken from the OMZ text detection sample almost as-is
255         const int kMinArea = 300;
256         const int kMinHeight = 10;
257 
258         const float *link_data_pointer = link.ptr<float>();
259         std::vector<float> link_data(link_data_pointer, link_data_pointer + link.total());
260         link_data = transpose4d(link_data, dimsToShape(link.size), {0, 2, 3, 1});
261         softmax(link_data);
262         link_data = sliceAndGetSecondChannel(link_data);
263         std::vector<int> new_link_data_shape = {
264             link.size[0],
265             link.size[2],
266             link.size[3],
267             link.size[1]/2,
268         };
269 
270         const float *cls_data_pointer = segm.ptr<float>();
271         std::vector<float> cls_data(cls_data_pointer, cls_data_pointer + segm.total());
272         cls_data = transpose4d(cls_data, dimsToShape(segm.size), {0, 2, 3, 1});
273         softmax(cls_data);
274         cls_data = sliceAndGetSecondChannel(cls_data);
275         std::vector<int> new_cls_data_shape = {
276             segm.size[0],
277             segm.size[2],
278             segm.size[3],
279             segm.size[1]/2,
280         };
281 
282         out = maskToBoxes(decodeImageByJoin(cls_data, new_cls_data_shape,
283                                             link_data, new_link_data_shape,
284                                             segm_threshold, link_threshold),
285                           static_cast<float>(kMinArea),
286                           static_cast<float>(kMinHeight),
287                           img_size);
288     }
289 
290     static std::vector<std::size_t> dimsToShape(const cv::MatSize &sz) {
291         const int n_dims = sz.dims();
292         std::vector<std::size_t> result;
293         result.reserve(n_dims);
294 
295         // cv::MatSize is not iterable...
296         for (int i = 0; i < n_dims; i++) {
297             result.emplace_back(static_cast<std::size_t>(sz[i]));
298         }
299         return result;
300     }
301 
302     static void softmax(std::vector<float> &rdata) {
303         // NOTE: Taken from the OMZ text detection sample almost as-is
304         const size_t last_dim = 2;
305         for (size_t i = 0 ; i < rdata.size(); i+=last_dim) {
306             float m = std::max(rdata[i], rdata[i+1]);
307             rdata[i] = std::exp(rdata[i] - m);
308             rdata[i + 1] = std::exp(rdata[i + 1] - m);
309             float s = rdata[i] + rdata[i + 1];
310             rdata[i] /= s;
311             rdata[i + 1] /= s;
312         }
313     }
314 
315     static std::vector<float> transpose4d(const std::vector<float> &data,
316                                           const std::vector<size_t> &shape,
317                                           const std::vector<size_t> &axes) {
318         // NOTE: Taken from the OMZ text detection sample almost as-is
319         if (shape.size() != axes.size())
320             throw std::runtime_error("Shape and axes must have the same dimension.");
321 
322         for (size_t a : axes) {
323             if (a >= shape.size())
324                 throw std::runtime_error("Axis must be less than dimension of shape.");
325         }
326         size_t total_size = shape[0]*shape[1]*shape[2]*shape[3];
327         std::vector<size_t> steps {
328             shape[axes[1]]*shape[axes[2]]*shape[axes[3]],
329             shape[axes[2]]*shape[axes[3]],
330             shape[axes[3]],
331             1
332          };
333 
334         size_t source_data_idx = 0;
335         std::vector<float> new_data(total_size, 0);
336         std::vector<size_t> ids(shape.size());
337         for (ids[0] = 0; ids[0] < shape[0]; ids[0]++) {
338             for (ids[1] = 0; ids[1] < shape[1]; ids[1]++) {
339                 for (ids[2] = 0; ids[2] < shape[2]; ids[2]++) {
340                     for (ids[3]= 0; ids[3] < shape[3]; ids[3]++) {
341                         size_t new_data_idx = ids[axes[0]]*steps[0] + ids[axes[1]]*steps[1] +
342                             ids[axes[2]]*steps[2] + ids[axes[3]]*steps[3];
343                         new_data[new_data_idx] = data[source_data_idx++];
344                     }
345                 }
346             }
347         }
348         return new_data;
349     }
350 
351     static std::vector<float> sliceAndGetSecondChannel(const std::vector<float> &data) {
352         // NOTE: Taken from the OMZ text detection sample almost as-is
353         std::vector<float> new_data(data.size() / 2, 0);
354         for (size_t i = 0; i < data.size() / 2; i++) {
355             new_data[i] = data[2 * i + 1];
356         }
357         return new_data;
358     }
359 
360     static void join(const int p1,
361                      const int p2,
362                      std::unordered_map<int, int> &group_mask) {
363         // NOTE: Taken from the OMZ text detection sample almost as-is
364         const int root1 = findRoot(p1, group_mask);
365         const int root2 = findRoot(p2, group_mask);
366         if (root1 != root2) {
367             group_mask[root1] = root2;
368         }
369     }
370 
371     static cv::Mat decodeImageByJoin(const std::vector<float> &cls_data,
372                                      const std::vector<int>   &cls_data_shape,
373                                      const std::vector<float> &link_data,
374                                      const std::vector<int>   &link_data_shape,
375                                      float cls_conf_threshold,
376                                      float link_conf_threshold) {
377         // NOTE: Taken from the OMZ text detection sample almost as-is
378         const int h = cls_data_shape[1];
379         const int w = cls_data_shape[2];
380 
381         std::vector<uchar> pixel_mask(h * w, 0);
382         std::unordered_map<int, int> group_mask;
383         std::vector<cv::Point> points;
384         for (int i = 0; i < static_cast<int>(pixel_mask.size()); i++) {
385             pixel_mask[i] = cls_data[i] >= cls_conf_threshold;
386             if (pixel_mask[i]) {
387                 points.emplace_back(i % w, i / w);
388                 group_mask[i] = -1;
389             }
390         }
391         std::vector<uchar> link_mask(link_data.size(), 0);
392         for (size_t i = 0; i < link_mask.size(); i++) {
393             link_mask[i] = link_data[i] >= link_conf_threshold;
394         }
395         size_t neighbours = size_t(link_data_shape[3]);
396         for (const auto &point : points) {
397             size_t neighbour = 0;
398             for (int ny = point.y - 1; ny <= point.y + 1; ny++) {
399                 for (int nx = point.x - 1; nx <= point.x + 1; nx++) {
400                     if (nx == point.x && ny == point.y)
401                         continue;
402                     if (nx >= 0 && nx < w && ny >= 0 && ny < h) {
403                         uchar pixel_value = pixel_mask[size_t(ny) * size_t(w) + size_t(nx)];
404                         uchar link_value = link_mask[(size_t(point.y) * size_t(w) + size_t(point.x))
405                                                      *neighbours + neighbour];
406                         if (pixel_value && link_value) {
407                             join(point.x + point.y * w, nx + ny * w, group_mask);
408                         }
409                     }
410                     neighbour++;
411                 }
412             }
413         }
414         return get_all(points, w, h, group_mask);
415     }
416 
417     static cv::Mat get_all(const std::vector<cv::Point> &points,
418                            const int w,
419                            const int h,
420                            std::unordered_map<int, int> &group_mask) {
421         // NOTE: Taken from the OMZ text detection sample almost as-is
422         std::unordered_map<int, int> root_map;
423         cv::Mat mask(h, w, CV_32S, cv::Scalar(0));
424         for (const auto &point : points) {
425             int point_root = findRoot(point.x + point.y * w, group_mask);
426             if (root_map.find(point_root) == root_map.end()) {
427                 root_map.emplace(point_root, static_cast<int>(root_map.size() + 1));
428             }
429             mask.at<int>(point.x + point.y * w) = root_map[point_root];
430         }
431         return mask;
432     }
433 
434     static int findRoot(const int point,
435                         std::unordered_map<int, int> &group_mask) {
436         // NOTE: Taken from the OMZ text detection sample almost as-is
437         int root = point;
438         bool update_parent = false;
439         while (group_mask.at(root) != -1) {
440             root = group_mask.at(root);
441             update_parent = true;
442         }
443         if (update_parent) {
444             group_mask[point] = root;
445         }
446         return root;
447     }
448 
449     static std::vector<cv::RotatedRect> maskToBoxes(const cv::Mat &mask,
450                                                     const float min_area,
451                                                     const float min_height,
452                                                     const cv::Size &image_size) {
453         // NOTE: Taken from the OMZ text detection sample almost as-is
454         std::vector<cv::RotatedRect> bboxes;
455         double min_val = 0.;
456         double max_val = 0.;
457         cv::minMaxLoc(mask, &min_val, &max_val);
458         int max_bbox_idx = static_cast<int>(max_val);
459         cv::Mat resized_mask;
460         cv::resize(mask, resized_mask, image_size, 0, 0, cv::INTER_NEAREST);
461 
462         for (int i = 1; i <= max_bbox_idx; i++) {
463             cv::Mat bbox_mask = resized_mask == i;
464             std::vector<std::vector<cv::Point>> contours;
465 
466             cv::findContours(bbox_mask, contours, cv::RETR_CCOMP, cv::CHAIN_APPROX_SIMPLE);
467             if (contours.empty())
468                 continue;
469             cv::RotatedRect r = cv::minAreaRect(contours[0]);
470             if (std::min(r.size.width, r.size.height) < min_height)
471                 continue;
472             if (r.size.area() < min_area)
473                 continue;
474             bboxes.emplace_back(r);
475         }
476         return bboxes;
477     }
478 }; // GAPI_OCV_KERNEL(PostProcess)
479 
GAPI_OCV_KERNEL(OCVCropLabels,CropLabels)480 GAPI_OCV_KERNEL(OCVCropLabels, CropLabels) {
481     static void run(const cv::Mat &image,
482                     const std::vector<cv::RotatedRect> &detections,
483                     const cv::Size &outSize,
484                     std::vector<cv::Mat> &out) {
485         out.clear();
486         out.reserve(detections.size());
487         cv::Mat crop(outSize, CV_8UC3, cv::Scalar(0));
488         cv::Mat gray(outSize, CV_8UC1, cv::Scalar(0));
489         std::vector<int> blob_shape = {1,1,outSize.height,outSize.width};
490 
491         for (auto &&rr : detections) {
492             std::vector<cv::Point2f> points(4);
493             rr.points(points.data());
494 
495             const auto top_left_point_idx = topLeftPointIdx(points);
496             cv::Point2f point0 = points[static_cast<size_t>(top_left_point_idx)];
497             cv::Point2f point1 = points[(top_left_point_idx + 1) % 4];
498             cv::Point2f point2 = points[(top_left_point_idx + 2) % 4];
499 
500             std::vector<cv::Point2f> from{point0, point1, point2};
501             std::vector<cv::Point2f> to{
502                 cv::Point2f(0.0f, 0.0f),
503                 cv::Point2f(static_cast<float>(outSize.width-1), 0.0f),
504                 cv::Point2f(static_cast<float>(outSize.width-1),
505                             static_cast<float>(outSize.height-1))
506             };
507             cv::Mat M = cv::getAffineTransform(from, to);
508             cv::warpAffine(image, crop, M, outSize);
509             cv::cvtColor(crop, gray, cv::COLOR_BGR2GRAY);
510 
511             cv::Mat blob;
512             gray.convertTo(blob, CV_32F);
513             out.push_back(blob.reshape(1, blob_shape)); // pass as 1,1,H,W instead of H,W
514         }
515     }
516 
517     static int topLeftPointIdx(const std::vector<cv::Point2f> &points) {
518         // NOTE: Taken from the OMZ text detection sample almost as-is
519         cv::Point2f most_left(std::numeric_limits<float>::max(),
520                               std::numeric_limits<float>::max());
521         cv::Point2f almost_most_left(std::numeric_limits<float>::max(),
522                                      std::numeric_limits<float>::max());
523         int most_left_idx = -1;
524         int almost_most_left_idx = -1;
525 
526         for (size_t i = 0; i < points.size() ; i++) {
527             if (most_left.x > points[i].x) {
528                 if (most_left.x < std::numeric_limits<float>::max()) {
529                     almost_most_left = most_left;
530                     almost_most_left_idx = most_left_idx;
531                 }
532                 most_left = points[i];
533                 most_left_idx = static_cast<int>(i);
534             }
535             if (almost_most_left.x > points[i].x && points[i] != most_left) {
536                 almost_most_left = points[i];
537                 almost_most_left_idx = static_cast<int>(i);
538             }
539         }
540 
541         if (almost_most_left.y < most_left.y) {
542             most_left = almost_most_left;
543             most_left_idx = almost_most_left_idx;
544         }
545         return most_left_idx;
546     }
547 
548 }; // GAPI_OCV_KERNEL(CropLabels)
549 
550 } // anonymous namespace
551 } // namespace custom
552 
553 namespace vis {
554 namespace {
555 
drawRotatedRect(cv::Mat & m,const cv::RotatedRect & rc)556 void drawRotatedRect(cv::Mat &m, const cv::RotatedRect &rc) {
557     std::vector<cv::Point2f> tmp_points(5);
558     rc.points(tmp_points.data());
559     tmp_points[4] = tmp_points[0];
560     auto prev = tmp_points.begin(), it = prev+1;
561     for (; it != tmp_points.end(); ++it) {
562         cv::line(m, *prev, *it, cv::Scalar(50, 205, 50), 2);
563         prev = it;
564     }
565 }
566 
drawText(cv::Mat & m,const cv::RotatedRect & rc,const std::string & str)567 void drawText(cv::Mat &m, const cv::RotatedRect &rc, const std::string &str) {
568     const int    fface   = cv::FONT_HERSHEY_SIMPLEX;
569     const double scale   = 0.7;
570     const int    thick   = 1;
571           int    base    = 0;
572     const auto text_size = cv::getTextSize(str, fface, scale, thick, &base);
573 
574     std::vector<cv::Point2f> tmp_points(4);
575     rc.points(tmp_points.data());
576     const auto tl_point_idx = custom::OCVCropLabels::topLeftPointIdx(tmp_points);
577     cv::Point text_pos = tmp_points[tl_point_idx];
578     text_pos.x = std::max(0, text_pos.x);
579     text_pos.y = std::max(text_size.height, text_pos.y);
580 
581     cv::rectangle(m,
582                   text_pos + cv::Point{0, base},
583                   text_pos + cv::Point{text_size.width, -text_size.height},
584                   CV_RGB(50, 205, 50),
585                   cv::FILLED);
586     const auto white = CV_RGB(255, 255, 255);
587     cv::putText(m, str, text_pos, fface, scale, white, thick, 8);
588 }
589 
590 } // anonymous namespace
591 } // namespace vis
592 
main(int argc,char * argv[])593 int main(int argc, char *argv[])
594 {
595     cv::CommandLineParser cmd(argc, argv, keys);
596     cmd.about(about);
597     if (cmd.has("help")) {
598         cmd.printMessage();
599         return 0;
600     }
601     const auto input_file_name = cmd.get<std::string>("input");
602     const auto tdet_model_path = cmd.get<std::string>("tdm");
603     const auto trec_model_path = cmd.get<std::string>("trm");
604     const auto tdet_target_dev = cmd.get<std::string>("tdd");
605     const auto trec_target_dev = cmd.get<std::string>("trd");
606     const auto ctc_beam_dec_bw = cmd.get<int>("bw");
607     const auto dec_conf_thresh = cmd.get<double>("thr");
608 
609     const auto pad_symbol      = '#';
610     const auto symbol_set      = cmd.get<std::string>("sset") + pad_symbol;
611 
612     cv::GMat in;
613     cv::GOpaque<cv::Size> in_rec_sz;
614     cv::GMat link, segm;
615     std::tie(link, segm) = cv::gapi::infer<custom::TextDetection>(in);
616     cv::GOpaque<cv::Size> size = cv::gapi::streaming::size(in);
617     cv::GArray<cv::RotatedRect> rrs = custom::PostProcess::on(link, segm, size, 0.8f, 0.8f);
618     cv::GArray<cv::GMat> labels = custom::CropLabels::on(in, rrs, in_rec_sz);
619     cv::GArray<cv::GMat> text = cv::gapi::infer2<custom::TextRecognition>(in, labels);
620 
621     cv::GComputation graph(cv::GIn(in, in_rec_sz),
622                            cv::GOut(cv::gapi::copy(in), rrs, text));
623 
624     // Text detection network
625     auto tdet_net = cv::gapi::ie::Params<custom::TextDetection> {
626         tdet_model_path,                // path to topology IR
627         weights_path(tdet_model_path),  // path to weights
628         tdet_target_dev,                // device specifier
629     }.cfgOutputLayers({"model/link_logits_/add", "model/segm_logits/add"});
630 
631     auto trec_net = cv::gapi::ie::Params<custom::TextRecognition> {
632         trec_model_path,                // path to topology IR
633         weights_path(trec_model_path),  // path to weights
634         trec_target_dev,                // device specifier
635     };
636     auto networks = cv::gapi::networks(tdet_net, trec_net);
637 
638     auto kernels = cv::gapi::kernels< custom::OCVPostProcess
639                                     , custom::OCVCropLabels
640                                     >();
641     auto pipeline = graph.compileStreaming(cv::compile_args(kernels, networks));
642 
643     std::cout << "Reading " << input_file_name << std::endl;
644 
645     // Input stream
646     auto in_src = cv::gapi::wip::make_src<cv::gapi::wip::GCaptureSource>(input_file_name);
647 
648     // Text recognition input size (also an input parameter to the graph)
649     auto in_rsz = cv::Size{ 120, 32 };
650 
651     // Set the pipeline source & start the pipeline
652     pipeline.setSource(cv::gin(in_src, in_rsz));
653     pipeline.start();
654 
655     // Declare the output data & run the processing loop
656     cv::TickMeter tm;
657     cv::Mat image;
658     std::vector<cv::RotatedRect> out_rcs;
659     std::vector<cv::Mat> out_text;
660 
661     tm.start();
662     int frames = 0;
663     while (pipeline.pull(cv::gout(image, out_rcs, out_text))) {
664         frames++;
665 
666         CV_Assert(out_rcs.size() == out_text.size());
667         const auto num_labels = out_rcs.size();
668 
669         std::vector<cv::Point2f> tmp_points(4);
670         for (std::size_t l = 0; l < num_labels; l++) {
671             // Decode the recognized text in the rectangle
672             const auto &blob = out_text[l];
673             const float *data = blob.ptr<float>();
674             const auto sz = blob.total();
675             double conf = 1.0;
676             const std::string res = ctc_beam_dec_bw == 0
677                 ? CTCGreedyDecoder(data, sz, symbol_set, pad_symbol, &conf)
678                 : CTCBeamSearchDecoder(data, sz, symbol_set, &conf, ctc_beam_dec_bw);
679 
680             // Draw a bounding box for this rotated rectangle
681             const auto &rc = out_rcs[l];
682             vis::drawRotatedRect(image, rc);
683 
684             // Draw text, if decoded
685             if (conf >= dec_conf_thresh) {
686                 vis::drawText(image, rc, res);
687             }
688         }
689         tm.stop();
690         cv::imshow("Out", image);
691         cv::waitKey(1);
692         tm.start();
693     }
694     tm.stop();
695     std::cout << "Processed " << frames << " frames"
696               << " (" << frames / tm.getTimeSec() << " FPS)" << std::endl;
697     return 0;
698 }
699