1 // sparse-tuple-weight.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: krr@google.com (Kasturi Rangan Raghavan)
17 // Inspiration: allauzen@google.com (Cyril Allauzen)
18 // \file
19 // Sparse version of tuple-weight, based on tuple-weight.h
20 //   Internally stores sparse key, value pairs in linked list
21 //   Default value elemnt is the assumed value of unset keys
22 //   Internal singleton implementation that stores first key,
23 //   value pair as a initialized member variable to avoide
24 //   unnecessary allocation on heap.
25 // Use SparseTupleWeightIterator to iterate through the key,value pairs
26 // Note: this does NOT iterate through the default value.
27 //
28 // Sparse tuple weight set operation definitions.
29 
30 #ifndef FST_LIB_SPARSE_TUPLE_WEIGHT_H__
31 #define FST_LIB_SPARSE_TUPLE_WEIGHT_H__
32 
33 #include<string>
34 #include<list>
35 #include<stack>
36 #include<unordered_map>
37 using std::unordered_map;
38 using std::unordered_multimap;
39 
40 #include <fst/weight.h>
41 
42 
43 DECLARE_string(fst_weight_parentheses);
44 DECLARE_string(fst_weight_separator);
45 
46 namespace fst {
47 
48 template <class W, class K> class SparseTupleWeight;
49 
50 template<class W, class K>
51 class SparseTupleWeightIterator;
52 
53 template <class W, class K>
54 istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w);
55 
56 // Arbitrary dimension tuple weight, stored as a sorted linked-list
57 // W is any weight class,
58 // K is the key value type. kNoKey(-1) is reserved for internal use
59 template <class W, class K = int>
60 class SparseTupleWeight {
61  public:
62   typedef pair<K, W> Pair;
63   typedef SparseTupleWeight<typename W::ReverseWeight, K> ReverseWeight;
64 
65   const static K kNoKey = -1;
SparseTupleWeight()66   SparseTupleWeight() {
67     Init();
68   }
69 
70   template <class Iterator>
SparseTupleWeight(Iterator begin,Iterator end)71   SparseTupleWeight(Iterator begin, Iterator end) {
72     Init();
73     // Assumes input iterator is sorted
74     for (Iterator it = begin; it != end; ++it)
75       Push(*it);
76   }
77 
78 
SparseTupleWeight(const K & key,const W & w)79   SparseTupleWeight(const K& key, const W &w) {
80     Init();
81     Push(key, w);
82   }
83 
SparseTupleWeight(const W & w)84   SparseTupleWeight(const W &w) {
85     Init(w);
86   }
87 
SparseTupleWeight(const SparseTupleWeight<W,K> & w)88   SparseTupleWeight(const SparseTupleWeight<W, K> &w) {
89     Init(w.DefaultValue());
90     SetDefaultValue(w.DefaultValue());
91     for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
92       Push(it.Value());
93     }
94   }
95 
Zero()96   static const SparseTupleWeight<W, K> &Zero() {
97     static SparseTupleWeight<W, K> zero;
98     return zero;
99   }
100 
One()101   static const SparseTupleWeight<W, K> &One() {
102     static SparseTupleWeight<W, K> one(W::One());
103     return one;
104   }
105 
NoWeight()106   static const SparseTupleWeight<W, K> &NoWeight() {
107     static SparseTupleWeight<W, K> no_weight(W::NoWeight());
108     return no_weight;
109   }
110 
Read(istream & strm)111   istream &Read(istream &strm) {
112     ReadType(strm, &default_);
113     ReadType(strm, &first_);
114     return ReadType(strm, &rest_);
115   }
116 
Write(ostream & strm)117   ostream &Write(ostream &strm) const {
118     WriteType(strm, default_);
119     WriteType(strm, first_);
120     return WriteType(strm, rest_);
121   }
122 
123   SparseTupleWeight<W, K> &operator=(const SparseTupleWeight<W, K> &w) {
124     if (this == &w) return *this; // check for w = w
125     Init(w.DefaultValue());
126     for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
127       Push(it.Value());
128     }
129     return *this;
130   }
131 
Member()132   bool Member() const {
133     if (!DefaultValue().Member()) return false;
134     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
135       if (!it.Value().second.Member()) return false;
136     }
137     return true;
138   }
139 
140   // Assumes H() function exists for the hash of the key value
Hash()141   size_t Hash() const {
142     uint64 h = 0;
143     std::hash<K> H;
144     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
145       h = 5 * h + H(it.Value().first);
146       h = 13 * h + it.Value().second.Hash();
147     }
148     return size_t(h);
149   }
150 
151   SparseTupleWeight<W, K> Quantize(float delta = kDelta) const {
152     SparseTupleWeight<W, K> w;
153     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
154       w.Push(it.Value().first, it.Value().second.Quantize(delta));
155     }
156     return w;
157   }
158 
Reverse()159   ReverseWeight Reverse() const {
160     SparseTupleWeight<W, K> w;
161     for (SparseTupleWeightIterator<W, K> it(*this); !it.Done(); it.Next()) {
162       w.Push(it.Value().first, it.Value().second.Reverse());
163     }
164     return w;
165   }
166 
167   // Common initializer among constructors.
Init()168   void Init() {
169     Init(W::Zero());
170   }
171 
Init(const W & default_value)172   void Init(const W& default_value) {
173     first_.first = kNoKey;
174     /* initialized to the reserved key value */
175     default_ = default_value;
176     rest_.clear();
177   }
178 
Size()179   size_t Size() const {
180     if (first_.first == kNoKey)
181       return 0;
182     else
183       return  rest_.size() + 1;
184   }
185 
186   inline void Push(const K &k, const W &w, bool default_value_check = true) {
187     Push(make_pair(k, w), default_value_check);
188   }
189 
190   inline void Push(const Pair &p, bool default_value_check = true) {
191     if (default_value_check && p.second == default_) return;
192     if (first_.first == kNoKey) {
193       first_ = p;
194     } else {
195       rest_.push_back(p);
196     }
197   }
198 
SetDefaultValue(const W & val)199   void SetDefaultValue(const W& val) { default_ = val; }
200 
DefaultValue()201   const W& DefaultValue() const { return default_; }
202 
203  protected:
204   static istream& ReadNoParen(
205     istream&, SparseTupleWeight<W, K>&, char separator);
206 
207   static istream& ReadWithParen(
208     istream&, SparseTupleWeight<W, K>&,
209     char separator, char open_paren, char close_paren);
210 
211  private:
212   // Assumed default value of uninitialized keys, by default W::Zero()
213   W default_;
214 
215   // Key values pairs are first stored in first_, then fill rest_
216   // this way we can avoid dynamic allocation in the common case
217   // where the weight is a single key,val pair.
218   Pair first_;
219   list<Pair> rest_;
220 
221   friend istream &operator>><W, K>(istream&, SparseTupleWeight<W, K>&);
222   friend class SparseTupleWeightIterator<W, K>;
223 };
224 
225 template<class W, class K>
226 class SparseTupleWeightIterator {
227  public:
228   typedef typename SparseTupleWeight<W, K>::Pair Pair;
229   typedef typename list<Pair>::const_iterator const_iterator;
230   typedef typename list<Pair>::iterator iterator;
231 
SparseTupleWeightIterator(const SparseTupleWeight<W,K> & w)232   explicit SparseTupleWeightIterator(const SparseTupleWeight<W, K>& w)
233     : first_(w.first_), rest_(w.rest_), init_(true),
234       iter_(rest_.begin()) {}
235 
Done()236   bool Done() const {
237     if (init_)
238       return first_.first == SparseTupleWeight<W, K>::kNoKey;
239     else
240       return iter_ == rest_.end();
241   }
242 
Value()243   const Pair& Value() const { return init_ ? first_ : *iter_; }
244 
Next()245   void Next() {
246     if (init_)
247       init_ = false;
248     else
249       ++iter_;
250   }
251 
Reset()252   void Reset() {
253     init_ = true;
254     iter_ = rest_.begin();
255   }
256 
257  private:
258   const Pair &first_;
259   const list<Pair> & rest_;
260   bool init_;  // in the initialized state?
261   typename list<Pair>::const_iterator iter_;
262 
263   DISALLOW_COPY_AND_ASSIGN(SparseTupleWeightIterator);
264 };
265 
266 template<class W, class K, class M>
SparseTupleWeightMap(SparseTupleWeight<W,K> * ret,const SparseTupleWeight<W,K> & w1,const SparseTupleWeight<W,K> & w2,const M & operator_mapper)267 inline void SparseTupleWeightMap(
268   SparseTupleWeight<W, K>* ret,
269   const SparseTupleWeight<W, K>& w1,
270   const SparseTupleWeight<W, K>& w2,
271   const M& operator_mapper) {
272   SparseTupleWeightIterator<W, K> w1_it(w1);
273   SparseTupleWeightIterator<W, K> w2_it(w2);
274   const W& v1_def = w1.DefaultValue();
275   const W& v2_def = w2.DefaultValue();
276   ret->SetDefaultValue(operator_mapper.Map(0, v1_def, v2_def));
277   while (!w1_it.Done() || !w2_it.Done()) {
278     const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
279     const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
280     const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
281     const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
282     if (k1 == k2) {
283       ret->Push(k1, operator_mapper.Map(k1, v1, v2));
284       if (!w1_it.Done()) w1_it.Next();
285       if (!w2_it.Done()) w2_it.Next();
286     } else if (k1 < k2) {
287       ret->Push(k1, operator_mapper.Map(k1, v1, v2_def));
288       w1_it.Next();
289     } else {
290       ret->Push(k2, operator_mapper.Map(k2, v1_def, v2));
291       w2_it.Next();
292     }
293   }
294 }
295 
296 template <class W, class K>
297 inline bool operator==(const SparseTupleWeight<W, K> &w1,
298                        const SparseTupleWeight<W, K> &w2) {
299   const W& v1_def = w1.DefaultValue();
300   const W& v2_def = w2.DefaultValue();
301   if (v1_def != v2_def) return false;
302 
303   SparseTupleWeightIterator<W, K> w1_it(w1);
304   SparseTupleWeightIterator<W, K> w2_it(w2);
305   while (!w1_it.Done() || !w2_it.Done()) {
306     const K& k1 = (w1_it.Done()) ? w2_it.Value().first : w1_it.Value().first;
307     const K& k2 = (w2_it.Done()) ? w1_it.Value().first : w2_it.Value().first;
308     const W& v1 = (w1_it.Done()) ? v1_def : w1_it.Value().second;
309     const W& v2 = (w2_it.Done()) ? v2_def : w2_it.Value().second;
310     if (k1 == k2) {
311       if (v1 != v2) return false;
312       if (!w1_it.Done()) w1_it.Next();
313       if (!w2_it.Done()) w2_it.Next();
314     } else if (k1 < k2) {
315       if (v1 != v2_def) return false;
316       w1_it.Next();
317     } else {
318       if (v1_def != v2) return false;
319       w2_it.Next();
320     }
321   }
322   return true;
323 }
324 
325 template <class W, class K>
326 inline bool operator!=(const SparseTupleWeight<W, K> &w1,
327                        const SparseTupleWeight<W, K> &w2) {
328   return !(w1 == w2);
329 }
330 
331 template <class W, class K>
332 inline ostream &operator<<(ostream &strm, const SparseTupleWeight<W, K> &w) {
333   if(FLAGS_fst_weight_separator.size() != 1) {
334     FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
335     strm.clear(std::ios::badbit);
336     return strm;
337   }
338   char separator = FLAGS_fst_weight_separator[0];
339   bool write_parens = false;
340   if (!FLAGS_fst_weight_parentheses.empty()) {
341     if (FLAGS_fst_weight_parentheses.size() != 2) {
342       FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
343       strm.clear(std::ios::badbit);
344       return strm;
345     }
346     write_parens = true;
347   }
348 
349   if (write_parens)
350     strm << FLAGS_fst_weight_parentheses[0];
351 
352   strm << w.DefaultValue();
353   strm << separator;
354 
355   size_t n = w.Size();
356   strm << n;
357   strm << separator;
358 
359   for (SparseTupleWeightIterator<W, K> it(w); !it.Done(); it.Next()) {
360       strm << it.Value().first;
361       strm << separator;
362       strm << it.Value().second;
363       strm << separator;
364   }
365 
366   if (write_parens)
367     strm << FLAGS_fst_weight_parentheses[1];
368 
369   return strm;
370 }
371 
372 template <class W, class K>
373 inline istream &operator>>(istream &strm, SparseTupleWeight<W, K> &w) {
374   if(FLAGS_fst_weight_separator.size() != 1) {
375     FSTERROR() << "FLAGS_fst_weight_separator.size() is not equal to 1";
376     strm.clear(std::ios::badbit);
377     return strm;
378   }
379   char separator = FLAGS_fst_weight_separator[0];
380 
381   if (!FLAGS_fst_weight_parentheses.empty()) {
382     if (FLAGS_fst_weight_parentheses.size() != 2) {
383       FSTERROR() << "FLAGS_fst_weight_parentheses.size() is not equal to 2";
384       strm.clear(std::ios::badbit);
385       return strm;
386     }
387     return SparseTupleWeight<W, K>::ReadWithParen(
388         strm, w, separator, FLAGS_fst_weight_parentheses[0],
389         FLAGS_fst_weight_parentheses[1]);
390   } else {
391     return SparseTupleWeight<W, K>::ReadNoParen(strm, w, separator);
392   }
393 }
394 
395 // Reads SparseTupleWeight when there are no parentheses around tuple terms
396 template <class W, class K>
ReadNoParen(istream & strm,SparseTupleWeight<W,K> & w,char separator)397 inline istream& SparseTupleWeight<W, K>::ReadNoParen(
398     istream &strm,
399     SparseTupleWeight<W, K> &w,
400     char separator) {
401   int c;
402   size_t n;
403 
404   do {
405     c = strm.get();
406   } while (isspace(c));
407 
408 
409   { // Read default weight
410     W default_value;
411     string s;
412     while (c != separator) {
413       if (c == EOF) {
414         strm.clear(std::ios::badbit);
415         return strm;
416       }
417       s += c;
418       c = strm.get();
419     }
420     istringstream sstrm(s);
421     sstrm >> default_value;
422     w.SetDefaultValue(default_value);
423   }
424 
425   c = strm.get();
426 
427   { // Read n
428     string s;
429     while (c != separator) {
430       if (c == EOF) {
431         strm.clear(std::ios::badbit);
432         return strm;
433       }
434       s += c;
435       c = strm.get();
436     }
437     istringstream sstrm(s);
438     sstrm >> n;
439   }
440 
441   // Read n elements
442   for (size_t i = 0; i < n; ++i) {
443     // discard separator
444     c = strm.get();
445     K p;
446     W r;
447 
448     { // read key
449       string s;
450       while (c != separator) {
451         if (c == EOF) {
452           strm.clear(std::ios::badbit);
453           return strm;
454         }
455         s += c;
456         c = strm.get();
457       }
458       istringstream sstrm(s);
459       sstrm >> p;
460     }
461 
462     c = strm.get();
463 
464     { // read weight
465       string s;
466       while (c != separator) {
467         if (c == EOF) {
468           strm.clear(std::ios::badbit);
469           return strm;
470         }
471         s += c;
472         c = strm.get();
473       }
474       istringstream sstrm(s);
475       sstrm >> r;
476     }
477 
478     w.Push(p, r);
479   }
480 
481   c = strm.get();
482   if (c != separator) {
483     strm.clear(std::ios::badbit);
484   }
485 
486   return strm;
487 }
488 
489 // Reads SparseTupleWeight when there are parentheses around tuple terms
490 template <class W, class K>
ReadWithParen(istream & strm,SparseTupleWeight<W,K> & w,char separator,char open_paren,char close_paren)491 inline istream& SparseTupleWeight<W, K>::ReadWithParen(
492     istream &strm,
493     SparseTupleWeight<W, K> &w,
494     char separator,
495     char open_paren,
496     char close_paren) {
497   int c;
498   size_t n;
499 
500   do {
501     c = strm.get();
502   } while (isspace(c));
503 
504   if (c != open_paren) {
505     FSTERROR() << "is fst_weight_parentheses flag set correcty? ";
506     strm.clear(std::ios::badbit);
507     return strm;
508   }
509 
510   c = strm.get();
511 
512   { // Read weight
513     W default_value;
514     stack<int> parens;
515     string s;
516     while (c != separator || !parens.empty()) {
517       if (c == EOF) {
518         strm.clear(std::ios::badbit);
519         return strm;
520       }
521       s += c;
522       // If parens encountered before separator, they must be matched
523       if (c == open_paren) {
524         parens.push(1);
525       } else if (c == close_paren) {
526         // Fail for mismatched parens
527         if (parens.empty()) {
528           strm.clear(std::ios::failbit);
529           return strm;
530         }
531         parens.pop();
532       }
533       c = strm.get();
534     }
535     istringstream sstrm(s);
536     sstrm >> default_value;
537     w.SetDefaultValue(default_value);
538   }
539 
540   c = strm.get();
541 
542   { // Read n
543     string s;
544     while (c != separator) {
545       if (c == EOF) {
546         strm.clear(std::ios::badbit);
547         return strm;
548       }
549       s += c;
550       c = strm.get();
551     }
552     istringstream sstrm(s);
553     sstrm >> n;
554   }
555 
556   // Read n elements
557   for (size_t i = 0; i < n; ++i) {
558     // discard separator
559     c = strm.get();
560     K p;
561     W r;
562 
563     { // Read key
564       stack<int> parens;
565       string s;
566       while (c != separator || !parens.empty()) {
567         if (c == EOF) {
568           strm.clear(std::ios::badbit);
569           return strm;
570         }
571         s += c;
572         // If parens encountered before separator, they must be matched
573         if (c == open_paren) {
574           parens.push(1);
575         } else if (c == close_paren) {
576           // Fail for mismatched parens
577           if (parens.empty()) {
578             strm.clear(std::ios::failbit);
579             return strm;
580           }
581           parens.pop();
582         }
583         c = strm.get();
584       }
585       istringstream sstrm(s);
586       sstrm >> p;
587     }
588 
589     c = strm.get();
590 
591     { // Read weight
592       stack<int> parens;
593       string s;
594       while (c != separator || !parens.empty()) {
595         if (c == EOF) {
596           strm.clear(std::ios::badbit);
597           return strm;
598         }
599         s += c;
600         // If parens encountered before separator, they must be matched
601         if (c == open_paren) {
602           parens.push(1);
603         } else if (c == close_paren) {
604           // Fail for mismatched parens
605           if (parens.empty()) {
606             strm.clear(std::ios::failbit);
607             return strm;
608           }
609           parens.pop();
610         }
611         c = strm.get();
612       }
613       istringstream sstrm(s);
614       sstrm >> r;
615     }
616 
617     w.Push(p, r);
618   }
619 
620   if (c != separator) {
621     FSTERROR() << " separator expected, not found! ";
622     strm.clear(std::ios::badbit);
623     return strm;
624   }
625 
626   c = strm.get();
627   if (c != close_paren) {
628     FSTERROR() << " is fst_weight_parentheses flag set correcty? ";
629     strm.clear(std::ios::badbit);
630     return strm;
631   }
632 
633   return strm;
634 }
635 
636 
637 
638 }  // namespace fst
639 
640 #endif  // FST_LIB_SPARSE_TUPLE_WEIGHT_H__
641