1 #include <string>
2 #include <vector>
3 #include <cstdlib>
4 #include <cstring>
5 #include <dmlc/io.h>
6 #include <dmlc/recordio.h>
7
main(int argc,char * argv[])8 int main(int argc, char *argv[]) {
9 if (argc < 5) {
10 printf("Usage: <filename> partid npart nmax\n");
11 return 0;
12 }
13 using namespace dmlc;
14 dmlc::InputSplit *in = dmlc::InputSplit::
15 Create(argv[1],
16 atoi(argv[2]),
17 atoi(argv[3]),
18 "text");
19 size_t nmax = static_cast<size_t>(atol(argv[4]));
20 size_t lcnt = 0;
21 InputSplit::Blob rec;
22 std::vector<std::string> data;
23 while (in->NextRecord(&rec)) {
24 data.push_back(std::string((char*)rec.dptr, rec.size));
25 ++lcnt;
26 if (lcnt == nmax) {
27 LOG(INFO) << "finish loading " << lcnt << " lines";
28 break;
29 }
30 }
31 LOG(INFO) << "Call BeforeFirst when lcnt="
32 << lcnt << " nmax=" << nmax;
33 in->BeforeFirst();
34 lcnt = 0;
35 while (in->NextRecord(&rec)) {
36 std::string dat = std::string((char*)rec.dptr, rec.size);
37 if (lcnt < nmax) {
38 CHECK(rec.size == data[lcnt].length());
39 CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
40 } else {
41 data.push_back(dat);
42 }
43 ++lcnt;
44 }
45 LOG(INFO) << "Call BeforeFirst again";
46 in->BeforeFirst();
47 lcnt = 0;
48 while (in->NextRecord(&rec)) {
49 std::string dat = std::string((char*)rec.dptr, rec.size);
50 CHECK(lcnt < data.size());
51 CHECK(rec.size == data[lcnt].length());
52 CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
53 ++lcnt;
54 }
55 delete in;
56 LOG(INFO) << "All tests passed";
57 return 0;
58 }
59