1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * Xin Li yakumolx@gmail.com
22 */
23 #include <chrono>
24 #include "utils.h"
25 #include "mxnet-cpp/MxNetCpp.h"
26
27 using namespace mxnet::cpp;
28
mlp(const std::vector<int> & layers)29 Symbol mlp(const std::vector<int> &layers) {
30 auto x = Symbol::Variable("X");
31 auto label = Symbol::Variable("label");
32
33 std::vector<Symbol> weights(layers.size());
34 std::vector<Symbol> biases(layers.size());
35 std::vector<Symbol> outputs(layers.size());
36
37 for (size_t i = 0; i < layers.size(); ++i) {
38 weights[i] = Symbol::Variable("w" + std::to_string(i));
39 biases[i] = Symbol::Variable("b" + std::to_string(i));
40 Symbol fc = FullyConnected(
41 i == 0? x : outputs[i-1], // data
42 weights[i],
43 biases[i],
44 layers[i]);
45 outputs[i] = i == layers.size()-1 ? fc : Activation(fc, ActivationActType::kRelu);
46 }
47
48 return SoftmaxOutput(outputs.back(), label);
49 }
50
main(int argc,char ** argv)51 int main(int argc, char** argv) {
52 const int image_size = 28;
53 const std::vector<int> layers{128, 64, 10};
54 const int batch_size = 100;
55 const int max_epoch = 10;
56 const float learning_rate = 0.1;
57 const float weight_decay = 1e-2;
58
59 std::vector<std::string> data_files = { "./data/mnist_data/train-images-idx3-ubyte",
60 "./data/mnist_data/train-labels-idx1-ubyte",
61 "./data/mnist_data/t10k-images-idx3-ubyte",
62 "./data/mnist_data/t10k-labels-idx1-ubyte"
63 };
64
65 auto train_iter = MXDataIter("MNISTIter");
66 if (!setDataIter(&train_iter, "Train", data_files, batch_size)) {
67 return 1;
68 }
69
70 auto val_iter = MXDataIter("MNISTIter");
71 if (!setDataIter(&val_iter, "Label", data_files, batch_size)) {
72 return 1;
73 }
74
75 TRY
76 auto net = mlp(layers);
77
78 Context ctx = Context::gpu(); // Use GPU for training
79
80 std::map<std::string, NDArray> args;
81 args["X"] = NDArray(Shape(batch_size, image_size*image_size), ctx);
82 args["label"] = NDArray(Shape(batch_size), ctx);
83 // Let MXNet infer shapes of other parameters such as weights
84 net.InferArgsMap(ctx, &args, args);
85
86 // Initialize all parameters with uniform distribution U(-0.01, 0.01)
87 auto initializer = Uniform(0.01);
88 for (auto& arg : args) {
89 // arg.first is parameter name, and arg.second is the value
90 initializer(arg.first, &arg.second);
91 }
92
93 // Create sgd optimizer
94 Optimizer* opt = OptimizerRegistry::Find("sgd");
95 opt->SetParam("rescale_grad", 1.0/batch_size)
96 ->SetParam("lr", learning_rate)
97 ->SetParam("wd", weight_decay);
98 std::unique_ptr<LRScheduler> lr_sch(new FactorScheduler(5000, 0.1));
99 opt->SetLRScheduler(std::move(lr_sch));
100
101 // Create executor by binding parameters to the model
102 auto *exec = net.SimpleBind(ctx, args);
103 auto arg_names = net.ListArguments();
104
105 // Create metrics
106 Accuracy train_acc, val_acc;
107
108 // Start training
109 for (int iter = 0; iter < max_epoch; ++iter) {
110 int samples = 0;
111 train_iter.Reset();
112 train_acc.Reset();
113
114 auto tic = std::chrono::system_clock::now();
115 while (train_iter.Next()) {
116 samples += batch_size;
117 auto data_batch = train_iter.GetDataBatch();
118 // Data provided by DataIter are stored in memory, should be copied to GPU first.
119 data_batch.data.CopyTo(&args["X"]);
120 data_batch.label.CopyTo(&args["label"]);
121 // CopyTo is imperative, need to wait for it to complete.
122 NDArray::WaitAll();
123
124 // Compute gradients
125 exec->Forward(true);
126 exec->Backward();
127
128 // Update parameters
129 for (size_t i = 0; i < arg_names.size(); ++i) {
130 if (arg_names[i] == "X" || arg_names[i] == "label") continue;
131 opt->Update(i, exec->arg_arrays[i], exec->grad_arrays[i]);
132 }
133 // Update metric
134 train_acc.Update(data_batch.label, exec->outputs[0]);
135 }
136 // one epoch of training is finished
137 auto toc = std::chrono::system_clock::now();
138 float duration = std::chrono::duration_cast<std::chrono::milliseconds>
139 (toc - tic).count() / 1000.0;
140 LG << "Epoch[" << iter << "] " << samples/duration \
141 << " samples/sec " << "Train-Accuracy=" << train_acc.Get();;
142
143 val_iter.Reset();
144 val_acc.Reset();
145 while (val_iter.Next()) {
146 auto data_batch = val_iter.GetDataBatch();
147 data_batch.data.CopyTo(&args["X"]);
148 data_batch.label.CopyTo(&args["label"]);
149 NDArray::WaitAll();
150
151 // Only forward pass is enough as no gradient is needed when evaluating
152 exec->Forward(false);
153 val_acc.Update(data_batch.label, exec->outputs[0]);
154 }
155 LG << "Epoch[" << iter << "] Val-Accuracy=" << val_acc.Get();
156 }
157
158 delete exec;
159 delete opt;
160 MXNotifyShutdown();
161 CATCH
162 return 0;
163 }
164