1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21 * \file base.h
22 * \brief metrics defined
23 * \author Zhang Chen
24 */
25 
26 #ifndef MXNET_CPP_METRIC_H_
27 #define MXNET_CPP_METRIC_H_
28 
29 #include <cmath>
30 #include <string>
31 #include <vector>
32 #include <algorithm>
33 #include "mxnet-cpp/ndarray.h"
34 #include "dmlc/logging.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 
39 class EvalMetric {
40  public:
41   explicit EvalMetric(const std::string& name, int num = 0)
name(name)42       : name(name), num(num) {}
43   virtual void Update(NDArray labels, NDArray preds) = 0;
Reset()44   void Reset() {
45     num_inst = 0;
46     sum_metric = 0.0f;
47   }
Get()48   float Get() { return sum_metric / num_inst; }
49   void GetNameValue();
50 
51  protected:
52   std::string name;
53   int num;
54   float sum_metric = 0.0f;
55   int num_inst = 0;
56 
57   static void CheckLabelShapes(NDArray labels, NDArray preds,
58                                bool strict = false) {
59     if (strict) {
60       CHECK_EQ(Shape(labels.GetShape()), Shape(preds.GetShape()));
61     } else {
62       CHECK_EQ(labels.Size(), preds.Size());
63     }
64   }
65 };
66 
67 class Accuracy : public EvalMetric {
68  public:
Accuracy()69   Accuracy() : EvalMetric("accuracy") {}
70 
Update(NDArray labels,NDArray preds)71   void Update(NDArray labels, NDArray preds) override {
72     CHECK_EQ(labels.GetShape().size(), 1);
73     mx_uint len = labels.GetShape()[0];
74     std::vector<mx_float> pred_data(len);
75     std::vector<mx_float> label_data(len);
76     preds.ArgmaxChannel().SyncCopyToCPU(&pred_data, len);
77     labels.SyncCopyToCPU(&label_data, len);
78     for (mx_uint i = 0; i < len; ++i) {
79       sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0;
80       num_inst += 1;
81     }
82   }
83 };
84 
85 class LogLoss : public EvalMetric {
86  public:
LogLoss()87   LogLoss() : EvalMetric("logloss") {}
88 
Update(NDArray labels,NDArray preds)89   void Update(NDArray labels, NDArray preds) override {
90     static const float epsilon = 1e-15;
91     mx_uint len = labels.GetShape()[0];
92     mx_uint m = preds.GetShape()[1];
93     std::vector<mx_float> pred_data(len * m);
94     std::vector<mx_float> label_data(len);
95     preds.SyncCopyToCPU(&pred_data, pred_data.size());
96     labels.SyncCopyToCPU(&label_data, len);
97     for (mx_uint i = 0; i < len; ++i) {
98       sum_metric +=
99           -std::log(std::max(pred_data[i * m + label_data[i]], epsilon));
100       num_inst += 1;
101     }
102   }
103 };
104 
105 class MAE : public EvalMetric {
106  public:
MAE()107   MAE() : EvalMetric("mae") {}
108 
Update(NDArray labels,NDArray preds)109   void Update(NDArray labels, NDArray preds) override {
110     CheckLabelShapes(labels, preds);
111 
112     std::vector<mx_float> pred_data;
113     preds.SyncCopyToCPU(&pred_data);
114     std::vector<mx_float> label_data;
115     labels.SyncCopyToCPU(&label_data);
116 
117     size_t len = preds.Size();
118     mx_float sum = 0;
119     for (size_t i = 0; i < len; ++i) {
120       sum += std::abs(pred_data[i] - label_data[i]);
121     }
122     sum_metric += sum / len;
123     ++num_inst;
124   }
125 };
126 
127 class MSE : public EvalMetric {
128  public:
MSE()129   MSE() : EvalMetric("mse") {}
130 
Update(NDArray labels,NDArray preds)131   void Update(NDArray labels, NDArray preds) override {
132     CheckLabelShapes(labels, preds);
133 
134     std::vector<mx_float> pred_data;
135     preds.SyncCopyToCPU(&pred_data);
136     std::vector<mx_float> label_data;
137     labels.SyncCopyToCPU(&label_data);
138 
139     size_t len = preds.Size();
140     mx_float sum = 0;
141     for (size_t i = 0; i < len; ++i) {
142       mx_float diff = pred_data[i] - label_data[i];
143       sum += diff * diff;
144     }
145     sum_metric += sum / len;
146     ++num_inst;
147   }
148 };
149 
150 class RMSE : public EvalMetric {
151  public:
RMSE()152   RMSE() : EvalMetric("rmse") {}
153 
Update(NDArray labels,NDArray preds)154   void Update(NDArray labels, NDArray preds) override {
155     CheckLabelShapes(labels, preds);
156 
157     std::vector<mx_float> pred_data;
158     preds.SyncCopyToCPU(&pred_data);
159     std::vector<mx_float> label_data;
160     labels.SyncCopyToCPU(&label_data);
161 
162     size_t len = preds.Size();
163     mx_float sum = 0;
164     for (size_t i = 0; i < len; ++i) {
165       mx_float diff = pred_data[i] - label_data[i];
166       sum += diff * diff;
167     }
168     sum_metric += std::sqrt(sum / len);
169     ++num_inst;
170   }
171 };
172 
173 class PSNR : public EvalMetric {
174  public:
PSNR()175   PSNR() : EvalMetric("psnr") {
176   }
177 
Update(NDArray labels,NDArray preds)178   void Update(NDArray labels, NDArray preds) override {
179     CheckLabelShapes(labels, preds);
180 
181     std::vector<mx_float> pred_data;
182     preds.SyncCopyToCPU(&pred_data);
183     std::vector<mx_float> label_data;
184     labels.SyncCopyToCPU(&label_data);
185 
186     size_t len = preds.Size();
187     mx_float sum = 0;
188     for (size_t i = 0; i < len; ++i) {
189       mx_float diff = pred_data[i] - label_data[i];
190       sum += diff * diff;
191     }
192     mx_float mse = sum / len;
193     if (mse > 0) {
194       sum_metric += 10 * std::log(255.0f / mse) / log10_;
195     } else {
196       sum_metric += 99.0f;
197     }
198     ++num_inst;
199   }
200 
201  private:
202   mx_float log10_ = std::log(10.0f);
203 };
204 
205 }  // namespace cpp
206 }  // namespace mxnet
207 
208 #endif  // MXNET_CPP_METRIC_H_
209 
210