1 //
2 //  MobilenetV2Utils.cpp
3 //  MNN
4 //
5 //  Created by MNN on 2020/01/08.
6 //  Copyright © 2018, Alibaba Group Holding Limited
7 //
8 
9 #include "MobilenetV2Utils.hpp"
10 #include <MNN/expr/Executor.hpp>
11 #include <MNN/expr/Optimizer.hpp>
12 #include <cmath>
13 #include <iostream>
14 #include <vector>
15 #include "DataLoader.hpp"
16 #include "DemoUnit.hpp"
17 #include "NN.hpp"
18 #include "SGD.hpp"
19 #define MNN_OPEN_TIME_TRACE
20 #include <MNN/AutoTime.hpp>
21 #include "ADAM.hpp"
22 #include "LearningRateScheduler.hpp"
23 #include "Loss.hpp"
24 #include "RandomGenerator.hpp"
25 #include "Transformer.hpp"
26 #include "ImageDataset.hpp"
27 #include "module/PipelineModule.hpp"
28 #include "cpp/ConvertToFullQuant.hpp"
29 
30 using namespace MNN;
31 using namespace MNN::Express;
32 using namespace MNN::Train;
33 
train(std::shared_ptr<Module> model,const int numClasses,const int addToLabel,std::string trainImagesFolder,std::string trainImagesTxt,std::string testImagesFolder,std::string testImagesTxt,const int quantBits)34 void MobilenetV2Utils::train(std::shared_ptr<Module> model, const int numClasses, const int addToLabel,
35                                 std::string trainImagesFolder, std::string trainImagesTxt,
36                                 std::string testImagesFolder, std::string testImagesTxt, const int quantBits) {
37     auto exe = Executor::getGlobalExecutor();
38     BackendConfig config;
39     exe->setGlobalExecutorConfig(MNN_FORWARD_USER_1, config, 2);
40     std::shared_ptr<SGD> solver(new SGD(model));
41     solver->setMomentum(0.9f);
42     // solver->setMomentum2(0.99f);
43     solver->setWeightDecay(0.00004f);
44 
45     auto converImagesToFormat  = CV::RGB;
46     int resizeHeight           = 224;
47     int resizeWidth            = 224;
48     std::vector<float> means = {127.5, 127.5, 127.5};
49     std::vector<float> scales = {1/127.5, 1/127.5, 1/127.5};
50     std::vector<float> cropFraction = {0.875, 0.875}; // center crop fraction for height and width
51     bool centerOrRandomCrop = false; // true for random crop
52     std::shared_ptr<ImageDataset::ImageConfig> datasetConfig(ImageDataset::ImageConfig::create(converImagesToFormat, resizeHeight, resizeWidth, scales, means,cropFraction, centerOrRandomCrop));
53     bool readAllImagesToMemory = false;
54     auto trainDataset = ImageDataset::create(trainImagesFolder, trainImagesTxt, datasetConfig.get(), readAllImagesToMemory);
55     auto testDataset = ImageDataset::create(testImagesFolder, testImagesTxt, datasetConfig.get(), readAllImagesToMemory);
56 
57     const int trainBatchSize = 32;
58     const int trainNumWorkers = 4;
59     const int testBatchSize = 10;
60     const int testNumWorkers = 0;
61 
62     auto trainDataLoader = trainDataset.createLoader(trainBatchSize, true, true, trainNumWorkers);
63     auto testDataLoader = testDataset.createLoader(testBatchSize, true, false, testNumWorkers);
64 
65     const int trainIterations = trainDataLoader->iterNumber();
66     const int testIterations = testDataLoader->iterNumber();
67 
68     // const int usedSize = 1000;
69     // const int testIterations = usedSize / testBatchSize;
70 
71     for (int epoch = 0; epoch < 50; ++epoch) {
72         model->clearCache();
73         exe->gc(Executor::FULL);
74         exe->resetProfile();
75         {
76             AUTOTIME;
77             trainDataLoader->reset();
78             model->setIsTraining(true);
79             for (int i = 0; i < trainIterations; i++) {
80                 AUTOTIME;
81                 auto trainData  = trainDataLoader->next();
82                 auto example    = trainData[0];
83 
84                 // Compute One-Hot
85                 auto newTarget = _OneHot(_Cast<int32_t>(_Squeeze(example.second[0] + _Scalar<int32_t>(addToLabel), {})),
86                                   _Scalar<int>(numClasses), _Scalar<float>(1.0f),
87                                          _Scalar<float>(0.0f));
88 
89                 auto predict = model->forward(_Convert(example.first[0], NC4HW4));
90                 auto loss    = _CrossEntropy(predict, newTarget);
91                 // float rate   = LrScheduler::inv(0.0001, solver->currentStep(), 0.0001, 0.75);
92                 float rate = 1e-5;
93                 solver->setLearningRate(rate);
94                 if (solver->currentStep() % 10 == 0) {
95                     std::cout << "train iteration: " << solver->currentStep();
96                     std::cout << " loss: " << loss->readMap<float>()[0];
97                     std::cout << " lr: " << rate << std::endl;
98                 }
99                 solver->step(loss);
100             }
101         }
102 
103         int correct = 0;
104         int sampleCount = 0;
105         testDataLoader->reset();
106         model->setIsTraining(false);
107         exe->gc(Executor::PART);
108 
109         AUTOTIME;
110         for (int i = 0; i < testIterations; i++) {
111             auto data       = testDataLoader->next();
112             auto example    = data[0];
113             auto predict    = model->forward(_Convert(example.first[0], NC4HW4));
114             predict         = _ArgMax(predict, 1); // (N, numClasses) --> (N)
115             auto label = _Squeeze(example.second[0]) + _Scalar<int32_t>(addToLabel);
116             sampleCount += label->getInfo()->size;
117             auto accu       = _Cast<int32_t>(_Equal(predict, label).sum({}));
118             correct += accu->readMap<int32_t>()[0];
119 
120             if ((i + 1) % 10 == 0) {
121                 std::cout << "test iteration: " << (i + 1) << " ";
122                 std::cout << "acc: " << correct << "/" << sampleCount << " = " << float(correct) / sampleCount * 100 << "%";
123                 std::cout << std::endl;
124             }
125         }
126         auto accu = (float)correct / testDataLoader->size();
127         // auto accu = (float)correct / usedSize;
128         std::cout << "epoch: " << epoch << "  accuracy: " << accu << std::endl;
129 
130         {
131             auto forwardInput = _Input({1, 3, resizeHeight, resizeWidth}, NC4HW4);
132             forwardInput->setName("data");
133             auto predict = model->forward(forwardInput);
134             Transformer::turnModelToInfer()->onExecute({predict});
135             predict->setName("prob");
136             std::string fileName = "temp.mobilenetv2.mnn";
137             Variable::save({predict}, fileName.c_str());
138             ConvertToFullQuant::convert(fileName);
139         }
140 
141         exe->dumpProfile();
142     }
143 }
144