1 //
2 //  ImageDataset.hpp
3 //  MNN
4 //
5 //  Created by MNN on 2019/12/30.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #ifndef ImageDataset_hpp
10 #define ImageDataset_hpp
11 
12 #include <string>
13 #include <utility>
14 #include <vector>
15 #include "train/source/data/Dataset.hpp"
16 #include "train/source/data/Example.hpp"
17 #include <MNN/ImageProcess.hpp>
18 
19 //
20 // the ImageDataset read stored images as input data.
21 // use 'pathToImages' and a txt file to construct a ImageDataset.
22 // the txt file should use format as below:
23 //      image1.jpg label1,label2,...
24 //      image2.jpg label3,label4,...
25 //      ...
26 // the ImageDataset would read images from:
27 //      pathToImages/image1.jpg
28 //      pathToImages/image2.jpg
29 //      ...
30 //
31 
32 namespace MNN {
33 namespace Train {
34 class MNN_PUBLIC ImageDataset : public Dataset {
35 public:
36     class ImageConfig {
37     public:
create(CV::ImageFormat destFmt=CV::GRAY,int resizeH=0,int resizeW=0,std::vector<float> s={1, 1, 1, 1},std::vector<float> m={0, 0, 0, 0},std::vector<float> cropFract={1 , 1 },const bool centerOrRandom=false)38         static ImageConfig* create(CV::ImageFormat destFmt = CV::GRAY, int resizeH = 0, int resizeW = 0,
39                     std::vector<float> s = {1, 1, 1, 1}, std::vector<float> m = {0, 0, 0, 0},
40                     std::vector<float> cropFract = {1/*height*/, 1/*width*/}, const bool centerOrRandom = false/*false:center*/) {
41             auto config = new ImageConfig;
42             config->destFormat   = destFmt;
43             config->resizeHeight = resizeH;
44             config->resizeWidth  = resizeW;
45             config->scale = s;
46             config->mean = m;
47             MNN_ASSERT(cropFract.size() == 2);
48             MNN_ASSERT(cropFract[0] > 0 && cropFract[0] <= 1);
49             MNN_ASSERT(cropFract[1] > 0 && cropFract[1] <= 1);
50             config->cropFraction = cropFract;
51             config->centerOrRandomCrop = centerOrRandom;
52             return config;
53         }
54         CV::ImageFormat destFormat;
55         int resizeHeight;
56         int resizeWidth;
57         std::vector<float> scale;
58         std::vector<float> mean;
59         std::vector<float> cropFraction;
60         bool centerOrRandomCrop;
61     };
62 
63     static DatasetPtr create(const std::string pathToImages, const std::string pathToImageTxt,
64                           const ImageConfig* cfg, bool readAllToMemory = false);
65     static Express::VARP convertImage(const std::string& imageName, const ImageConfig& config, const MNN::CV::ImageProcess::Config& cvConfig);
66 
67     Example get(size_t index) override;
68 
69     size_t size() override;
70 
71 private:
ImageDataset()72     ImageDataset(){}
73     bool mReadAllToMemory;
74     std::vector<std::pair<std::string, std::vector<int> > > mAllTxtLines;
75     std::vector<std::pair<VARP, VARP> > mDataAndLabels;
76     ImageConfig mConfig;
77     MNN::CV::ImageProcess::Config mProcessConfig;
78 
79     void getAllDataAndLabelsFromTxt(const std::string pathToImages, std::string pathToImageTxt);
80     std::pair<VARP, VARP> getDataAndLabelsFrom(std::pair<std::string, std::vector<int> > dataAndLabels);
81 };
82 } // namespace Train
83 } // namespace MNN
84 
85 #endif // ImageDataset_hpp
86