1 // Copyright by Contributors
2 
3 #include <dmlc/base.h>
4 #include <dmlc/recordio.h>
5 #include <dmlc/logging.h>
6 #include <algorithm>
7 
8 
9 namespace dmlc {
10 // implementation
WriteRecord(const void * buf,size_t size)11 void RecordIOWriter::WriteRecord(const void *buf, size_t size) {
12   CHECK(size < (1 << 29U))
13       << "RecordIO only accept record less than 2^29 bytes";
14   const uint32_t umagic = kMagic;
15   // initialize the magic number, in stack
16   const char *magic = reinterpret_cast<const char*>(&umagic);
17   const char *bhead = reinterpret_cast<const char*>(buf);
18   uint32_t len = static_cast<uint32_t>(size);
19   uint32_t lower_align = (len >> 2U) << 2U;
20   uint32_t upper_align = ((len + 3U) >> 2U) << 2U;
21   uint32_t dptr = 0;
22   for (uint32_t i = 0; i < lower_align ; i += 4) {
23     // use char check for alignment safety reason
24     if (bhead[i] == magic[0] &&
25         bhead[i + 1] == magic[1] &&
26         bhead[i + 2] == magic[2] &&
27         bhead[i + 3] == magic[3]) {
28       uint32_t lrec = EncodeLRec(dptr == 0 ? 1U : 2U,
29                                  i - dptr);
30       stream_->Write(magic, 4);
31       stream_->Write(&lrec, sizeof(lrec));
32       if (i != dptr) {
33         stream_->Write(bhead + dptr, i - dptr);
34       }
35       dptr = i + 4;
36       except_counter_ += 1;
37     }
38   }
39   uint32_t lrec = EncodeLRec(dptr != 0 ? 3U : 0U,
40                              len - dptr);
41   stream_->Write(magic, 4);
42   stream_->Write(&lrec, sizeof(lrec));
43   if (len != dptr) {
44     stream_->Write(bhead + dptr, len - dptr);
45   }
46   // write padded bytes
47   uint32_t zero = 0;
48   if (upper_align != len) {
49     stream_->Write(&zero, upper_align - len);
50   }
51 }
52 
NextRecord(std::string * out_rec)53 bool RecordIOReader::NextRecord(std::string *out_rec) {
54   if (end_of_stream_) return false;
55   const uint32_t kMagic = RecordIOWriter::kMagic;
56   out_rec->clear();
57   size_t size = 0;
58   while (true) {
59     uint32_t header[2];
60     size_t nread = stream_->Read(header, sizeof(header));
61     if (nread == 0) {
62       end_of_stream_ = true; return false;
63     }
64     CHECK(nread == sizeof(header)) << "Inavlid RecordIO File";
65     CHECK(header[0] == RecordIOWriter::kMagic) << "Invalid RecordIO File";
66     uint32_t cflag = RecordIOWriter::DecodeFlag(header[1]);
67     uint32_t len = RecordIOWriter::DecodeLength(header[1]);
68     uint32_t upper_align = ((len + 3U) >> 2U) << 2U;
69     out_rec->resize(size + upper_align);
70     if (upper_align != 0) {
71       CHECK(stream_->Read(BeginPtr(*out_rec) + size, upper_align) == upper_align)
72           << "Invalid RecordIO File upper_align=" << upper_align;
73     }
74     // squeeze back
75     size += len; out_rec->resize(size);
76     if (cflag == 0U || cflag == 3U) break;
77     out_rec->resize(size + sizeof(kMagic));
78     std::memcpy(BeginPtr(*out_rec) + size, &kMagic, sizeof(kMagic));
79     size += sizeof(kMagic);
80   }
81   return true;
82 }
83 
84 // helper function to find next recordio head
FindNextRecordIOHead(char * begin,char * end)85 inline char *FindNextRecordIOHead(char *begin, char *end) {
86   CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL),  0U);
87   CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
88   uint32_t *p = reinterpret_cast<uint32_t *>(begin);
89   uint32_t *pend = reinterpret_cast<uint32_t *>(end);
90   for (; p + 1 < pend; ++p) {
91     if (p[0] == RecordIOWriter::kMagic) {
92       uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
93       if (cflag == 0 || cflag == 1) {
94         return reinterpret_cast<char*>(p);
95       }
96     }
97   }
98   return end;
99 }
100 
RecordIOChunkReader(InputSplit::Blob chunk,unsigned part_index,unsigned num_parts)101 RecordIOChunkReader::RecordIOChunkReader(InputSplit::Blob chunk,
102                                          unsigned part_index,
103                                          unsigned num_parts) {
104   size_t nstep = (chunk.size + num_parts - 1) / num_parts;
105   // align
106   nstep = ((nstep + 3UL) >> 2UL) << 2UL;
107   size_t begin = std::min(chunk.size, nstep * part_index);
108   size_t end = std::min(chunk.size, nstep * (part_index + 1));
109   char *head = reinterpret_cast<char*>(chunk.dptr);
110   pbegin_ = FindNextRecordIOHead(head + begin, head + chunk.size);
111   pend_ = FindNextRecordIOHead(head + end, head + chunk.size);
112 }
113 
NextRecord(InputSplit::Blob * out_rec)114 bool RecordIOChunkReader::NextRecord(InputSplit::Blob *out_rec) {
115   if (pbegin_ >= pend_) return false;
116   uint32_t *p = reinterpret_cast<uint32_t *>(pbegin_);
117   CHECK(p[0] == RecordIOWriter::kMagic);
118   uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
119   uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
120   if (cflag == 0) {
121     // skip header
122     out_rec->dptr = pbegin_ + 2 * sizeof(uint32_t);
123     // move pbegin
124     pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
125     CHECK(pbegin_ <= pend_) << "Invalid RecordIO Format";
126     out_rec->size = clen;
127     return true;
128   } else {
129     const uint32_t kMagic = RecordIOWriter::kMagic;
130     // abnormal path, read into string
131     CHECK(cflag == 1U) << "Invalid RecordIO Format";
132     temp_.resize(0);
133     while (true) {
134       CHECK(pbegin_ + 2 * sizeof(uint32_t) <= pend_);
135       p = reinterpret_cast<uint32_t *>(pbegin_);
136       CHECK(p[0] == RecordIOWriter::kMagic);
137       cflag = RecordIOWriter::DecodeFlag(p[1]);
138       clen = RecordIOWriter::DecodeLength(p[1]);
139       size_t tsize = temp_.length();
140       temp_.resize(tsize + clen);
141       if (clen != 0) {
142         std::memcpy(BeginPtr(temp_) + tsize,
143                     pbegin_ + 2 * sizeof(uint32_t),
144                     clen);
145         tsize += clen;
146       }
147       pbegin_ += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
148       if (cflag == 3U) break;
149       temp_.resize(tsize + sizeof(kMagic));
150       std::memcpy(BeginPtr(temp_) + tsize, &kMagic, sizeof(kMagic));
151     }
152     out_rec->dptr = BeginPtr(temp_);
153     out_rec->size = temp_.length();
154     return true;
155   }
156 }
157 }  // namespace dmlc
158