1 #include <string>
2 #include <cstdlib>
3 #include <dmlc/io.h>
4 #include <dmlc/recordio.h>
5
main(int argc,char * argv[])6 int main(int argc, char *argv[]) {
7 if (argc < 4) {
8 printf("Usage: <filename> ndata dlen [nsplit]\n");
9 return 0;
10 }
11 using namespace dmlc;
12 int nsplit = 4;
13 if (argc > 4) nsplit = atoi(argv[4]);
14 LOG(INFO) << "generate the test-cases into";
15 int ndata = atoi(argv[2]);
16 int dlen = atoi(argv[3]);
17 std::vector<std::string> data;
18 const unsigned kMagic = dmlc::RecordIOWriter::kMagic;
19 for (int i = 0; i < ndata; ++i) {
20 std::string s;
21 s.resize(rand() % dlen);
22 // generate random string
23 for (size_t j = 0; j < s.length(); ++j) {
24 s[j] = static_cast<char>(rand() & 255);
25 }
26 int rnd = rand() % 4;
27 if (rnd == 4) {
28 size_t n = s.length();
29 s.resize(s.length() + 4);
30 std::memcpy(BeginPtr(s) + n, &kMagic, sizeof(kMagic));
31 } else if (rnd == 3) {
32 s.resize(std::max(s.length(), 4UL));
33 std::memcpy(BeginPtr(s), &kMagic, sizeof(kMagic));
34 } else if (rnd == 2) {
35 for (size_t k = 0; k + 4 <= s.length(); k += 4) {
36 if (rand() % 2) {
37 std::memcpy(BeginPtr(s) + 4, &kMagic, sizeof(kMagic));
38 }
39 }
40 } else if (rnd == 1) {
41 for (size_t k = 0; k + 4 <= s.length(); k += 4) {
42 if (rand() % 10) {
43 std::memcpy(BeginPtr(s) + 4, &kMagic, sizeof(kMagic));
44 }
45 }
46 }
47 data.push_back(s);
48 }
49 LOG(INFO) << "generate the test-cases into" << argv[1];
50 {// output
51 dmlc::Stream *fs = dmlc::Stream::Create(argv[1], "wb");
52 dmlc::RecordIOWriter writer(fs);
53 for (size_t i = 0; i < data.size(); ++i) {
54 writer.WriteRecord(data[i]);
55 }
56 delete fs;
57 printf("finish writing with %lu exceptions\n", writer.except_counter());
58 }
59 {// input
60 LOG(INFO) << "Test RecordIOReader..";
61 dmlc::Stream *fi = dmlc::Stream::Create(argv[1], "r");
62 dmlc::RecordIOReader reader(fi);
63 std::string temp;
64 size_t lcnt = 0;
65 while (reader.NextRecord(&temp)) {
66 CHECK(lcnt < data.size());
67 CHECK(temp.length() == data[lcnt].length());
68 if (temp.length() != 0) {
69 CHECK(!memcmp(BeginPtr(temp), BeginPtr(data[lcnt]), temp.length()));
70 }
71 ++lcnt;
72 }
73 delete fi;
74 LOG(INFO) << "Test RecordIOReader.. Pass";
75 }
76 {// InputSplit::RecordiO
77 LOG(INFO) << "Test InputSplit for RecordIO..";
78 size_t lcnt = 0;
79 for (int i = 0; i < nsplit; ++i) {
80 InputSplit::Blob rec;
81 dmlc::InputSplit *split = InputSplit::Create(argv[1], i, nsplit, "recordio");
82 while (split->NextRecord(&rec)) {
83 CHECK(lcnt < data.size());
84 CHECK(rec.size == data[lcnt].length());
85 if (rec.size != 0) {
86 CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
87 }
88 ++lcnt;
89 }
90 delete split;
91 }
92 LOG(INFO) << "Test InputSplit for RecordIO.. Pass";
93 }
94 {// InputSplit::RecordIO Chunk Read
95 LOG(INFO) << "Test InputSplit for RecordIO.. ChunkReader";
96 size_t lcnt = 0;
97 InputSplit::Blob chunk;
98 dmlc::InputSplit *split = InputSplit::Create(argv[1], 0, 1, "recordio");
99 while (split->NextChunk(&chunk)) {
100 for (int i = 0; i < nsplit; ++i) {
101 InputSplit::Blob rec;
102 dmlc::RecordIOChunkReader reader(chunk, i, nsplit);
103 while (reader.NextRecord(&rec)) {
104 CHECK(lcnt < data.size());
105 CHECK(rec.size == data[lcnt].length());
106 if (rec.size != 0) {
107 CHECK(!memcmp(rec.dptr, BeginPtr(data[lcnt]), rec.size));
108 }
109 ++lcnt;
110 }
111 }
112 }
113 delete split;
114 LOG(INFO) << "Test InputSplit for RecordIO.. ChunkReader Pass";
115 }
116 return 0;
117 }
118