1 //
2 // MobilenetV2Utils.cpp
3 // MNN
4 //
5 // Created by MNN on 2020/01/08.
6 // Copyright © 2018, Alibaba Group Holding Limited
7 //
8
9 #include "MobilenetV2Utils.hpp"
10 #include <MNN/expr/Executor.hpp>
11 #include <MNN/expr/Optimizer.hpp>
12 #include <cmath>
13 #include <iostream>
14 #include <vector>
15 #include "DataLoader.hpp"
16 #include "DemoUnit.hpp"
17 #include "NN.hpp"
18 #include "SGD.hpp"
19 #define MNN_OPEN_TIME_TRACE
20 #include <MNN/AutoTime.hpp>
21 #include "ADAM.hpp"
22 #include "LearningRateScheduler.hpp"
23 #include "Loss.hpp"
24 #include "RandomGenerator.hpp"
25 #include "Transformer.hpp"
26 #include "ImageDataset.hpp"
27 #include "module/PipelineModule.hpp"
28 #include "cpp/ConvertToFullQuant.hpp"
29
30 using namespace MNN;
31 using namespace MNN::Express;
32 using namespace MNN::Train;
33
train(std::shared_ptr<Module> model,const int numClasses,const int addToLabel,std::string trainImagesFolder,std::string trainImagesTxt,std::string testImagesFolder,std::string testImagesTxt,const int quantBits)34 void MobilenetV2Utils::train(std::shared_ptr<Module> model, const int numClasses, const int addToLabel,
35 std::string trainImagesFolder, std::string trainImagesTxt,
36 std::string testImagesFolder, std::string testImagesTxt, const int quantBits) {
37 auto exe = Executor::getGlobalExecutor();
38 BackendConfig config;
39 exe->setGlobalExecutorConfig(MNN_FORWARD_USER_1, config, 2);
40 std::shared_ptr<SGD> solver(new SGD(model));
41 solver->setMomentum(0.9f);
42 // solver->setMomentum2(0.99f);
43 solver->setWeightDecay(0.00004f);
44
45 auto converImagesToFormat = CV::RGB;
46 int resizeHeight = 224;
47 int resizeWidth = 224;
48 std::vector<float> means = {127.5, 127.5, 127.5};
49 std::vector<float> scales = {1/127.5, 1/127.5, 1/127.5};
50 std::vector<float> cropFraction = {0.875, 0.875}; // center crop fraction for height and width
51 bool centerOrRandomCrop = false; // true for random crop
52 std::shared_ptr<ImageDataset::ImageConfig> datasetConfig(ImageDataset::ImageConfig::create(converImagesToFormat, resizeHeight, resizeWidth, scales, means,cropFraction, centerOrRandomCrop));
53 bool readAllImagesToMemory = false;
54 auto trainDataset = ImageDataset::create(trainImagesFolder, trainImagesTxt, datasetConfig.get(), readAllImagesToMemory);
55 auto testDataset = ImageDataset::create(testImagesFolder, testImagesTxt, datasetConfig.get(), readAllImagesToMemory);
56
57 const int trainBatchSize = 32;
58 const int trainNumWorkers = 4;
59 const int testBatchSize = 10;
60 const int testNumWorkers = 0;
61
62 auto trainDataLoader = trainDataset.createLoader(trainBatchSize, true, true, trainNumWorkers);
63 auto testDataLoader = testDataset.createLoader(testBatchSize, true, false, testNumWorkers);
64
65 const int trainIterations = trainDataLoader->iterNumber();
66 const int testIterations = testDataLoader->iterNumber();
67
68 // const int usedSize = 1000;
69 // const int testIterations = usedSize / testBatchSize;
70
71 for (int epoch = 0; epoch < 50; ++epoch) {
72 model->clearCache();
73 exe->gc(Executor::FULL);
74 exe->resetProfile();
75 {
76 AUTOTIME;
77 trainDataLoader->reset();
78 model->setIsTraining(true);
79 for (int i = 0; i < trainIterations; i++) {
80 AUTOTIME;
81 auto trainData = trainDataLoader->next();
82 auto example = trainData[0];
83
84 // Compute One-Hot
85 auto newTarget = _OneHot(_Cast<int32_t>(_Squeeze(example.second[0] + _Scalar<int32_t>(addToLabel), {})),
86 _Scalar<int>(numClasses), _Scalar<float>(1.0f),
87 _Scalar<float>(0.0f));
88
89 auto predict = model->forward(_Convert(example.first[0], NC4HW4));
90 auto loss = _CrossEntropy(predict, newTarget);
91 // float rate = LrScheduler::inv(0.0001, solver->currentStep(), 0.0001, 0.75);
92 float rate = 1e-5;
93 solver->setLearningRate(rate);
94 if (solver->currentStep() % 10 == 0) {
95 std::cout << "train iteration: " << solver->currentStep();
96 std::cout << " loss: " << loss->readMap<float>()[0];
97 std::cout << " lr: " << rate << std::endl;
98 }
99 solver->step(loss);
100 }
101 }
102
103 int correct = 0;
104 int sampleCount = 0;
105 testDataLoader->reset();
106 model->setIsTraining(false);
107 exe->gc(Executor::PART);
108
109 AUTOTIME;
110 for (int i = 0; i < testIterations; i++) {
111 auto data = testDataLoader->next();
112 auto example = data[0];
113 auto predict = model->forward(_Convert(example.first[0], NC4HW4));
114 predict = _ArgMax(predict, 1); // (N, numClasses) --> (N)
115 auto label = _Squeeze(example.second[0]) + _Scalar<int32_t>(addToLabel);
116 sampleCount += label->getInfo()->size;
117 auto accu = _Cast<int32_t>(_Equal(predict, label).sum({}));
118 correct += accu->readMap<int32_t>()[0];
119
120 if ((i + 1) % 10 == 0) {
121 std::cout << "test iteration: " << (i + 1) << " ";
122 std::cout << "acc: " << correct << "/" << sampleCount << " = " << float(correct) / sampleCount * 100 << "%";
123 std::cout << std::endl;
124 }
125 }
126 auto accu = (float)correct / testDataLoader->size();
127 // auto accu = (float)correct / usedSize;
128 std::cout << "epoch: " << epoch << " accuracy: " << accu << std::endl;
129
130 {
131 auto forwardInput = _Input({1, 3, resizeHeight, resizeWidth}, NC4HW4);
132 forwardInput->setName("data");
133 auto predict = model->forward(forwardInput);
134 Transformer::turnModelToInfer()->onExecute({predict});
135 predict->setName("prob");
136 std::string fileName = "temp.mobilenetv2.mnn";
137 Variable::save({predict}, fileName.c_str());
138 ConvertToFullQuant::convert(fileName);
139 }
140
141 exe->dumpProfile();
142 }
143 }
144