1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file iter_csv.cc
22  * \brief define a CSV Reader to read in arrays
23  */
24 #include <mxnet/io.h>
25 #include <dmlc/base.h>
26 #include <dmlc/logging.h>
27 #include <dmlc/parameter.h>
28 #include <dmlc/data.h>
29 #include "./iter_prefetcher.h"
30 #include "./iter_batchloader.h"
31 
32 namespace mxnet {
33 namespace io {
34 // CSV parameters
35 struct CSVIterParam : public dmlc::Parameter<CSVIterParam> {
36   /*! \brief path to data csv file */
37   std::string data_csv;
38   /*! \brief data shape */
39   mxnet::TShape data_shape;
40   /*! \brief path to label csv file */
41   std::string label_csv;
42   /*! \brief label shape */
43   mxnet::TShape label_shape;
44   // declare parameters
DMLC_DECLARE_PARAMETERmxnet::io::CSVIterParam45   DMLC_DECLARE_PARAMETER(CSVIterParam) {
46     DMLC_DECLARE_FIELD(data_csv)
47         .describe("The input CSV file or a directory path.");
48     DMLC_DECLARE_FIELD(data_shape)
49         .describe("The shape of one example.");
50     DMLC_DECLARE_FIELD(label_csv).set_default("NULL")
51         .describe("The input CSV file or a directory path. "
52                   "If NULL, all labels will be returned as 0.");
53     index_t shape1[] = {1};
54     DMLC_DECLARE_FIELD(label_shape).set_default(mxnet::TShape(shape1, shape1 + 1))
55         .describe("The shape of one label.");
56   }
57 };
58 
59 class CSVIterBase: public IIterator<DataInst> {
60  public:
CSVIterBase()61   CSVIterBase() {
62     out_.data.resize(2);
63   }
~CSVIterBase()64   virtual ~CSVIterBase() {}
65 
66   // initialize iterator loads data in
67   virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
68   /*! \brief reset the iterator */
69   virtual void BeforeFirst(void) = 0;
70   /*! \brief move to next item */
71   virtual bool Next(void) = 0;
72   /*! \brief get current data */
Value(void) const73   virtual const DataInst &Value(void) const {
74     return out_;
75   }
76 
77  protected:
78   CSVIterParam param_;
79 
80   DataInst out_;
81 
82   // internal instance counter
83   unsigned inst_counter_{0};
84   // at end
85   bool end_{false};
86 
87   // label parser
88   size_t label_ptr_{0}, label_size_{0};
89   size_t data_ptr_{0}, data_size_{0};
90 };
91 
92 template <typename DType>
93 class CSVIterTyped: public CSVIterBase {
94  public:
~CSVIterTyped()95   virtual ~CSVIterTyped() {}
96   // intialize iterator loads data in
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)97   virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
98     param_.InitAllowUnknown(kwargs);
99     data_parser_.reset(dmlc::Parser<uint32_t, DType>::Create(param_.data_csv.c_str(), 0, 1, "csv"));
100     if (param_.label_csv != "NULL") {
101       label_parser_.reset(
102         dmlc::Parser<uint32_t, DType>::Create(param_.label_csv.c_str(), 0, 1, "csv"));
103     } else {
104       dummy_label.set_pad(false);
105       dummy_label.Resize(mshadow::Shape1(1));
106       dummy_label = 0;
107     }
108   }
109 
BeforeFirst()110   virtual void BeforeFirst() {
111     data_parser_->BeforeFirst();
112     if (label_parser_.get() != nullptr) {
113       label_parser_->BeforeFirst();
114     }
115     data_ptr_ = label_ptr_ = 0;
116     data_size_ = label_size_ = 0;
117     inst_counter_ = 0;
118     end_ = false;
119   }
120 
Next()121   virtual bool Next() {
122     if (end_) return false;
123     while (data_ptr_ >= data_size_) {
124       if (!data_parser_->Next()) {
125         end_ = true; return false;
126       }
127       data_ptr_ = 0;
128       data_size_ = data_parser_->Value().size;
129     }
130     out_.index = inst_counter_++;
131     CHECK_LT(data_ptr_, data_size_);
132     out_.data[0] = AsTBlob(data_parser_->Value()[data_ptr_++], param_.data_shape);
133 
134     if (label_parser_.get() != nullptr) {
135       while (label_ptr_ >= label_size_) {
136         CHECK(label_parser_->Next())
137             << "Data CSV's row is smaller than the number of rows in label_csv";
138         label_ptr_ = 0;
139         label_size_ = label_parser_->Value().size;
140       }
141       CHECK_LT(label_ptr_, label_size_);
142       out_.data[1] = AsTBlob(label_parser_->Value()[label_ptr_++], param_.label_shape);
143     } else {
144       out_.data[1] = dummy_label;
145     }
146     return true;
147   }
148 
149  private:
AsTBlob(const dmlc::Row<uint32_t,DType> & row,const mxnet::TShape & shape)150   inline TBlob AsTBlob(const dmlc::Row<uint32_t, DType>& row, const mxnet::TShape& shape) {
151     CHECK_EQ(row.length, shape.Size())
152         << "The data size in CSV do not match size of shape: "
153         << "specified shape=" << shape << ", the csv row-length=" << row.length;
154     const DType* ptr = row.value;
155     return TBlob((DType*)ptr, shape, cpu::kDevMask, 0);  // NOLINT(*)
156   }
157   // dummy label
158   mshadow::TensorContainer<cpu, 1, DType> dummy_label;
159   std::unique_ptr<dmlc::Parser<uint32_t, DType> > label_parser_;
160   std::unique_ptr<dmlc::Parser<uint32_t, DType> > data_parser_;
161 };
162 
163 class CSVIter: public IIterator<DataInst> {
164  public:
CSVIter()165   CSVIter() {}
~CSVIter()166   virtual ~CSVIter() {}
167 
168   // intialize iterator loads data in
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)169   virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
170     param_.InitAllowUnknown(kwargs);
171     bool dtype_has_value = false;
172     int target_dtype = -1;
173     for (const auto& arg : kwargs) {
174       if (arg.first == "dtype") {
175         dtype_has_value = true;
176         if (arg.second == "int32") {
177           target_dtype = mshadow::kInt32;
178         } else if (arg.second == "int64") {
179           target_dtype = mshadow::kInt64;
180         } else if (arg.second == "float32") {
181           target_dtype = mshadow::kFloat32;
182         } else {
183           CHECK(false) << arg.second << " is not supported for CSVIter";
184         }
185       }
186     }
187     if (dtype_has_value && target_dtype == mshadow::kInt32) {
188       iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int32_t>()));
189     } else if (dtype_has_value && target_dtype == mshadow::kInt64) {
190       iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int64_t>()));
191     } else if (!dtype_has_value || target_dtype == mshadow::kFloat32) {
192       iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<float>()));
193     }
194     iterator_->Init(kwargs);
195   }
196 
BeforeFirst()197   virtual void BeforeFirst() {
198     iterator_->BeforeFirst();
199   }
200 
Next()201   virtual bool Next() {
202     return iterator_->Next();
203   }
204 
Value(void) const205   virtual const DataInst &Value(void) const {
206     return iterator_->Value();
207   }
208 
209  private:
210   CSVIterParam param_;
211   std::unique_ptr<CSVIterBase> iterator_;
212 };
213 
214 
215 DMLC_REGISTER_PARAMETER(CSVIterParam);
216 
217 MXNET_REGISTER_IO_ITER(CSVIter)
218 .describe(R"code(Returns the CSV file iterator.
219 
220 In this function, the `data_shape` parameter is used to set the shape of each line of the input data.
221 If a row in an input file is `1,2,3,4,5,6`` and `data_shape` is (3,2), that row
222 will be reshaped, yielding the array [[1,2],[3,4],[5,6]] of shape (3,2).
223 
224 By default, the `CSVIter` has `round_batch` parameter set to ``True``. So, if `batch_size`
225 is 3 and there are 4 total rows in CSV file, 2 more examples
226 are consumed at the first round. If `reset` function is called after first round,
227 the call is ignored and remaining examples are returned in the second round.
228 
229 If one wants all the instances in the second round after calling `reset`, make sure
230 to set `round_batch` to False.
231 
232 If ``data_csv = 'data/'`` is set, then all the files in this directory will be read.
233 
234 ``reset()`` is expected to be called only after a complete pass of data.
235 
236 By default, the CSVIter parses all entries in the data file as float32 data type,
237 if `dtype` argument is set to be 'int32' or 'int64' then CSVIter will parse all entries in the file
238 as int32 or int64 data type accordingly.
239 
240 Examples::
241 
242   // Contents of CSV file ``data/data.csv``.
243   1,2,3
244   2,3,4
245   3,4,5
246   4,5,6
247 
248   // Creates a `CSVIter` with `batch_size`=2 and default `round_batch`=True.
249   CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
250   batch_size = 2)
251 
252   // Two batches read from the above iterator are as follows:
253   [[ 1.  2.  3.]
254   [ 2.  3.  4.]]
255   [[ 3.  4.  5.]
256   [ 4.  5.  6.]]
257 
258   // Creates a `CSVIter` with default `round_batch` set to True.
259   CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
260   batch_size = 3)
261 
262   // Two batches read from the above iterator in the first pass are as follows:
263   [[1.  2.  3.]
264   [2.  3.  4.]
265   [3.  4.  5.]]
266 
267   [[4.  5.  6.]
268   [1.  2.  3.]
269   [2.  3.  4.]]
270 
271   // Now, `reset` method is called.
272   CSVIter.reset()
273 
274   // Batch read from the above iterator in the second pass is as follows:
275   [[ 3.  4.  5.]
276   [ 4.  5.  6.]
277   [ 1.  2.  3.]]
278 
279   // Creates a `CSVIter` with `round_batch`=False.
280   CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
281   batch_size = 3, round_batch=False)
282 
283   // Contents of two batches read from the above iterator in both passes, after calling
284   // `reset` method before second pass, is as follows:
285   [[1.  2.  3.]
286   [2.  3.  4.]
287   [3.  4.  5.]]
288 
289   [[4.  5.  6.]
290   [2.  3.  4.]
291   [3.  4.  5.]]
292 
293   // Creates a 'CSVIter' with `dtype`='int32'
294   CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
295   batch_size = 3, round_batch=False, dtype='int32')
296 
297   // Contents of two batches read from the above iterator in both passes, after calling
298   // `reset` method before second pass, is as follows:
299   [[1  2  3]
300   [2  3  4]
301   [3  4  5]]
302 
303   [[4  5  6]
304   [2  3  4]
305   [3  4  5]]
306 
307 )code" ADD_FILELINE)
308 .add_arguments(CSVIterParam::__FIELDS__())
309 .add_arguments(BatchParam::__FIELDS__())
310 .add_arguments(PrefetcherParam::__FIELDS__())
__anon1b42fbd90102() 311 .set_body([]() {
312     return new PrefetcherIter(
313         new BatchLoader(
314             new CSVIter()));
315   });
316 
317 }  // namespace io
318 }  // namespace mxnet
319