1 // paren.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // Common classes for PDT parentheses
19 
20 // \file
21 
22 #ifndef FST_EXTENSIONS_PDT_PAREN_H_
23 #define FST_EXTENSIONS_PDT_PAREN_H_
24 
25 #include <algorithm>
26 #include <unordered_map>
27 using std::unordered_map;
28 using std::unordered_multimap;
29 #include <unordered_set>
30 using std::unordered_set;
31 using std::unordered_multiset;
32 #include <set>
33 
34 #include <fst/extensions/pdt/pdt.h>
35 #include <fst/extensions/pdt/collection.h>
36 #include <fst/fst.h>
37 #include <fst/dfs-visit.h>
38 
39 
40 namespace fst {
41 
42 //
43 // ParenState: Pair of an open (close) parenthesis and
44 // its destination (source) state.
45 //
46 
47 template <class A>
48 class ParenState {
49  public:
50   typedef typename A::Label Label;
51   typedef typename A::StateId StateId;
52 
53   struct Hash {
operatorHash54     size_t operator()(const ParenState<A> &p) const {
55       return p.paren_id + p.state_id * kPrime;
56     }
57   };
58 
59   Label paren_id;     // ID of open (close) paren
60   StateId state_id;   // destination (source) state of open (close) paren
61 
ParenState()62   ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
63 
ParenState(Label p,StateId s)64   ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
65 
66   bool operator==(const ParenState<A> &p) const {
67     if (&p == this)
68       return true;
69     return p.paren_id == this->paren_id && p.state_id == this->state_id;
70   }
71 
72   bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
73 
74   bool operator<(const ParenState<A> &p) const {
75     return paren_id < this->paren.id ||
76         (p.paren_id == this->paren.id && p.state_id < this->state_id);
77   }
78 
79  private:
80   static const size_t kPrime;
81 };
82 
83 template <class A>
84 const size_t ParenState<A>::kPrime = 7853;
85 
86 
87 // Creates an FST-style iterator from STL map and iterator.
88 template <class M>
89 class MapIterator {
90  public:
91   typedef typename M::const_iterator StlIterator;
92   typedef typename M::value_type PairType;
93   typedef typename PairType::second_type ValueType;
94 
MapIterator(const M & m,StlIterator iter)95   MapIterator(const M &m, StlIterator iter)
96       : map_(m), begin_(iter), iter_(iter) {}
97 
Done()98   bool Done() const {
99     return iter_ == map_.end() || iter_->first != begin_->first;
100   }
101 
Value()102   ValueType Value() const { return iter_->second; }
Next()103   void Next() { ++iter_; }
Reset()104   void Reset() { iter_ = begin_; }
105 
106  private:
107   const M &map_;
108   StlIterator begin_;
109   StlIterator iter_;
110 };
111 
112 //
113 // PdtParenReachable: Provides various parenthesis reachability information
114 // on a PDT.
115 //
116 
117 template <class A>
118 class PdtParenReachable {
119  public:
120   typedef typename A::StateId StateId;
121   typedef typename A::Label Label;
122  public:
123   // Maps from state ID to reachable paren IDs from (to) that state.
124   typedef unordered_multimap<StateId, Label> ParenMultiMap;
125 
126   // Maps from paren ID and state ID to reachable state set ID
127   typedef unordered_map<ParenState<A>, ssize_t,
128                    typename ParenState<A>::Hash> StateSetMap;
129 
130   // Maps from paren ID and state ID to arcs exiting that state with that
131   // Label.
132   typedef unordered_multimap<ParenState<A>, A,
133                         typename ParenState<A>::Hash> ParenArcMultiMap;
134 
135   typedef MapIterator<ParenMultiMap> ParenIterator;
136 
137   typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
138 
139   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
140 
141   // Computes close (open) parenthesis reachabilty information for
142   // a PDT with bounded stack.
PdtParenReachable(const Fst<A> & fst,const vector<pair<Label,Label>> & parens,bool close)143   PdtParenReachable(const Fst<A> &fst,
144                     const vector<pair<Label, Label> > &parens, bool close)
145       : fst_(fst),
146         parens_(parens),
147         close_(close),
148         error_(false) {
149     for (Label i = 0; i < parens.size(); ++i) {
150       const pair<Label, Label>  &p = parens[i];
151       paren_id_map_[p.first] = i;
152       paren_id_map_[p.second] = i;
153     }
154 
155     if (close_) {
156       StateId start = fst.Start();
157       if (start == kNoStateId)
158         return;
159       if (!DFSearch(start)) {
160         FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
161         error_ = true;
162       }
163     } else {
164       FSTERROR() << "PdtParenReachable: open paren info not implemented";
165       error_ = true;
166     }
167   }
168 
Error()169   bool const Error() { return error_; }
170 
171   // Given a state ID, returns an iterator over paren IDs
172   // for close (open) parens reachable from that state along balanced
173   // paths.
FindParens(StateId s)174   ParenIterator FindParens(StateId s) const {
175     return ParenIterator(paren_multimap_, paren_multimap_.find(s));
176   }
177 
178   // Given a paren ID and a state ID s, returns an iterator over
179   // states that can be reached along balanced paths from (to) s that
180   // have have close (open) parentheses matching the paren ID exiting
181   // (entering) those states.
FindStates(Label paren_id,StateId s)182   SetIterator FindStates(Label paren_id, StateId s) const {
183     ParenState<A> paren_state(paren_id, s);
184     typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
185     if (id_it == set_map_.end()) {
186       return state_sets_.FindSet(-1);
187     } else {
188       return state_sets_.FindSet(id_it->second);
189     }
190   }
191 
192   // Given a paren Id and a state ID s, return an iterator over
193   // arcs that exit (enter) s and are labeled with a close (open)
194   // parenthesis matching the paren ID.
FindParenArcs(Label paren_id,StateId s)195   ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
196     ParenState<A> paren_state(paren_id, s);
197     return ParenArcIterator(paren_arc_multimap_,
198                             paren_arc_multimap_.find(paren_state));
199   }
200 
201  private:
202   // DFS that gathers paren and state set information.
203   // Bool returns false when cycle detected.
204   bool DFSearch(StateId s);
205 
206   // Unions state sets together gathered by the DFS.
207   void ComputeStateSet(StateId s);
208 
209   // Gather state set(s) from state 'nexts'.
210   void UpdateStateSet(StateId nexts, set<Label> *paren_set,
211                       vector< set<StateId> > *state_sets) const;
212 
213   const Fst<A> &fst_;
214   const vector<pair<Label, Label> > &parens_;         // Paren ID -> Labels
215   bool close_;                                        // Close/open paren info?
216   unordered_map<Label, Label> paren_id_map_;               // Paren labels -> ID
217   ParenMultiMap paren_multimap_;                      // Paren reachability
218   ParenArcMultiMap paren_arc_multimap_;               // Paren Arcs
219   vector<char> state_color_;                          // DFS state
220   mutable Collection<ssize_t, StateId> state_sets_;   // Reachable states -> ID
221   StateSetMap set_map_;                               // ID -> Reachable states
222   bool error_;
223   DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
224 };
225 
226 // DFS that gathers paren and state set information.
227 template <class A>
DFSearch(StateId s)228 bool PdtParenReachable<A>::DFSearch(StateId s) {
229   if (s >= state_color_.size())
230     state_color_.resize(s + 1, kDfsWhite);
231 
232   if (state_color_[s] == kDfsBlack)
233     return true;
234 
235   if (state_color_[s] == kDfsGrey)
236     return false;
237 
238   state_color_[s] = kDfsGrey;
239 
240   for (ArcIterator<Fst<A> > aiter(fst_, s);
241        !aiter.Done();
242        aiter.Next()) {
243     const A &arc = aiter.Value();
244 
245     typename unordered_map<Label, Label>::const_iterator pit
246         = paren_id_map_.find(arc.ilabel);
247     if (pit != paren_id_map_.end()) {               // paren?
248       Label paren_id = pit->second;
249       if (arc.ilabel == parens_[paren_id].first) {  // open paren
250         if (!DFSearch(arc.nextstate))
251           return false;
252         for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
253              !set_iter.Done(); set_iter.Next()) {
254           for (ParenArcIterator paren_arc_iter =
255                    FindParenArcs(paren_id, set_iter.Element());
256                !paren_arc_iter.Done();
257                paren_arc_iter.Next()) {
258             const A &cparc = paren_arc_iter.Value();
259             if (!DFSearch(cparc.nextstate))
260               return false;
261           }
262         }
263       }
264     } else {                                       // non-paren
265       if(!DFSearch(arc.nextstate))
266         return false;
267     }
268   }
269   ComputeStateSet(s);
270   state_color_[s] = kDfsBlack;
271   return true;
272 }
273 
274 // Unions state sets together gathered by the DFS.
275 template <class A>
ComputeStateSet(StateId s)276 void PdtParenReachable<A>::ComputeStateSet(StateId s) {
277   set<Label> paren_set;
278   vector< set<StateId> > state_sets(parens_.size());
279   for (ArcIterator< Fst<A> > aiter(fst_, s);
280        !aiter.Done();
281        aiter.Next()) {
282     const A &arc = aiter.Value();
283 
284     typename unordered_map<Label, Label>::const_iterator pit
285         = paren_id_map_.find(arc.ilabel);
286     if (pit != paren_id_map_.end()) {               // paren?
287       Label paren_id = pit->second;
288       if (arc.ilabel == parens_[paren_id].first) {  // open paren
289         for (SetIterator set_iter =
290                  FindStates(paren_id, arc.nextstate);
291              !set_iter.Done(); set_iter.Next()) {
292           for (ParenArcIterator paren_arc_iter =
293                    FindParenArcs(paren_id, set_iter.Element());
294                !paren_arc_iter.Done();
295                paren_arc_iter.Next()) {
296             const A &cparc = paren_arc_iter.Value();
297             UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
298           }
299         }
300       } else {                                      // close paren
301         paren_set.insert(paren_id);
302         state_sets[paren_id].insert(s);
303         ParenState<A> paren_state(paren_id, s);
304         paren_arc_multimap_.insert(make_pair(paren_state, arc));
305       }
306     } else {                                        // non-paren
307       UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
308     }
309   }
310 
311   vector<StateId> state_set;
312   for (typename set<Label>::iterator paren_iter = paren_set.begin();
313        paren_iter != paren_set.end(); ++paren_iter) {
314     state_set.clear();
315     Label paren_id = *paren_iter;
316     paren_multimap_.insert(make_pair(s, paren_id));
317     for (typename set<StateId>::iterator state_iter
318              = state_sets[paren_id].begin();
319          state_iter != state_sets[paren_id].end();
320          ++state_iter) {
321       state_set.push_back(*state_iter);
322     }
323     ParenState<A> paren_state(paren_id, s);
324     set_map_[paren_state] = state_sets_.FindId(state_set);
325   }
326 }
327 
328 // Gather state set(s) from state 'nexts'.
329 template <class A>
UpdateStateSet(StateId nexts,set<Label> * paren_set,vector<set<StateId>> * state_sets)330 void PdtParenReachable<A>::UpdateStateSet(
331     StateId nexts, set<Label> *paren_set,
332     vector< set<StateId> > *state_sets) const {
333   for(ParenIterator paren_iter = FindParens(nexts);
334       !paren_iter.Done(); paren_iter.Next()) {
335     Label paren_id = paren_iter.Value();
336     paren_set->insert(paren_id);
337     for (SetIterator set_iter = FindStates(paren_id, nexts);
338          !set_iter.Done(); set_iter.Next()) {
339       (*state_sets)[paren_id].insert(set_iter.Element());
340     }
341   }
342 }
343 
344 
345 // Store balancing parenthesis data for a PDT. Allows on-the-fly
346 // construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
347 template <class A>
348 class PdtBalanceData {
349  public:
350   typedef typename A::StateId StateId;
351   typedef typename A::Label Label;
352 
353   // Hash set for open parens
354   typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
355 
356   // Maps from open paren destination state to parenthesis ID.
357   typedef unordered_multimap<StateId, Label> OpenParenMap;
358 
359   // Maps from open paren state to source states of matching close parens
360   typedef unordered_multimap<ParenState<A>, StateId,
361                         typename ParenState<A>::Hash> CloseParenMap;
362 
363   // Maps from open paren state to close source set ID
364   typedef unordered_map<ParenState<A>, ssize_t,
365                    typename ParenState<A>::Hash> CloseSourceMap;
366 
367   typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
368 
PdtBalanceData()369   PdtBalanceData() {}
370 
Clear()371   void Clear() {
372     open_paren_map_.clear();
373     close_paren_map_.clear();
374   }
375 
376   // Adds an open parenthesis with destination state 'open_dest'.
OpenInsert(Label paren_id,StateId open_dest)377   void OpenInsert(Label paren_id, StateId open_dest) {
378     ParenState<A> key(paren_id, open_dest);
379     if (!open_paren_set_.count(key)) {
380       open_paren_set_.insert(key);
381       open_paren_map_.insert(make_pair(open_dest, paren_id));
382     }
383   }
384 
385   // Adds a matching closing parenthesis with source state
386   // 'close_source' that balances an open_parenthesis with destination
387   // state 'open_dest' if OpenInsert() previously called
388   // (o.w. CloseInsert() does nothing).
CloseInsert(Label paren_id,StateId open_dest,StateId close_source)389   void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
390     ParenState<A> key(paren_id, open_dest);
391     if (open_paren_set_.count(key))
392       close_paren_map_.insert(make_pair(key, close_source));
393   }
394 
395   // Find close paren source states matching an open parenthesis.
396   // Methods that follow, iterate through those matching states.
397   // Should be called only after FinishInsert(open_dest).
Find(Label paren_id,StateId open_dest)398   SetIterator Find(Label paren_id, StateId open_dest) {
399     ParenState<A> close_key(paren_id, open_dest);
400     typename CloseSourceMap::const_iterator id_it =
401         close_source_map_.find(close_key);
402     if (id_it == close_source_map_.end()) {
403       return close_source_sets_.FindSet(-1);
404     } else {
405       return close_source_sets_.FindSet(id_it->second);
406     }
407   }
408 
409   // Call when all open and close parenthesis insertions wrt open
410   // parentheses entering 'open_dest' are finished. Must be called
411   // before Find(open_dest). Stores close paren source state sets
412   // efficiently.
FinishInsert(StateId open_dest)413   void FinishInsert(StateId open_dest) {
414     vector<StateId> close_sources;
415     for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
416          oit != open_paren_map_.end() && oit->first == open_dest;) {
417       Label paren_id = oit->second;
418       close_sources.clear();
419       ParenState<A> okey(paren_id, open_dest);
420       open_paren_set_.erase(open_paren_set_.find(okey));
421       for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
422            cit != close_paren_map_.end() && cit->first == okey;) {
423         close_sources.push_back(cit->second);
424         close_paren_map_.erase(cit++);
425       }
426       sort(close_sources.begin(), close_sources.end());
427       typename vector<StateId>::iterator unique_end =
428           unique(close_sources.begin(), close_sources.end());
429       close_sources.resize(unique_end - close_sources.begin());
430 
431       if (!close_sources.empty())
432         close_source_map_[okey] = close_source_sets_.FindId(close_sources);
433       open_paren_map_.erase(oit++);
434     }
435   }
436 
437   // Return a new balance data object representing the reversed balance
438   // information.
439   PdtBalanceData<A> *Reverse(StateId num_states,
440                                StateId num_split,
441                                StateId state_id_shift) const;
442 
443  private:
444   OpenParenSet open_paren_set_;                      // open par. at dest?
445 
446   OpenParenMap open_paren_map_;                      // open parens per state
447   ParenState<A> open_dest_;                          // cur open dest. state
448   typename OpenParenMap::const_iterator open_iter_;  // cur open parens/state
449 
450   CloseParenMap close_paren_map_;                    // close states/open
451                                                      //  paren and state
452 
453   CloseSourceMap close_source_map_;                  // paren, state to set ID
454   mutable Collection<ssize_t, StateId> close_source_sets_;
455 };
456 
457 // Return a new balance data object representing the reversed balance
458 // information.
459 template <class A>
Reverse(StateId num_states,StateId num_split,StateId state_id_shift)460 PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
461     StateId num_states,
462     StateId num_split,
463     StateId state_id_shift) const {
464   PdtBalanceData<A> *bd = new PdtBalanceData<A>;
465   unordered_set<StateId> close_sources;
466   StateId split_size = num_states / num_split;
467 
468   for (StateId i = 0; i < num_states; i+= split_size) {
469     close_sources.clear();
470 
471     for (typename CloseSourceMap::const_iterator
472              sit = close_source_map_.begin();
473          sit != close_source_map_.end();
474          ++sit) {
475       ParenState<A> okey = sit->first;
476       StateId open_dest = okey.state_id;
477       Label paren_id = okey.paren_id;
478       for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
479            !set_iter.Done(); set_iter.Next()) {
480         StateId close_source = set_iter.Element();
481         if ((close_source < i) || (close_source >= i + split_size))
482           continue;
483         close_sources.insert(close_source + state_id_shift);
484         bd->OpenInsert(paren_id, close_source + state_id_shift);
485         bd->CloseInsert(paren_id, close_source + state_id_shift,
486                         open_dest + state_id_shift);
487       }
488     }
489 
490     for (typename unordered_set<StateId>::const_iterator it
491              = close_sources.begin();
492          it != close_sources.end();
493          ++it) {
494       bd->FinishInsert(*it);
495     }
496 
497   }
498   return bd;
499 }
500 
501 
502 }  // namespace fst
503 
504 #endif  // FST_EXTENSIONS_PDT_PAREN_H_
505