1 #include <string>
2 #include <cstdlib>
3 #include <dmlc/io.h>
4 #include <dmlc/recordio.h>
5 
main(int argc,char * argv[])6 int main(int argc, char *argv[]) {
7   if (argc < 4) {
8     printf("Usage: <filename> ndata dlen [nsplit]\n");
9     return 0;
10   }
11   using namespace dmlc;
12   int nsplit = 4;
13   if (argc > 4) nsplit = atoi(argv[4]);
14   LOG(INFO) << "generate the test-cases into";
15   int ndata = atoi(argv[2]);
16   int dlen = atoi(argv[3]);
17   std::vector<std::string> data;
18   const unsigned kMagic = dmlc::RecordIOWriter::kMagic;
19   for (int i = 0; i < ndata; ++i) {
20     std::string s;
21     s.resize(rand() % dlen);
22     // generate random string
23     for (size_t j = 0; j < s.length(); ++j) {
24       s[j] = static_cast<char>(rand() & 255);
25     }
26     int rnd = rand() % 4;
27     if (rnd == 4) {
28       size_t n = s.length();
29       s.resize(s.length() + 4);
30       std::memcpy(BeginPtr(s) + n, &kMagic, sizeof(kMagic));
31     } else if (rnd == 3) {
32       s.resize(std::max(s.length(), 4UL));
33       std::memcpy(BeginPtr(s), &kMagic, sizeof(kMagic));
34     } else if (rnd == 2) {
35       for (size_t k = 0; k + 4 <= s.length(); k += 4) {
36         if (rand() % 2) {
37           std::memcpy(BeginPtr(s) + 4, &kMagic, sizeof(kMagic));
38         }
39       }
40     } else if (rnd == 1) {
41       for (size_t k = 0; k + 4 <= s.length(); k += 4) {
42         if (rand() % 10) {
43           std::memcpy(BeginPtr(s) + 4, &kMagic, sizeof(kMagic));
44         }
45       }
46     }
47     data.push_back(s);
48   }
49   LOG(INFO) << "generate the test-cases into" << argv[1];
50   {// output
51     dmlc::Stream *fs = dmlc::Stream::Create(argv[1], "wb");
52     dmlc::RecordIOWriter writer(fs);
53     for (size_t i = 0; i < data.size(); ++i) {
54       writer.WriteRecord(data[i]);
55     }
56     delete fs;
57     printf("finish writing with %lu exceptions\n", writer.except_counter());
58   }
59   {// input
60     LOG(INFO) << "Test RecordIOReader..";
61     dmlc::Stream *fi = dmlc::Stream::Create(argv[1], "r");
62     dmlc::RecordIOReader reader(fi);
63     std::string temp;
64     size_t lcnt = 0;
65     while (reader.NextRecord(&temp)) {
66       CHECK(lcnt < data.size());
67       CHECK(temp.length() == data[lcnt].length());
68       if (temp.length() != 0) {
69         CHECK(!memcmp(BeginPtr(temp), BeginPtr(data[lcnt]), temp.length()));
70       }
71       ++lcnt;
72     }
73     delete fi;
74     LOG(INFO) << "Test RecordIOReader.. Pass";
75   }
76   {// InputSplit::RecordiO
77     LOG(INFO) << "Test InputSplit for RecordIO..";
78     size_t lcnt = 0;
79     for (int i = 0; i < nsplit; ++i) {
80       InputSplit::Blob rec;
81       dmlc::InputSplit *split = InputSplit::Create(argv[1], i, nsplit, "recordio");
82       while (split->NextRecord(&rec)) {
83         CHECK(lcnt < data.size());
84         CHECK(rec.size == data[lcnt].length());
85         if (rec.size != 0) {
86           CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
87         }
88         ++lcnt;
89       }
90       delete split;
91     }
92     LOG(INFO) << "Test InputSplit for RecordIO.. Pass";
93   }
94   {// InputSplit::RecordIO Chunk Read
95     LOG(INFO) << "Test InputSplit for RecordIO.. ChunkReader";
96     size_t lcnt = 0;
97     InputSplit::Blob chunk;
98     dmlc::InputSplit *split = InputSplit::Create(argv[1], 0, 1, "recordio");
99     while (split->NextChunk(&chunk)) {
100       for (int i = 0; i < nsplit; ++i) {
101         InputSplit::Blob rec;
102         dmlc::RecordIOChunkReader reader(chunk, i, nsplit);
103         while (reader.NextRecord(&rec)) {
104           CHECK(lcnt < data.size());
105           CHECK(rec.size == data[lcnt].length());
106           if (rec.size != 0) {
107             CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
108           }
109           ++lcnt;
110         }
111       }
112     }
113     delete split;
114     LOG(INFO) << "Test InputSplit for RecordIO.. ChunkReader Pass";
115   }
116   return 0;
117 }
118