1 /*!
2  *  Copyright (c) 2017 by Contributors
3  * \file libfm_parser.h
4  * \brief iterator parser to parse libfm format
5  * \author formath
6  */
7 #ifndef DMLC_DATA_LIBFM_PARSER_H_
8 #define DMLC_DATA_LIBFM_PARSER_H_
9 
10 #include <dmlc/data.h>
11 #include <dmlc/strtonum.h>
12 #include <dmlc/parameter.h>
13 #include <map>
14 #include <string>
15 #include <limits>
16 #include <algorithm>
17 #include <cstring>
18 #include "./row_block.h"
19 #include "./text_parser.h"
20 
21 namespace dmlc {
22 namespace data {
23 
24 struct LibFMParserParam : public Parameter<LibFMParserParam> {
25   std::string format;
26   int indexing_mode;
27   // declare parameters
DMLC_DECLARE_PARAMETERLibFMParserParam28   DMLC_DECLARE_PARAMETER(LibFMParserParam) {
29     DMLC_DECLARE_FIELD(format).set_default("libfm")
30         .describe("File format");
31     DMLC_DECLARE_FIELD(indexing_mode).set_default(0)
32         .describe(
33           "If >0, treat all field and feature indices as 1-based. "
34           "If =0, treat all field and feature indices as 0-based. "
35           "If <0, use heuristic to automatically detect mode of indexing. "
36           "See https://en.wikipedia.org/wiki/Array_data_type#Index_origin "
37           "for more details on indexing modes.");
38   }
39 };
40 
41 /*!
42  * \brief Text parser that parses the input lines
43  * and returns rows in input data
44  */
45 template <typename IndexType, typename DType = real_t>
46 class LibFMParser : public TextParserBase<IndexType, DType> {
47  public:
LibFMParser(InputSplit * source,int nthread)48   explicit LibFMParser(InputSplit *source, int nthread)
49       : LibFMParser(source, std::map<std::string, std::string>(), nthread) {}
LibFMParser(InputSplit * source,const std::map<std::string,std::string> & args,int nthread)50   explicit LibFMParser(InputSplit *source,
51                        const std::map<std::string, std::string>& args,
52                        int nthread)
53       : TextParserBase<IndexType>(source, nthread) {
54     param_.Init(args);
55     CHECK_EQ(param_.format, "libfm");
56   }
57 
58  protected:
59   virtual void ParseBlock(const char *begin,
60                           const char *end,
61                           RowBlockContainer<IndexType, DType> *out);
62 
63  private:
64   LibFMParserParam param_;
65 };
66 
67 template <typename IndexType, typename DType>
68 void LibFMParser<IndexType, DType>::
ParseBlock(const char * begin,const char * end,RowBlockContainer<IndexType,DType> * out)69 ParseBlock(const char *begin,
70            const char *end,
71            RowBlockContainer<IndexType, DType> *out) {
72   out->Clear();
73   const char * lbegin = begin;
74   const char * lend = lbegin;
75   IndexType min_field_id = std::numeric_limits<IndexType>::max();
76   IndexType min_feat_id = std::numeric_limits<IndexType>::max();
77   while (lbegin != end) {
78     // get line end
79     lend = lbegin + 1;
80     while (lend != end && *lend != '\n' && *lend != '\r') ++lend;
81     // parse label[:weight]
82     const char * p = lbegin;
83     const char * q = NULL;
84     real_t label;
85     real_t weight;
86     int r = ParsePair<real_t, real_t>(p, lend, &q, label, weight);
87     if (r < 1) {
88       // empty line
89       lbegin = lend;
90       continue;
91     }
92     if (r == 2) {
93       // has weight
94       out->weight.push_back(weight);
95     }
96     if (out->label.size() != 0) {
97       out->offset.push_back(out->index.size());
98     }
99     out->label.push_back(label);
100     // parse fieldid:feature:value
101     p = q;
102     while (p != lend) {
103       IndexType fieldId;
104       IndexType featureId;
105       real_t value;
106       int r = ParseTriple<IndexType, IndexType, real_t>(p, lend, &q, fieldId, featureId, value);
107       if (r <= 1) {
108         p = q;
109         continue;
110       }
111       out->field.push_back(fieldId);
112       out->index.push_back(featureId);
113       min_field_id = std::min(fieldId, min_field_id);
114       min_feat_id = std::min(featureId, min_feat_id);
115       if (r == 3) {
116         // has value
117         out->value.push_back(value);
118       }
119       p = q;
120     }
121     // next line
122     lbegin = lend;
123   }
124   if (out->label.size() != 0) {
125     out->offset.push_back(out->index.size());
126   }
127   CHECK(out->field.size() == out->index.size());
128   CHECK(out->label.size() + 1 == out->offset.size());
129 
130   // detect indexing mode
131   // heuristic adopted from sklearn.datasets.load_svmlight_file
132   // If all feature and field id's exceed 0, then detect 1-based indexing
133   if (param_.indexing_mode > 0
134       || (param_.indexing_mode < 0 && !out->index.empty() && min_feat_id > 0
135           && !out->field.empty() && min_field_id > 0) ) {
136     // convert from 1-based to 0-based indexing
137     for (IndexType& e : out->index) {
138       --e;
139     }
140     for (IndexType& e : out->field) {
141       --e;
142     }
143   }
144 }
145 
146 }  // namespace data
147 }  // namespace dmlc
148 #endif  // DMLC_DATA_LIBFM_PARSER_H_
149