1 // Copyright by Contributors
2 #include <dmlc/recordio.h>
3 #include <dmlc/logging.h>
4 #include <algorithm>
5 #include "./recordio_split.h"
6 
7 namespace dmlc {
8 namespace io {
SeekRecordBegin(Stream * fi)9 size_t RecordIOSplitter::SeekRecordBegin(Stream *fi) {
10   size_t nstep = 0;
11   uint32_t v, lrec;
12   while (true) {
13     if (fi->Read(&v, sizeof(v)) == 0) return nstep;
14     nstep += sizeof(v);
15     if (v == RecordIOWriter::kMagic) {
16       CHECK(fi->Read(&lrec, sizeof(lrec)) != 0)
17             << "invalid record io format";
18       nstep += sizeof(lrec);
19       uint32_t cflag = RecordIOWriter::DecodeFlag(lrec);
20       if (cflag == 0 || cflag == 1) break;
21     }
22   }
23   // should point at head of record
24   return nstep - 2 * sizeof(uint32_t);
25 }
FindLastRecordBegin(const char * begin,const char * end)26 const char* RecordIOSplitter::FindLastRecordBegin(const char *begin,
27                                                   const char *end) {
28   CHECK_EQ((reinterpret_cast<size_t>(begin) & 3UL), 0U);
29   CHECK_EQ((reinterpret_cast<size_t>(end) & 3UL), 0U);
30   const uint32_t *pbegin = reinterpret_cast<const uint32_t *>(begin);
31   const uint32_t *p = reinterpret_cast<const uint32_t *>(end);
32   CHECK(p >= pbegin + 2);
33   for (p = p - 2; p != pbegin; --p) {
34     if (p[0] == RecordIOWriter::kMagic) {
35       uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
36       if (cflag == 0 || cflag == 1) {
37         return reinterpret_cast<const char*>(p);
38       }
39     }
40   }
41   return begin;
42 }
43 
ExtractNextRecord(Blob * out_rec,Chunk * chunk)44 bool RecordIOSplitter::ExtractNextRecord(Blob *out_rec, Chunk *chunk) {
45   if (chunk->begin == chunk->end) return false;
46   CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end)
47       << "Invalid RecordIO Format";
48   CHECK_EQ((reinterpret_cast<size_t>(chunk->begin) & 3UL), 0U);
49   CHECK_EQ((reinterpret_cast<size_t>(chunk->end) & 3UL), 0U);
50   uint32_t *p = reinterpret_cast<uint32_t *>(chunk->begin);
51   uint32_t cflag = RecordIOWriter::DecodeFlag(p[1]);
52   uint32_t clen = RecordIOWriter::DecodeLength(p[1]);
53   // skip header
54   out_rec->dptr = chunk->begin + 2 * sizeof(uint32_t);
55   // move pbegin
56   chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
57   CHECK(chunk->begin <= chunk->end) << "Invalid RecordIO Format";
58   out_rec->size = clen;
59   if (cflag == 0) return true;
60   const uint32_t kMagic = RecordIOWriter::kMagic;
61   // abnormal path, move data around to make a full part
62   CHECK(cflag == 1U) << "Invalid RecordIO Format";
63   while (cflag != 3U) {
64     CHECK(chunk->begin + 2 * sizeof(uint32_t) <= chunk->end);
65     p = reinterpret_cast<uint32_t *>(chunk->begin);
66     CHECK(p[0] == RecordIOWriter::kMagic);
67     cflag = RecordIOWriter::DecodeFlag(p[1]);
68     clen = RecordIOWriter::DecodeLength(p[1]);
69     // pad kmagic in between
70     std::memcpy(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
71                 &kMagic, sizeof(kMagic));
72     out_rec->size += sizeof(kMagic);
73     // move the rest of the blobs
74     if (clen != 0) {
75       std::memmove(reinterpret_cast<char*>(out_rec->dptr) + out_rec->size,
76                    chunk->begin + 2 * sizeof(uint32_t), clen);
77       out_rec->size += clen;
78     }
79     chunk->begin += 2 * sizeof(uint32_t) + (((clen + 3U) >> 2U) << 2U);
80   }
81   return true;
82 }
83 }  // namespace io
84 }  // namespace dmlc
85