1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file base.h 22 * \brief metrics defined 23 * \author Zhang Chen 24 */ 25 26 #ifndef MXNET_CPP_METRIC_H_ 27 #define MXNET_CPP_METRIC_H_ 28 29 #include <cmath> 30 #include <string> 31 #include <vector> 32 #include <algorithm> 33 #include "mxnet-cpp/ndarray.h" 34 #include "dmlc/logging.h" 35 36 namespace mxnet { 37 namespace cpp { 38 39 class EvalMetric { 40 public: 41 explicit EvalMetric(const std::string& name, int num = 0) name(name)42 : name(name), num(num) {} 43 virtual void Update(NDArray labels, NDArray preds) = 0; Reset()44 void Reset() { 45 num_inst = 0; 46 sum_metric = 0.0f; 47 } Get()48 float Get() { return sum_metric / num_inst; } 49 void GetNameValue(); 50 51 protected: 52 std::string name; 53 int num; 54 float sum_metric = 0.0f; 55 int num_inst = 0; 56 57 static void CheckLabelShapes(NDArray labels, NDArray preds, 58 bool strict = false) { 59 if (strict) { 60 CHECK_EQ(Shape(labels.GetShape()), Shape(preds.GetShape())); 61 } else { 62 CHECK_EQ(labels.Size(), preds.Size()); 63 } 64 } 65 }; 66 67 class Accuracy : public EvalMetric { 68 public: Accuracy()69 Accuracy() : EvalMetric("accuracy") {} 70 Update(NDArray labels,NDArray preds)71 void Update(NDArray labels, NDArray preds) override { 72 CHECK_EQ(labels.GetShape().size(), 1); 73 mx_uint len = labels.GetShape()[0]; 74 std::vector<mx_float> pred_data(len); 75 std::vector<mx_float> label_data(len); 76 preds.ArgmaxChannel().SyncCopyToCPU(&pred_data, len); 77 labels.SyncCopyToCPU(&label_data, len); 78 for (mx_uint i = 0; i < len; ++i) { 79 sum_metric += (pred_data[i] == label_data[i]) ? 1 : 0; 80 num_inst += 1; 81 } 82 } 83 }; 84 85 class LogLoss : public EvalMetric { 86 public: LogLoss()87 LogLoss() : EvalMetric("logloss") {} 88 Update(NDArray labels,NDArray preds)89 void Update(NDArray labels, NDArray preds) override { 90 static const float epsilon = 1e-15; 91 mx_uint len = labels.GetShape()[0]; 92 mx_uint m = preds.GetShape()[1]; 93 std::vector<mx_float> pred_data(len * m); 94 std::vector<mx_float> label_data(len); 95 preds.SyncCopyToCPU(&pred_data, pred_data.size()); 96 labels.SyncCopyToCPU(&label_data, len); 97 for (mx_uint i = 0; i < len; ++i) { 98 sum_metric += 99 -std::log(std::max(pred_data[i * m + label_data[i]], epsilon)); 100 num_inst += 1; 101 } 102 } 103 }; 104 105 class MAE : public EvalMetric { 106 public: MAE()107 MAE() : EvalMetric("mae") {} 108 Update(NDArray labels,NDArray preds)109 void Update(NDArray labels, NDArray preds) override { 110 CheckLabelShapes(labels, preds); 111 112 std::vector<mx_float> pred_data; 113 preds.SyncCopyToCPU(&pred_data); 114 std::vector<mx_float> label_data; 115 labels.SyncCopyToCPU(&label_data); 116 117 size_t len = preds.Size(); 118 mx_float sum = 0; 119 for (size_t i = 0; i < len; ++i) { 120 sum += std::abs(pred_data[i] - label_data[i]); 121 } 122 sum_metric += sum / len; 123 ++num_inst; 124 } 125 }; 126 127 class MSE : public EvalMetric { 128 public: MSE()129 MSE() : EvalMetric("mse") {} 130 Update(NDArray labels,NDArray preds)131 void Update(NDArray labels, NDArray preds) override { 132 CheckLabelShapes(labels, preds); 133 134 std::vector<mx_float> pred_data; 135 preds.SyncCopyToCPU(&pred_data); 136 std::vector<mx_float> label_data; 137 labels.SyncCopyToCPU(&label_data); 138 139 size_t len = preds.Size(); 140 mx_float sum = 0; 141 for (size_t i = 0; i < len; ++i) { 142 mx_float diff = pred_data[i] - label_data[i]; 143 sum += diff * diff; 144 } 145 sum_metric += sum / len; 146 ++num_inst; 147 } 148 }; 149 150 class RMSE : public EvalMetric { 151 public: RMSE()152 RMSE() : EvalMetric("rmse") {} 153 Update(NDArray labels,NDArray preds)154 void Update(NDArray labels, NDArray preds) override { 155 CheckLabelShapes(labels, preds); 156 157 std::vector<mx_float> pred_data; 158 preds.SyncCopyToCPU(&pred_data); 159 std::vector<mx_float> label_data; 160 labels.SyncCopyToCPU(&label_data); 161 162 size_t len = preds.Size(); 163 mx_float sum = 0; 164 for (size_t i = 0; i < len; ++i) { 165 mx_float diff = pred_data[i] - label_data[i]; 166 sum += diff * diff; 167 } 168 sum_metric += std::sqrt(sum / len); 169 ++num_inst; 170 } 171 }; 172 173 class PSNR : public EvalMetric { 174 public: PSNR()175 PSNR() : EvalMetric("psnr") { 176 } 177 Update(NDArray labels,NDArray preds)178 void Update(NDArray labels, NDArray preds) override { 179 CheckLabelShapes(labels, preds); 180 181 std::vector<mx_float> pred_data; 182 preds.SyncCopyToCPU(&pred_data); 183 std::vector<mx_float> label_data; 184 labels.SyncCopyToCPU(&label_data); 185 186 size_t len = preds.Size(); 187 mx_float sum = 0; 188 for (size_t i = 0; i < len; ++i) { 189 mx_float diff = pred_data[i] - label_data[i]; 190 sum += diff * diff; 191 } 192 mx_float mse = sum / len; 193 if (mse > 0) { 194 sum_metric += 10 * std::log(255.0f / mse) / log10_; 195 } else { 196 sum_metric += 99.0f; 197 } 198 ++num_inst; 199 } 200 201 private: 202 mx_float log10_ = std::log(10.0f); 203 }; 204 205 } // namespace cpp 206 } // namespace mxnet 207 208 #endif // MXNET_CPP_METRIC_H_ 209 210