1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file iter_csv.cc
22 * \brief define a CSV Reader to read in arrays
23 */
24 #include <mxnet/io.h>
25 #include <dmlc/base.h>
26 #include <dmlc/logging.h>
27 #include <dmlc/parameter.h>
28 #include <dmlc/data.h>
29 #include "./iter_prefetcher.h"
30 #include "./iter_batchloader.h"
31
32 namespace mxnet {
33 namespace io {
34 // CSV parameters
35 struct CSVIterParam : public dmlc::Parameter<CSVIterParam> {
36 /*! \brief path to data csv file */
37 std::string data_csv;
38 /*! \brief data shape */
39 mxnet::TShape data_shape;
40 /*! \brief path to label csv file */
41 std::string label_csv;
42 /*! \brief label shape */
43 mxnet::TShape label_shape;
44 // declare parameters
DMLC_DECLARE_PARAMETERmxnet::io::CSVIterParam45 DMLC_DECLARE_PARAMETER(CSVIterParam) {
46 DMLC_DECLARE_FIELD(data_csv)
47 .describe("The input CSV file or a directory path.");
48 DMLC_DECLARE_FIELD(data_shape)
49 .describe("The shape of one example.");
50 DMLC_DECLARE_FIELD(label_csv).set_default("NULL")
51 .describe("The input CSV file or a directory path. "
52 "If NULL, all labels will be returned as 0.");
53 index_t shape1[] = {1};
54 DMLC_DECLARE_FIELD(label_shape).set_default(mxnet::TShape(shape1, shape1 + 1))
55 .describe("The shape of one label.");
56 }
57 };
58
59 class CSVIterBase: public IIterator<DataInst> {
60 public:
CSVIterBase()61 CSVIterBase() {
62 out_.data.resize(2);
63 }
~CSVIterBase()64 virtual ~CSVIterBase() {}
65
66 // initialize iterator loads data in
67 virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) = 0;
68 /*! \brief reset the iterator */
69 virtual void BeforeFirst(void) = 0;
70 /*! \brief move to next item */
71 virtual bool Next(void) = 0;
72 /*! \brief get current data */
Value(void) const73 virtual const DataInst &Value(void) const {
74 return out_;
75 }
76
77 protected:
78 CSVIterParam param_;
79
80 DataInst out_;
81
82 // internal instance counter
83 unsigned inst_counter_{0};
84 // at end
85 bool end_{false};
86
87 // label parser
88 size_t label_ptr_{0}, label_size_{0};
89 size_t data_ptr_{0}, data_size_{0};
90 };
91
92 template <typename DType>
93 class CSVIterTyped: public CSVIterBase {
94 public:
~CSVIterTyped()95 virtual ~CSVIterTyped() {}
96 // intialize iterator loads data in
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)97 virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
98 param_.InitAllowUnknown(kwargs);
99 data_parser_.reset(dmlc::Parser<uint32_t, DType>::Create(param_.data_csv.c_str(), 0, 1, "csv"));
100 if (param_.label_csv != "NULL") {
101 label_parser_.reset(
102 dmlc::Parser<uint32_t, DType>::Create(param_.label_csv.c_str(), 0, 1, "csv"));
103 } else {
104 dummy_label.set_pad(false);
105 dummy_label.Resize(mshadow::Shape1(1));
106 dummy_label = 0;
107 }
108 }
109
BeforeFirst()110 virtual void BeforeFirst() {
111 data_parser_->BeforeFirst();
112 if (label_parser_.get() != nullptr) {
113 label_parser_->BeforeFirst();
114 }
115 data_ptr_ = label_ptr_ = 0;
116 data_size_ = label_size_ = 0;
117 inst_counter_ = 0;
118 end_ = false;
119 }
120
Next()121 virtual bool Next() {
122 if (end_) return false;
123 while (data_ptr_ >= data_size_) {
124 if (!data_parser_->Next()) {
125 end_ = true; return false;
126 }
127 data_ptr_ = 0;
128 data_size_ = data_parser_->Value().size;
129 }
130 out_.index = inst_counter_++;
131 CHECK_LT(data_ptr_, data_size_);
132 out_.data[0] = AsTBlob(data_parser_->Value()[data_ptr_++], param_.data_shape);
133
134 if (label_parser_.get() != nullptr) {
135 while (label_ptr_ >= label_size_) {
136 CHECK(label_parser_->Next())
137 << "Data CSV's row is smaller than the number of rows in label_csv";
138 label_ptr_ = 0;
139 label_size_ = label_parser_->Value().size;
140 }
141 CHECK_LT(label_ptr_, label_size_);
142 out_.data[1] = AsTBlob(label_parser_->Value()[label_ptr_++], param_.label_shape);
143 } else {
144 out_.data[1] = dummy_label;
145 }
146 return true;
147 }
148
149 private:
AsTBlob(const dmlc::Row<uint32_t,DType> & row,const mxnet::TShape & shape)150 inline TBlob AsTBlob(const dmlc::Row<uint32_t, DType>& row, const mxnet::TShape& shape) {
151 CHECK_EQ(row.length, shape.Size())
152 << "The data size in CSV do not match size of shape: "
153 << "specified shape=" << shape << ", the csv row-length=" << row.length;
154 const DType* ptr = row.value;
155 return TBlob((DType*)ptr, shape, cpu::kDevMask, 0); // NOLINT(*)
156 }
157 // dummy label
158 mshadow::TensorContainer<cpu, 1, DType> dummy_label;
159 std::unique_ptr<dmlc::Parser<uint32_t, DType> > label_parser_;
160 std::unique_ptr<dmlc::Parser<uint32_t, DType> > data_parser_;
161 };
162
163 class CSVIter: public IIterator<DataInst> {
164 public:
CSVIter()165 CSVIter() {}
~CSVIter()166 virtual ~CSVIter() {}
167
168 // intialize iterator loads data in
Init(const std::vector<std::pair<std::string,std::string>> & kwargs)169 virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
170 param_.InitAllowUnknown(kwargs);
171 bool dtype_has_value = false;
172 int target_dtype = -1;
173 for (const auto& arg : kwargs) {
174 if (arg.first == "dtype") {
175 dtype_has_value = true;
176 if (arg.second == "int32") {
177 target_dtype = mshadow::kInt32;
178 } else if (arg.second == "int64") {
179 target_dtype = mshadow::kInt64;
180 } else if (arg.second == "float32") {
181 target_dtype = mshadow::kFloat32;
182 } else {
183 CHECK(false) << arg.second << " is not supported for CSVIter";
184 }
185 }
186 }
187 if (dtype_has_value && target_dtype == mshadow::kInt32) {
188 iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int32_t>()));
189 } else if (dtype_has_value && target_dtype == mshadow::kInt64) {
190 iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<int64_t>()));
191 } else if (!dtype_has_value || target_dtype == mshadow::kFloat32) {
192 iterator_.reset(reinterpret_cast<CSVIterBase*>(new CSVIterTyped<float>()));
193 }
194 iterator_->Init(kwargs);
195 }
196
BeforeFirst()197 virtual void BeforeFirst() {
198 iterator_->BeforeFirst();
199 }
200
Next()201 virtual bool Next() {
202 return iterator_->Next();
203 }
204
Value(void) const205 virtual const DataInst &Value(void) const {
206 return iterator_->Value();
207 }
208
209 private:
210 CSVIterParam param_;
211 std::unique_ptr<CSVIterBase> iterator_;
212 };
213
214
215 DMLC_REGISTER_PARAMETER(CSVIterParam);
216
217 MXNET_REGISTER_IO_ITER(CSVIter)
218 .describe(R"code(Returns the CSV file iterator.
219
220 In this function, the `data_shape` parameter is used to set the shape of each line of the input data.
221 If a row in an input file is `1,2,3,4,5,6`` and `data_shape` is (3,2), that row
222 will be reshaped, yielding the array [[1,2],[3,4],[5,6]] of shape (3,2).
223
224 By default, the `CSVIter` has `round_batch` parameter set to ``True``. So, if `batch_size`
225 is 3 and there are 4 total rows in CSV file, 2 more examples
226 are consumed at the first round. If `reset` function is called after first round,
227 the call is ignored and remaining examples are returned in the second round.
228
229 If one wants all the instances in the second round after calling `reset`, make sure
230 to set `round_batch` to False.
231
232 If ``data_csv = 'data/'`` is set, then all the files in this directory will be read.
233
234 ``reset()`` is expected to be called only after a complete pass of data.
235
236 By default, the CSVIter parses all entries in the data file as float32 data type,
237 if `dtype` argument is set to be 'int32' or 'int64' then CSVIter will parse all entries in the file
238 as int32 or int64 data type accordingly.
239
240 Examples::
241
242 // Contents of CSV file ``data/data.csv``.
243 1,2,3
244 2,3,4
245 3,4,5
246 4,5,6
247
248 // Creates a `CSVIter` with `batch_size`=2 and default `round_batch`=True.
249 CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
250 batch_size = 2)
251
252 // Two batches read from the above iterator are as follows:
253 [[ 1. 2. 3.]
254 [ 2. 3. 4.]]
255 [[ 3. 4. 5.]
256 [ 4. 5. 6.]]
257
258 // Creates a `CSVIter` with default `round_batch` set to True.
259 CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
260 batch_size = 3)
261
262 // Two batches read from the above iterator in the first pass are as follows:
263 [[1. 2. 3.]
264 [2. 3. 4.]
265 [3. 4. 5.]]
266
267 [[4. 5. 6.]
268 [1. 2. 3.]
269 [2. 3. 4.]]
270
271 // Now, `reset` method is called.
272 CSVIter.reset()
273
274 // Batch read from the above iterator in the second pass is as follows:
275 [[ 3. 4. 5.]
276 [ 4. 5. 6.]
277 [ 1. 2. 3.]]
278
279 // Creates a `CSVIter` with `round_batch`=False.
280 CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
281 batch_size = 3, round_batch=False)
282
283 // Contents of two batches read from the above iterator in both passes, after calling
284 // `reset` method before second pass, is as follows:
285 [[1. 2. 3.]
286 [2. 3. 4.]
287 [3. 4. 5.]]
288
289 [[4. 5. 6.]
290 [2. 3. 4.]
291 [3. 4. 5.]]
292
293 // Creates a 'CSVIter' with `dtype`='int32'
294 CSVIter = mx.io.CSVIter(data_csv = 'data/data.csv', data_shape = (3,),
295 batch_size = 3, round_batch=False, dtype='int32')
296
297 // Contents of two batches read from the above iterator in both passes, after calling
298 // `reset` method before second pass, is as follows:
299 [[1 2 3]
300 [2 3 4]
301 [3 4 5]]
302
303 [[4 5 6]
304 [2 3 4]
305 [3 4 5]]
306
307 )code" ADD_FILELINE)
308 .add_arguments(CSVIterParam::__FIELDS__())
309 .add_arguments(BatchParam::__FIELDS__())
310 .add_arguments(PrefetcherParam::__FIELDS__())
__anon1b42fbd90102() 311 .set_body([]() {
312 return new PrefetcherIter(
313 new BatchLoader(
314 new CSVIter()));
315 });
316
317 } // namespace io
318 } // namespace mxnet
319