1 #include <string>
2 #include <vector>
3 #include <cstdlib>
4 #include <cstring>
5 #include <dmlc/io.h>
6 #include <dmlc/recordio.h>
7 
main(int argc,char * argv[])8 int main(int argc, char *argv[]) {
9   if (argc < 5) {
10     printf("Usage: <filename> partid npart nmax\n");
11     return 0;
12   }
13   using namespace dmlc;
14   dmlc::InputSplit *in = dmlc::InputSplit::
15       Create(argv[1],
16              atoi(argv[2]),
17              atoi(argv[3]),
18              "text");
19   size_t nmax = static_cast<size_t>(atol(argv[4]));
20   size_t lcnt = 0;
21   InputSplit::Blob rec;
22   std::vector<std::string> data;
23   while (in->NextRecord(&rec)) {
24     data.push_back(std::string((char*)rec.dptr, rec.size));
25     ++lcnt;
26     if (lcnt == nmax) {
27       LOG(INFO) << "finish loading " << lcnt << " lines";
28       break;
29     }
30   }
31   LOG(INFO) << "Call BeforeFirst when lcnt="
32             << lcnt << " nmax=" << nmax;
33   in->BeforeFirst();
34   lcnt = 0;
35   while (in->NextRecord(&rec)) {
36     std::string dat = std::string((char*)rec.dptr, rec.size);
37     if (lcnt < nmax) {
38       CHECK(rec.size == data[lcnt].length());
39       CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
40     } else {
41       data.push_back(dat);
42     }
43     ++lcnt;
44   }
45   LOG(INFO) << "Call BeforeFirst again";
46   in->BeforeFirst();
47   lcnt = 0;
48   while (in->NextRecord(&rec)) {
49     std::string dat = std::string((char*)rec.dptr, rec.size);
50     CHECK(lcnt < data.size());
51     CHECK(rec.size == data[lcnt].length());
52     CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
53     ++lcnt;
54   }
55   delete in;
56   LOG(INFO) << "All tests passed";
57   return 0;
58 }
59