1 // shortest-distance.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Functions and classes to find shortest distance in an FST.
20 
21 #ifndef FST_LIB_SHORTEST_DISTANCE_H__
22 #define FST_LIB_SHORTEST_DISTANCE_H__
23 
24 #include <deque>
25 using std::deque;
26 #include <vector>
27 using std::vector;
28 
29 #include <fst/arcfilter.h>
30 #include <fst/cache.h>
31 #include <fst/queue.h>
32 #include <fst/reverse.h>
33 #include <fst/test-properties.h>
34 
35 
36 namespace fst {
37 
38 template <class Arc, class Queue, class ArcFilter>
39 struct ShortestDistanceOptions {
40   typedef typename Arc::StateId StateId;
41 
42   Queue *state_queue;    // Queue discipline used; owned by caller
43   ArcFilter arc_filter;  // Arc filter (e.g., limit to only epsilon graph)
44   StateId source;        // If kNoStateId, use the Fst's initial state
45   float delta;           // Determines the degree of convergence required
46   bool first_path;       // For a semiring with the path property (o.w.
47                          // undefined), compute the shortest-distances along
48                          // along the first path to a final state found
49                          // by the algorithm. That path is the shortest-path
50                          // only if the FST has a unique final state (or all
51                          // the final states have the same final weight), the
52                          // queue discipline is shortest-first and all the
53                          // weights in the FST are between One() and Zero()
54                          // according to NaturalLess.
55 
56   ShortestDistanceOptions(Queue *q, ArcFilter filt, StateId src = kNoStateId,
57                           float d = kDelta)
state_queueShortestDistanceOptions58       : state_queue(q), arc_filter(filt), source(src), delta(d),
59         first_path(false) {}
60 };
61 
62 
63 // Computation state of the shortest-distance algorithm. Reusable
64 // information is maintained across calls to member function
65 // ShortestDistance(source) when 'retain' is true for improved
66 // efficiency when calling multiple times from different source states
67 // (e.g., in epsilon removal). Contrary to usual conventions, 'fst'
68 // may not be freed before this class. Vector 'distance' should not be
69 // modified by the user between these calls.
70 // The Error() method returns true if an error was encountered.
71 template<class Arc, class Queue, class ArcFilter>
72 class ShortestDistanceState {
73  public:
74   typedef typename Arc::StateId StateId;
75   typedef typename Arc::Weight Weight;
76 
ShortestDistanceState(const Fst<Arc> & fst,vector<Weight> * distance,const ShortestDistanceOptions<Arc,Queue,ArcFilter> & opts,bool retain)77   ShortestDistanceState(
78       const Fst<Arc> &fst,
79       vector<Weight> *distance,
80       const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts,
81       bool retain)
82       : fst_(fst), distance_(distance), state_queue_(opts.state_queue),
83         arc_filter_(opts.arc_filter), delta_(opts.delta),
84         first_path_(opts.first_path), retain_(retain), source_id_(0),
85         error_(false) {
86     distance_->clear();
87   }
88 
~ShortestDistanceState()89   ~ShortestDistanceState() {}
90 
91   void ShortestDistance(StateId source);
92 
Error()93   bool Error() const { return error_; }
94 
95  private:
96   const Fst<Arc> &fst_;
97   vector<Weight> *distance_;
98   Queue *state_queue_;
99   ArcFilter arc_filter_;
100   float delta_;
101   bool first_path_;
102   bool retain_;               // Retain and reuse information across calls
103 
104   vector<Weight> rdistance_;  // Relaxation distance.
105   vector<bool> enqueued_;     // Is state enqueued?
106   vector<StateId> sources_;   // Source ID for ith state in 'distance_',
107                               //  'rdistance_', and 'enqueued_' if retained.
108   StateId source_id_;         // Unique ID characterizing each call to SD
109 
110   bool error_;
111 };
112 
113 // Compute the shortest distance. If 'source' is kNoStateId, use
114 // the initial state of the Fst.
115 template <class Arc, class Queue, class ArcFilter>
ShortestDistance(StateId source)116 void ShortestDistanceState<Arc, Queue, ArcFilter>::ShortestDistance(
117     StateId source) {
118   if (fst_.Start() == kNoStateId) {
119     if (fst_.Properties(kError, false)) error_ = true;
120     return;
121   }
122 
123   if (!(Weight::Properties() & kRightSemiring)) {
124     FSTERROR() << "ShortestDistance: Weight needs to be right distributive: "
125                << Weight::Type();
126     error_ = true;
127     return;
128   }
129 
130   if (first_path_ && !(Weight::Properties() & kPath)) {
131     FSTERROR() << "ShortestDistance: first_path option disallowed when "
132                << "Weight does not have the path property: "
133                << Weight::Type();
134     error_ = true;
135     return;
136   }
137 
138   state_queue_->Clear();
139 
140   if (!retain_) {
141     distance_->clear();
142     rdistance_.clear();
143     enqueued_.clear();
144   }
145 
146   if (source == kNoStateId)
147     source = fst_.Start();
148 
149   while (distance_->size() <= source) {
150     distance_->push_back(Weight::Zero());
151     rdistance_.push_back(Weight::Zero());
152     enqueued_.push_back(false);
153   }
154   if (retain_) {
155     while (sources_.size() <= source)
156       sources_.push_back(kNoStateId);
157     sources_[source] = source_id_;
158   }
159   (*distance_)[source] = Weight::One();
160   rdistance_[source] = Weight::One();
161   enqueued_[source] = true;
162 
163   state_queue_->Enqueue(source);
164 
165   while (!state_queue_->Empty()) {
166     StateId s = state_queue_->Head();
167     state_queue_->Dequeue();
168     while (distance_->size() <= s) {
169       distance_->push_back(Weight::Zero());
170       rdistance_.push_back(Weight::Zero());
171       enqueued_.push_back(false);
172     }
173     if (first_path_ && (fst_.Final(s) != Weight::Zero()))
174       break;
175     enqueued_[s] = false;
176     Weight r = rdistance_[s];
177     rdistance_[s] = Weight::Zero();
178     for (ArcIterator< Fst<Arc> > aiter(fst_, s);
179          !aiter.Done();
180          aiter.Next()) {
181       const Arc &arc = aiter.Value();
182       if (!arc_filter_(arc))
183         continue;
184       while (distance_->size() <= arc.nextstate) {
185         distance_->push_back(Weight::Zero());
186         rdistance_.push_back(Weight::Zero());
187         enqueued_.push_back(false);
188       }
189       if (retain_) {
190         while (sources_.size() <= arc.nextstate)
191           sources_.push_back(kNoStateId);
192         if (sources_[arc.nextstate] != source_id_) {
193           (*distance_)[arc.nextstate] = Weight::Zero();
194           rdistance_[arc.nextstate] = Weight::Zero();
195           enqueued_[arc.nextstate] = false;
196           sources_[arc.nextstate] = source_id_;
197         }
198       }
199       Weight &nd = (*distance_)[arc.nextstate];
200       Weight &nr = rdistance_[arc.nextstate];
201       Weight w = Times(r, arc.weight);
202       if (!ApproxEqual(nd, Plus(nd, w), delta_)) {
203         nd = Plus(nd, w);
204         nr = Plus(nr, w);
205         if (!nd.Member() || !nr.Member()) {
206           error_ = true;
207           return;
208         }
209         if (!enqueued_[arc.nextstate]) {
210           state_queue_->Enqueue(arc.nextstate);
211           enqueued_[arc.nextstate] = true;
212         } else {
213           state_queue_->Update(arc.nextstate);
214         }
215       }
216     }
217   }
218   ++source_id_;
219   if (fst_.Properties(kError, false)) error_ = true;
220 }
221 
222 
223 // Shortest-distance algorithm: this version allows fine control
224 // via the options argument. See below for a simpler interface.
225 //
226 // This computes the shortest distance from the 'opts.source' state to
227 // each visited state S and stores the value in the 'distance' vector.
228 // An unvisited state S has distance Zero(), which will be stored in
229 // the 'distance' vector if S is less than the maximum visited state.
230 // The state queue discipline, arc filter, and convergence delta are
231 // taken in the options argument.
232 // The 'distance' vector will contain a unique element for which
233 // Member() is false if an error was encountered.
234 //
235 // The weights must must be right distributive and k-closed (i.e., 1 +
236 // x + x^2 + ... + x^(k +1) = 1 + x + x^2 + ... + x^k).
237 //
238 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
239 // Shortest-Distance Problems", Journal of Automata, Languages and
240 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
241 // depends on the properties of the semiring and the queue discipline
242 // used. Refer to the paper for more details.
243 template<class Arc, class Queue, class ArcFilter>
ShortestDistance(const Fst<Arc> & fst,vector<typename Arc::Weight> * distance,const ShortestDistanceOptions<Arc,Queue,ArcFilter> & opts)244 void ShortestDistance(
245     const Fst<Arc> &fst,
246     vector<typename Arc::Weight> *distance,
247     const ShortestDistanceOptions<Arc, Queue, ArcFilter> &opts) {
248 
249   ShortestDistanceState<Arc, Queue, ArcFilter>
250     sd_state(fst, distance, opts, false);
251   sd_state.ShortestDistance(opts.source);
252   if (sd_state.Error()) {
253     distance->clear();
254     distance->resize(1, Arc::Weight::NoWeight());
255   }
256 }
257 
258 // Shortest-distance algorithm: simplified interface. See above for a
259 // version that allows finer control.
260 //
261 // If 'reverse' is false, this computes the shortest distance from the
262 // initial state to each state S and stores the value in the
263 // 'distance' vector. If 'reverse' is true, this computes the shortest
264 // distance from each state to the final states.  An unvisited state S
265 // has distance Zero(), which will be stored in the 'distance' vector
266 // if S is less than the maximum visited state.  The state queue
267 // discipline is automatically-selected.
268 // The 'distance' vector will contain a unique element for which
269 // Member() is false if an error was encountered.
270 //
271 // The weights must must be right (left) distributive if reverse is
272 // false (true) and k-closed (i.e., 1 + x + x^2 + ... + x^(k +1) = 1 +
273 // x + x^2 + ... + x^k).
274 //
275 // Arc weights must satisfy the property that the sum of the weights of one or
276 // more paths from some state S to T is never Zero(). In particular, arc weights
277 // are never Zero().
278 //
279 // The algorithm is from Mohri, "Semiring Framweork and Algorithms for
280 // Shortest-Distance Problems", Journal of Automata, Languages and
281 // Combinatorics 7(3):321-350, 2002. The complexity of algorithm
282 // depends on the properties of the semiring and the queue discipline
283 // used. Refer to the paper for more details.
284 template<class Arc>
285 void ShortestDistance(const Fst<Arc> &fst,
286                       vector<typename Arc::Weight> *distance,
287                       bool reverse = false,
288                       float delta = kDelta) {
289   typedef typename Arc::StateId StateId;
290   typedef typename Arc::Weight Weight;
291 
292   if (!reverse) {
293     AnyArcFilter<Arc> arc_filter;
294     AutoQueue<StateId> state_queue(fst, distance, arc_filter);
295     ShortestDistanceOptions< Arc, AutoQueue<StateId>, AnyArcFilter<Arc> >
296       opts(&state_queue, arc_filter);
297     opts.delta = delta;
298     ShortestDistance(fst, distance, opts);
299   } else {
300     typedef ReverseArc<Arc> ReverseArc;
301     typedef typename ReverseArc::Weight ReverseWeight;
302     AnyArcFilter<ReverseArc> rarc_filter;
303     VectorFst<ReverseArc> rfst;
304     Reverse(fst, &rfst);
305     vector<ReverseWeight> rdistance;
306     AutoQueue<StateId> state_queue(rfst, &rdistance, rarc_filter);
307     ShortestDistanceOptions< ReverseArc, AutoQueue<StateId>,
308       AnyArcFilter<ReverseArc> >
309       ropts(&state_queue, rarc_filter);
310     ropts.delta = delta;
311     ShortestDistance(rfst, &rdistance, ropts);
312     distance->clear();
313     if (rdistance.size() == 1 && !rdistance[0].Member()) {
314       distance->resize(1, Arc::Weight::NoWeight());
315       return;
316     }
317     while (distance->size() < rdistance.size() - 1)
318       distance->push_back(rdistance[distance->size() + 1].Reverse());
319   }
320 }
321 
322 
323 // Return the sum of the weight of all successful paths in an FST, i.e.,
324 // the shortest-distance from the initial state to the final states.
325 // Returns a weight such that Member() is false if an error was encountered.
326 template <class Arc>
327 typename Arc::Weight ShortestDistance(const Fst<Arc> &fst, float delta = kDelta) {
328   typedef typename Arc::Weight Weight;
329   typedef typename Arc::StateId StateId;
330   vector<Weight> distance;
331   if (Weight::Properties() & kRightSemiring) {
332     ShortestDistance(fst, &distance, false, delta);
333     if (distance.size() == 1 && !distance[0].Member())
334       return Arc::Weight::NoWeight();
335     Weight sum = Weight::Zero();
336     for (StateId s = 0; s < distance.size(); ++s)
337       sum = Plus(sum, Times(distance[s], fst.Final(s)));
338     return sum;
339   } else {
340     ShortestDistance(fst, &distance, true, delta);
341     StateId s = fst.Start();
342     if (distance.size() == 1 && !distance[0].Member())
343       return Arc::Weight::NoWeight();
344     return s != kNoStateId && s < distance.size() ?
345         distance[s] : Weight::Zero();
346   }
347 }
348 
349 
350 }  // namespace fst
351 
352 #endif  // FST_LIB_SHORTEST_DISTANCE_H__
353