1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file operator.h 22 * \brief definition of io, such as DataIter 23 * \author Zhang Chen 24 */ 25 #ifndef MXNET_CPP_IO_H_ 26 #define MXNET_CPP_IO_H_ 27 28 #include <map> 29 #include <string> 30 #include <vector> 31 #include <sstream> 32 #include "mxnet-cpp/base.h" 33 #include "mxnet-cpp/ndarray.h" 34 #include "dmlc/logging.h" 35 36 namespace mxnet { 37 namespace cpp { 38 /*! 39 * \brief Default object for holding a mini-batch of data and related 40 * information. 41 */ 42 class DataBatch { 43 public: 44 NDArray data; 45 NDArray label; 46 int pad_num; 47 std::vector<int> index; 48 }; 49 class DataIter { 50 public: 51 virtual void BeforeFirst(void) = 0; 52 virtual bool Next(void) = 0; 53 virtual NDArray GetData(void) = 0; 54 virtual NDArray GetLabel(void) = 0; 55 virtual int GetPadNum(void) = 0; 56 virtual std::vector<int> GetIndex(void) = 0; 57 GetDataBatch()58 DataBatch GetDataBatch() { 59 return DataBatch{GetData(), GetLabel(), GetPadNum(), GetIndex()}; 60 } Reset()61 void Reset() { BeforeFirst(); } 62 63 virtual ~DataIter() = default; 64 }; 65 66 class MXDataIterMap { 67 public: MXDataIterMap()68 inline MXDataIterMap() { 69 mx_uint num_data_iter_creators = 0; 70 DataIterCreator *data_iter_creators = nullptr; 71 int r = MXListDataIters(&num_data_iter_creators, &data_iter_creators); 72 CHECK_EQ(r, 0); 73 for (mx_uint i = 0; i < num_data_iter_creators; i++) { 74 const char *name; 75 const char *description; 76 mx_uint num_args; 77 const char **arg_names; 78 const char **arg_type_infos; 79 const char **arg_descriptions; 80 r = MXDataIterGetIterInfo(data_iter_creators[i], &name, &description, 81 &num_args, &arg_names, &arg_type_infos, 82 &arg_descriptions); 83 CHECK_EQ(r, 0); 84 mxdataiter_creators_[name] = data_iter_creators[i]; 85 } 86 } GetMXDataIterCreator(const std::string & name)87 inline DataIterCreator GetMXDataIterCreator(const std::string &name) { 88 return mxdataiter_creators_[name]; 89 } 90 91 private: 92 std::map<std::string, DataIterCreator> mxdataiter_creators_; 93 }; 94 95 struct MXDataIterBlob { 96 public: MXDataIterBlobMXDataIterBlob97 MXDataIterBlob() : handle_(nullptr) {} MXDataIterBlobMXDataIterBlob98 explicit MXDataIterBlob(DataIterHandle handle) : handle_(handle) {} ~MXDataIterBlobMXDataIterBlob99 ~MXDataIterBlob() { MXDataIterFree(handle_); } 100 DataIterHandle handle_; 101 102 private: 103 MXDataIterBlob &operator=(const MXDataIterBlob &); 104 }; 105 106 class MXDataIter : public DataIter { 107 public: 108 explicit MXDataIter(const std::string &mxdataiter_type); MXDataIter(const MXDataIter & other)109 MXDataIter(const MXDataIter &other) { 110 creator_ = other.creator_; 111 params_ = other.params_; 112 blob_ptr_ = other.blob_ptr_; 113 } 114 void BeforeFirst(); 115 bool Next(); 116 NDArray GetData(); 117 NDArray GetLabel(); 118 int GetPadNum(); 119 std::vector<int> GetIndex(); 120 MXDataIter CreateDataIter(); 121 /*! 122 * \brief set config parameters 123 * \param name name of the config parameter 124 * \param value value of the config parameter 125 * \return reference of self 126 */ 127 template <typename T> SetParam(const std::string & name,const T & value)128 MXDataIter &SetParam(const std::string &name, const T &value) { 129 std::string value_str; 130 std::stringstream ss; 131 ss << value; 132 ss >> value_str; 133 134 params_[name] = value_str; 135 return *this; 136 } 137 138 private: 139 DataIterCreator creator_; 140 std::map<std::string, std::string> params_; 141 std::shared_ptr<MXDataIterBlob> blob_ptr_; 142 static MXDataIterMap*& mxdataiter_map(); 143 }; 144 } // namespace cpp 145 } // namespace mxnet 146 147 #endif // MXNET_CPP_IO_H_ 148 149