1 // Copyright by Contributors
2
3 #include <dmlc/base.h>
4 #include <dmlc/recordio.h>
5 #include <dmlc/logging.h>
6 #include <algorithm>
7
8
9 namespace dmlc {
10 // implementation
WriteRecord(const void * buf,size_t size)11 void RecordIOWriter::WriteRecord(const void *buf, size_t size) {
12 CHECK(size < (1 << 29U))
13 << "RecordIO only accept record less than 2^29 bytes";
14 const uint32_t umagic = kMagic;
15 // initialize the magic number, in stack
16 const char *magic = reinterpret_cast<const char*>(&umagic);
17 const char *bhead = reinterpret_cast<const char*>(buf);
18 uint32_t len = static_cast<uint32_t>(size);
19 uint32_t lower_align = (len >> 2U) << 2U;
20 uint32_t upper_align = ((len + 3U) >> 2U) << 2U;
21 uint32_t dptr = 0;
22 for (uint32_t i = 0; i < lower_align ; i += 4) {
23 // use char check for alignment safety reason
24 if (bhead[i] == magic[0] &&
25 bhead[i + 1] == magic[1] &&
26 bhead[i + 2] == magic[2] &&
27 bhead[i + 3] == magic[3]) {
28 uint32_t lrec = EncodeLRec(dptr == 0 ? 1U : 2U,
29 i - dptr);
30 stream_->Write(magic, 4);
31 stream_->Write(&lrec, sizeof(lrec));
32 if (i != dptr) {
33 stream_->Write(bhead + dptr, i - dptr);
34 }
35 dptr = i + 4;
36 except_counter_ += 1;
37 }
38 }
39 uint32_t lrec = EncodeLRec(dptr != 0 ? 3U : 0U,
40 len - dptr);
41 stream_->Write(magic, 4);
42 stream_->Write(&lrec, sizeof(lrec));
43 if (len != dptr) {
44 stream_->Write(bhead + dptr, len - dptr);
45 }
46 // write padded bytes
47 uint32_t zero = 0;
48 if (upper_align != len) {
49 stream_->Write(&zero, upper_align - len);
50 }
51 }
52
NextRecord(std::string * out_rec)53 bool RecordIOReader::NextRecord(std::string *out_rec) {
54 if (end_of_stream_) return false;
55 const uint32_t kMagic = RecordIOWriter::kMagic;
56 out_rec->clear();
57 size_t size = 0;
58 while (true) {
59 uint32_t header[2];
60 size_t nread = stream_->Read(header, sizeof(header));
61 if (nread == 0) {
62 end_of_stream_ = true; return false;
63 }
64 CHECK(nread == sizeof(header)) << "Inavlid RecordIO File";
65 CHECK(header[0] == RecordIOWriter::kMagic) << "Invalid RecordIO File";
66 uint32_t cflag = RecordIOWriter::DecodeFlag(header[1]);
67 uint32_t len = RecordIOWriter::DecodeLength(header[1]);
68 uint32_t upper_align = ((len + 3U) >> 2U) << 2U;
69 out_rec->resize(size + upper_align);
70 if (upper_align != 0) {
71 CHECK(stream_->Read(BeginPtr(*out_rec) + size, upper_align) == upper_align)
72 << "Invalid RecordIO File upper_align=" << upper_align;
73 }
74 // squeeze back
75 size += len; out_rec->resize(size);
76 if (cflag == 0U || cflag == 3U) break;
77 out_rec->resize(size + sizeof(kMagic));
78 std::memcpy(BeginPtr(*out_rec) + size, &kMagic, sizeof(kMagic));
79 size += sizeof(kMagic);
80 }
81 return true;
82 }
83
84 // helper function to find next recordio head
FindNextRecordIOHead(char * begin,char * end)85 inline char *FindNextRecordIOHead(char *begin, char *end) {
86 CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
87 CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
88 uint32_t *p = reinterpret_cast<uint32_t *>(begin);
89 uint32_t *pend = reinterpret_cast<uint32_t *>(end);
90 for (; p + 1 < pend; ++p) {
91 if (p[0] == RecordIOWriter::kMagic) {
92 uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
93 if (cflag == 0 || cflag == 1) {
94 return reinterpret_cast<char*>(p);
95 }
96 }
97 }
98 return end;
99 }
100
RecordIOChunkReader(InputSplit::Blob chunk,unsigned part_index,unsigned num_parts)101 RecordIOChunkReader::RecordIOChunkReader(InputSplit::Blob chunk,
102 unsigned part_index,
103 unsigned num_parts) {
104 size_t nstep = (chunk.size + num_parts - 1) / num_parts;
105 // align
106 nstep = ((nstep + 3UL) >> 2UL) << 2UL;
107 size_t begin = std::min(chunk.size, nstep * part_index);
108 size_t end = std::min(chunk.size, nstep * (part_index + 1));
109 char *head = reinterpret_cast<char*>(chunk.dptr);
110 pbegin_ = FindNextRecordIOHead(head + begin, head + chunk.size);
111 pend_ = FindNextRecordIOHead(head + end, head + chunk.size);
112 }
113
NextRecord(InputSplit::Blob * out_rec)114 bool RecordIOChunkReader::NextRecord(InputSplit::Blob *out_rec) {
115 if (pbegin_ >= pend_) return false;
116 uint32_t *p = reinterpret_cast<uint32_t *>(pbegin_);
117 CHECK(p[0] == RecordIOWriter::kMagic);
118 uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
119 uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
120 if (cflag == 0) {
121 // skip header
122 out_rec->dptr = pbegin_ + 2 * sizeof(uint32_t);
123 // move pbegin
124 pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
125 CHECK(pbegin_ <= pend_) << "Invalid RecordIO Format";
126 out_rec->size = clen;
127 return true;
128 } else {
129 const uint32_t kMagic = RecordIOWriter::kMagic;
130 // abnormal path, read into string
131 CHECK(cflag == 1U) << "Invalid RecordIO Format";
132 temp_.resize(0);
133 while (true) {
134 CHECK(pbegin_ + 2 * sizeof(uint32_t) <= pend_);
135 p = reinterpret_cast<uint32_t *>(pbegin_);
136 CHECK(p[0] == RecordIOWriter::kMagic);
137 cflag = RecordIOWriter::DecodeFlag(p[1]);
138 clen = RecordIOWriter::DecodeLength(p[1]);
139 size_t tsize = temp_.length();
140 temp_.resize(tsize + clen);
141 if (clen != 0) {
142 std::memcpy(BeginPtr(temp_) + tsize,
143 pbegin_ + 2 * sizeof(uint32_t),
144 clen);
145 tsize += clen;
146 }
147 pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
148 if (cflag == 3U) break;
149 temp_.resize(tsize + sizeof(kMagic));
150 std::memcpy(BeginPtr(temp_) + tsize, &kMagic, sizeof(kMagic));
151 }
152 out_rec->dptr = BeginPtr(temp_);
153 out_rec->size = temp_.length();
154 return true;
155 }
156 }
157 } // namespace dmlc
158