1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21 * \file operator.h
22 * \brief definition of io, such as DataIter
23 * \author Zhang Chen
24 */
25 #ifndef MXNET_CPP_IO_H_
26 #define MXNET_CPP_IO_H_
27 
28 #include <map>
29 #include <string>
30 #include <vector>
31 #include <sstream>
32 #include "mxnet-cpp/base.h"
33 #include "mxnet-cpp/ndarray.h"
34 #include "dmlc/logging.h"
35 
36 namespace mxnet {
37 namespace cpp {
38 /*!
39 * \brief Default object for holding a mini-batch of data and related
40 * information.
41 */
42 class DataBatch {
43  public:
44   NDArray data;
45   NDArray label;
46   int pad_num;
47   std::vector<int> index;
48 };
49 class DataIter {
50  public:
51   virtual void BeforeFirst(void) = 0;
52   virtual bool Next(void) = 0;
53   virtual NDArray GetData(void) = 0;
54   virtual NDArray GetLabel(void) = 0;
55   virtual int GetPadNum(void) = 0;
56   virtual std::vector<int> GetIndex(void) = 0;
57 
GetDataBatch()58   DataBatch GetDataBatch() {
59     return DataBatch{GetData(), GetLabel(), GetPadNum(), GetIndex()};
60   }
Reset()61   void Reset() { BeforeFirst(); }
62 
63   virtual ~DataIter() = default;
64 };
65 
66 class MXDataIterMap {
67  public:
MXDataIterMap()68   inline MXDataIterMap() {
69     mx_uint num_data_iter_creators = 0;
70     DataIterCreator *data_iter_creators = nullptr;
71     int r = MXListDataIters(&num_data_iter_creators, &data_iter_creators);
72     CHECK_EQ(r, 0);
73     for (mx_uint i = 0; i < num_data_iter_creators; i++) {
74       const char *name;
75       const char *description;
76       mx_uint num_args;
77       const char **arg_names;
78       const char **arg_type_infos;
79       const char **arg_descriptions;
80       r = MXDataIterGetIterInfo(data_iter_creators[i], &name, &description,
81                                 &num_args, &arg_names, &arg_type_infos,
82                                 &arg_descriptions);
83       CHECK_EQ(r, 0);
84       mxdataiter_creators_[name] = data_iter_creators[i];
85     }
86   }
GetMXDataIterCreator(const std::string & name)87   inline DataIterCreator GetMXDataIterCreator(const std::string &name) {
88     return mxdataiter_creators_[name];
89   }
90 
91  private:
92   std::map<std::string, DataIterCreator> mxdataiter_creators_;
93 };
94 
95 struct MXDataIterBlob {
96  public:
MXDataIterBlobMXDataIterBlob97   MXDataIterBlob() : handle_(nullptr) {}
MXDataIterBlobMXDataIterBlob98   explicit MXDataIterBlob(DataIterHandle handle) : handle_(handle) {}
~MXDataIterBlobMXDataIterBlob99   ~MXDataIterBlob() { MXDataIterFree(handle_); }
100   DataIterHandle handle_;
101 
102  private:
103   MXDataIterBlob &operator=(const MXDataIterBlob &);
104 };
105 
106 class MXDataIter : public DataIter {
107  public:
108   explicit MXDataIter(const std::string &mxdataiter_type);
MXDataIter(const MXDataIter & other)109   MXDataIter(const MXDataIter &other) {
110     creator_ = other.creator_;
111     params_ = other.params_;
112     blob_ptr_ = other.blob_ptr_;
113   }
114   void BeforeFirst();
115   bool Next();
116   NDArray GetData();
117   NDArray GetLabel();
118   int GetPadNum();
119   std::vector<int> GetIndex();
120   MXDataIter CreateDataIter();
121   /*!
122    * \brief set config parameters
123    * \param name name of the config parameter
124    * \param value value of the config parameter
125    * \return reference of self
126    */
127   template <typename T>
SetParam(const std::string & name,const T & value)128   MXDataIter &SetParam(const std::string &name, const T &value) {
129     std::string value_str;
130     std::stringstream ss;
131     ss << value;
132     ss >> value_str;
133 
134     params_[name] = value_str;
135     return *this;
136   }
137 
138  private:
139   DataIterCreator creator_;
140   std::map<std::string, std::string> params_;
141   std::shared_ptr<MXDataIterBlob> blob_ptr_;
142   static MXDataIterMap*& mxdataiter_map();
143 };
144 }  // namespace cpp
145 }  // namespace mxnet
146 
147 #endif  // MXNET_CPP_IO_H_
148 
149