1 // connect.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Classes and functions to remove unsuccessful paths from an Fst.
20 
21 #ifndef FST_LIB_CONNECT_H__
22 #define FST_LIB_CONNECT_H__
23 
24 #include <vector>
25 using std::vector;
26 
27 #include <fst/dfs-visit.h>
28 #include <fst/union-find.h>
29 #include <fst/mutable-fst.h>
30 
31 
32 namespace fst {
33 
34 // Finds and returns connected components. Use with Visit().
35 template <class A>
36 class CcVisitor {
37  public:
38   typedef A Arc;
39   typedef typename Arc::Weight Weight;
40   typedef typename A::StateId StateId;
41 
42   // cc[i]: connected component number for state i.
CcVisitor(vector<StateId> * cc)43   CcVisitor(vector<StateId> *cc)
44       : comps_(new UnionFind<StateId>(0, kNoStateId)),
45         cc_(cc),
46         nstates_(0) { }
47 
48   // comps: connected components equiv classes.
CcVisitor(UnionFind<StateId> * comps)49   CcVisitor(UnionFind<StateId> *comps)
50       : comps_(comps),
51         cc_(0),
52         nstates_(0) { }
53 
~CcVisitor()54   ~CcVisitor() {
55     if (cc_)  // own comps_?
56       delete comps_;
57   }
58 
InitVisit(const Fst<A> & fst)59   void InitVisit(const Fst<A> &fst) { }
60 
InitState(StateId s,StateId root)61   bool InitState(StateId s, StateId root) {
62     ++nstates_;
63     if (comps_->FindSet(s) == kNoStateId)
64       comps_->MakeSet(s);
65     return true;
66   }
67 
WhiteArc(StateId s,const A & arc)68   bool WhiteArc(StateId s, const A &arc) {
69     comps_->MakeSet(arc.nextstate);
70     comps_->Union(s, arc.nextstate);
71     return true;
72   }
73 
GreyArc(StateId s,const A & arc)74   bool GreyArc(StateId s, const A &arc) {
75     comps_->Union(s, arc.nextstate);
76     return true;
77   }
78 
BlackArc(StateId s,const A & arc)79   bool BlackArc(StateId s, const A &arc) {
80     comps_->Union(s, arc.nextstate);
81     return true;
82   }
83 
FinishState(StateId s)84   void FinishState(StateId s) { }
85 
FinishVisit()86   void FinishVisit() {
87     if (cc_)
88       GetCcVector(cc_);
89   }
90 
91   // cc[i]: connected component number for state i.
92   // Returns number of components.
GetCcVector(vector<StateId> * cc)93   int GetCcVector(vector<StateId> *cc) {
94     cc->clear();
95     cc->resize(nstates_, kNoStateId);
96     StateId ncomp = 0;
97     for (StateId i = 0; i < nstates_; ++i) {
98       StateId rep = comps_->FindSet(i);
99       StateId &comp = (*cc)[rep];
100       if (comp == kNoStateId) {
101         comp = ncomp;
102         ++ncomp;
103       }
104       (*cc)[i] = comp;
105     }
106     return ncomp;
107   }
108 
109  private:
110   UnionFind<StateId> *comps_;   // Components
111   vector<StateId> *cc_;         // State's cc number
112   StateId nstates_;             // State count
113 };
114 
115 
116 // Finds and returns strongly-connected components, accessible and
117 // coaccessible states and related properties. Uses Tarjan's single
118 // DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer
119 // Algorithms", 189pp). Use with DfsVisit();
120 template <class A>
121 class SccVisitor {
122  public:
123   typedef A Arc;
124   typedef typename A::Weight Weight;
125   typedef typename A::StateId StateId;
126 
127   // scc[i]: strongly-connected component number for state i.
128   //   SCC numbers will be in topological order for acyclic input.
129   // access[i]: accessibility of state i.
130   // coaccess[i]: coaccessibility of state i.
131   // Any of above can be NULL.
132   // props: related property bits (cyclicity, initial cyclicity,
133   //   accessibility, coaccessibility) set/cleared (o.w. unchanged).
SccVisitor(vector<StateId> * scc,vector<bool> * access,vector<bool> * coaccess,uint64 * props)134   SccVisitor(vector<StateId> *scc, vector<bool> *access,
135              vector<bool> *coaccess, uint64 *props)
136       : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {}
SccVisitor(uint64 * props)137   SccVisitor(uint64 *props)
138       : scc_(0), access_(0), coaccess_(0), props_(props) {}
139 
140   void InitVisit(const Fst<A> &fst);
141 
142   bool InitState(StateId s, StateId root);
143 
TreeArc(StateId s,const A & arc)144   bool TreeArc(StateId s, const A &arc) { return true; }
145 
BackArc(StateId s,const A & arc)146   bool BackArc(StateId s, const A &arc) {
147     StateId t = arc.nextstate;
148     if ((*dfnumber_)[t] < (*lowlink_)[s])
149       (*lowlink_)[s] = (*dfnumber_)[t];
150     if ((*coaccess_)[t])
151       (*coaccess_)[s] = true;
152     *props_ |= kCyclic;
153     *props_ &= ~kAcyclic;
154     if (arc.nextstate == start_) {
155       *props_ |= kInitialCyclic;
156       *props_ &= ~kInitialAcyclic;
157     }
158     return true;
159   }
160 
ForwardOrCrossArc(StateId s,const A & arc)161   bool ForwardOrCrossArc(StateId s, const A &arc) {
162     StateId t = arc.nextstate;
163     if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ &&
164         (*onstack_)[t] && (*dfnumber_)[t] < (*lowlink_)[s])
165       (*lowlink_)[s] = (*dfnumber_)[t];
166     if ((*coaccess_)[t])
167       (*coaccess_)[s] = true;
168     return true;
169   }
170 
171   void FinishState(StateId s, StateId p, const A *);
172 
FinishVisit()173   void FinishVisit() {
174     // Numbers SCC's in topological order when acyclic.
175     if (scc_)
176       for (StateId i = 0; i < scc_->size(); ++i)
177         (*scc_)[i] = nscc_ - 1 - (*scc_)[i];
178     if (coaccess_internal_)
179       delete coaccess_;
180     delete dfnumber_;
181     delete lowlink_;
182     delete onstack_;
183     delete scc_stack_;
184   }
185 
186  private:
187   vector<StateId> *scc_;        // State's scc number
188   vector<bool> *access_;        // State's accessibility
189   vector<bool> *coaccess_;      // State's coaccessibility
190   uint64 *props_;
191   const Fst<A> *fst_;
192   StateId start_;
193   StateId nstates_;             // State count
194   StateId nscc_;                // SCC count
195   bool coaccess_internal_;
196   vector<StateId> *dfnumber_;   // state discovery times
197   vector<StateId> *lowlink_;    // lowlink[s] == dfnumber[s] => SCC root
198   vector<bool> *onstack_;       // is a state on the SCC stack
199   vector<StateId> *scc_stack_;  // SCC stack (w/ random access)
200 };
201 
202 template <class A> inline
InitVisit(const Fst<A> & fst)203 void SccVisitor<A>::InitVisit(const Fst<A> &fst) {
204   if (scc_)
205     scc_->clear();
206   if (access_)
207     access_->clear();
208   if (coaccess_) {
209     coaccess_->clear();
210     coaccess_internal_ = false;
211   } else {
212     coaccess_ = new vector<bool>;
213     coaccess_internal_ = true;
214   }
215   *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible;
216   *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible);
217   fst_ = &fst;
218   start_ = fst.Start();
219   nstates_ = 0;
220   nscc_ = 0;
221   dfnumber_ = new vector<StateId>;
222   lowlink_ = new vector<StateId>;
223   onstack_ = new vector<bool>;
224   scc_stack_ = new vector<StateId>;
225 }
226 
227 template <class A> inline
InitState(StateId s,StateId root)228 bool SccVisitor<A>::InitState(StateId s, StateId root) {
229   scc_stack_->push_back(s);
230   while (dfnumber_->size() <= s) {
231     if (scc_)
232       scc_->push_back(-1);
233     if (access_)
234       access_->push_back(false);
235     coaccess_->push_back(false);
236     dfnumber_->push_back(-1);
237     lowlink_->push_back(-1);
238     onstack_->push_back(false);
239   }
240   (*dfnumber_)[s] = nstates_;
241   (*lowlink_)[s] = nstates_;
242   (*onstack_)[s] = true;
243   if (root == start_) {
244     if (access_)
245       (*access_)[s] = true;
246   } else {
247     if (access_)
248       (*access_)[s] = false;
249     *props_ |= kNotAccessible;
250     *props_ &= ~kAccessible;
251   }
252   ++nstates_;
253   return true;
254 }
255 
256 template <class A> inline
FinishState(StateId s,StateId p,const A *)257 void SccVisitor<A>::FinishState(StateId s, StateId p, const A *) {
258   if (fst_->Final(s) != Weight::Zero())
259     (*coaccess_)[s] = true;
260   if ((*dfnumber_)[s] == (*lowlink_)[s]) {  // root of new SCC
261     bool scc_coaccess = false;
262     size_t i = scc_stack_->size();
263     StateId t;
264     do {
265       t = (*scc_stack_)[--i];
266       if ((*coaccess_)[t])
267         scc_coaccess = true;
268     } while (s != t);
269     do {
270       t = scc_stack_->back();
271       if (scc_)
272         (*scc_)[t] = nscc_;
273       if (scc_coaccess)
274         (*coaccess_)[t] = true;
275       (*onstack_)[t] = false;
276       scc_stack_->pop_back();
277     } while (s != t);
278     if (!scc_coaccess) {
279       *props_ |= kNotCoAccessible;
280       *props_ &= ~kCoAccessible;
281     }
282     ++nscc_;
283   }
284   if (p != kNoStateId) {
285     if ((*coaccess_)[s])
286       (*coaccess_)[p] = true;
287     if ((*lowlink_)[s] < (*lowlink_)[p])
288       (*lowlink_)[p] = (*lowlink_)[s];
289   }
290 }
291 
292 
293 // Trims an FST, removing states and arcs that are not on successful
294 // paths. This version modifies its input.
295 //
296 // Complexity:
297 // - Time:  O(V + E)
298 // - Space: O(V + E)
299 // where V = # of states and E = # of arcs.
300 template<class Arc>
Connect(MutableFst<Arc> * fst)301 void Connect(MutableFst<Arc> *fst) {
302   typedef typename Arc::StateId StateId;
303 
304   vector<bool> access;
305   vector<bool> coaccess;
306   uint64 props = 0;
307   SccVisitor<Arc> scc_visitor(0, &access, &coaccess, &props);
308   DfsVisit(*fst, &scc_visitor);
309   vector<StateId> dstates;
310   for (StateId s = 0; s < access.size(); ++s)
311     if (!access[s] || !coaccess[s])
312       dstates.push_back(s);
313   fst->DeleteStates(dstates);
314   fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible);
315 }
316 
317 
318 // Returns an acyclic FST where each SCC in the input FST has been
319 // condensed to a single state with transitions between SCCs retained
320 // and within SCCs dropped.  Also returns the mapping from an input
321 // state 's' to an output state 'scc[s]'.
322 template<class Arc>
Condense(const Fst<Arc> & ifst,MutableFst<Arc> * ofst,vector<typename Arc::StateId> * scc)323 void Condense(const Fst<Arc> &ifst, MutableFst<Arc> *ofst,
324               vector<typename Arc::StateId> *scc) {
325   typedef typename Arc::StateId StateId;
326   typedef typename Arc::Weight Weight;
327 
328   ofst->DeleteStates();
329   uint64 props = 0;
330   SccVisitor<Arc> scc_visitor(scc, 0, 0, &props);
331   DfsVisit(ifst, &scc_visitor);
332   for (StateId s = 0; s < scc->size(); ++s) {
333     StateId c = (*scc)[s];
334     while (c >= ofst->NumStates())
335       ofst->AddState();
336     if (s == ifst.Start())
337       ofst->SetStart(c);
338     Weight final = ifst.Final(s);
339     if (final != Weight::Zero())
340       ofst->SetFinal(c, Plus(ofst->Final(c), final));
341     for (ArcIterator< Fst<Arc> > aiter(ifst, s);
342          !aiter.Done();
343          aiter.Next()) {
344       Arc arc = aiter.Value();
345       StateId nextc = (*scc)[arc.nextstate];
346       if (nextc != c) {
347         while (nextc >= ofst->NumStates())
348           ofst->AddState();
349         arc.nextstate = nextc;
350         ofst->AddArc(c, arc);
351       }
352     }
353   }
354   ofst->SetProperties(kAcyclic | kInitialAcyclic, kAcyclic | kInitialAcyclic);
355 }
356 
357 }  // namespace fst
358 
359 #endif  // FST_LIB_CONNECT_H__
360