1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions to find shortest paths in a PDT.
19 
20 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
21 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
22 
23 #include <stack>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27 
28 #include <fst/types.h>
29 #include <fst/log.h>
30 #include <fst/extensions/pdt/paren.h>
31 #include <fst/extensions/pdt/pdt.h>
32 #include <fst/shortest-path.h>
33 #include <unordered_map>
34 
35 namespace fst {
36 
37 template <class Arc, class Queue>
38 struct PdtShortestPathOptions {
39   bool keep_parentheses;
40   bool path_gc;
41 
42   explicit PdtShortestPathOptions(bool keep_parentheses = false,
43                                   bool path_gc = true)
keep_parenthesesPdtShortestPathOptions44       : keep_parentheses(keep_parentheses), path_gc(path_gc) {}
45 };
46 
47 namespace internal {
48 
49 // Flags for shortest path data.
50 
51 constexpr uint8 kPdtInited = 0x01;
52 constexpr uint8 kPdtFinal = 0x02;
53 constexpr uint8 kPdtMarked = 0x04;
54 
55 // Stores shortest path tree info Distance(), Parent(), and ArcParent()
56 // information keyed on two types:
57 //
58 // 1. SearchState: This is a usual node in a shortest path tree but:
59 //    a. is w.r.t a PDT search state (a pair of a PDT state and a "start" state,
60 //    either the PDT start state or the destination state of an open
61 //    parenthesis).
62 //    b. the Distance() is from this "start" state to the search state.
63 //    c. Parent().state is kNoLabel for the "start" state.
64 //
65 // 2. ParenSpec: This connects shortest path trees depending on the the
66 // parenthesis taken. Given the parenthesis spec:
67 //    a. the Distance() is from the Parent() "start" state to the parenthesis
68 //    destination state.
69 //    b. The ArcParent() is the parenthesis arc.
70 template <class Arc>
71 class PdtShortestPathData {
72  public:
73   using Label = typename Arc::Label;
74   using StateId = typename Arc::StateId;
75   using Weight = typename Arc::Weight;
76 
77   struct SearchState {
78     StateId state;  // PDT state.
79     StateId start;  // PDT paren "start" state.
80 
81     explicit SearchState(StateId s = kNoStateId, StateId t = kNoStateId)
stateSearchState82         : state(s), start(t) {}
83 
84     bool operator==(const SearchState &other) const {
85       if (&other == this) return true;
86       return other.state == state && other.start == start;
87     }
88   };
89 
90   // Specifies paren ID, source and dest "start" states of a paren. These are
91   // the "start" states of the respective sub-graphs.
92   struct ParenSpec {
93     explicit ParenSpec(Label paren_id = kNoLabel,
94                        StateId src_start = kNoStateId,
95                        StateId dest_start = kNoStateId)
paren_idParenSpec96         : paren_id(paren_id), src_start(src_start), dest_start(dest_start) {}
97 
98     Label paren_id;
99     StateId src_start;   // Sub-graph "start" state for paren source.
100     StateId dest_start;  // Sub-graph "start" state for paren dest.
101 
102     bool operator==(const ParenSpec &other) const {
103       if (&other == this) return true;
104       return (other.paren_id == paren_id &&
105               other.src_start == other.src_start &&
106               other.dest_start == dest_start);
107     }
108   };
109 
110   struct SearchData {
SearchDataSearchData111     SearchData()
112         : distance(Weight::Zero()),
113           parent(kNoStateId, kNoStateId),
114           paren_id(kNoLabel),
115           flags(0) {}
116 
117     Weight distance;     // Distance to this state from PDT "start" state.
118     SearchState parent;  // Parent state in shortest path tree.
119     int16 paren_id;      // If parent arc has paren, paren ID (or kNoLabel).
120     uint8 flags;         // First byte reserved for PdtShortestPathData use.
121   };
122 
PdtShortestPathData(bool gc)123   explicit PdtShortestPathData(bool gc)
124       : gc_(gc), nstates_(0), ngc_(0), finished_(false) {}
125 
~PdtShortestPathData()126   ~PdtShortestPathData() {
127     VLOG(1) << "opm size: " << paren_map_.size();
128     VLOG(1) << "# of search states: " << nstates_;
129     if (gc_) VLOG(1) << "# of GC'd search states: " << ngc_;
130   }
131 
Clear()132   void Clear() {
133     search_map_.clear();
134     search_multimap_.clear();
135     paren_map_.clear();
136     state_ = SearchState(kNoStateId, kNoStateId);
137     nstates_ = 0;
138     ngc_ = 0;
139   }
140 
141   // TODO(kbg): Currently copying SearchState and passing a const reference to
142   // ParenSpec. Benchmark to confirm this is the right thing to do.
143 
Distance(SearchState s)144   Weight Distance(SearchState s) const { return GetSearchData(s)->distance; }
145 
Distance(const ParenSpec & paren)146   Weight Distance(const ParenSpec &paren) const {
147     return GetSearchData(paren)->distance;
148   }
149 
Parent(SearchState s)150   SearchState Parent(SearchState s) const { return GetSearchData(s)->parent; }
151 
Parent(const ParenSpec & paren)152   SearchState Parent(const ParenSpec &paren) const {
153     return GetSearchData(paren)->parent;
154   }
155 
ParenId(SearchState s)156   Label ParenId(SearchState s) const { return GetSearchData(s)->paren_id; }
157 
Flags(SearchState s)158   uint8 Flags(SearchState s) const { return GetSearchData(s)->flags; }
159 
SetDistance(SearchState s,Weight weight)160   void SetDistance(SearchState s, Weight weight) {
161     GetSearchData(s)->distance = std::move(weight);
162   }
163 
SetDistance(const ParenSpec & paren,Weight weight)164   void SetDistance(const ParenSpec &paren, Weight weight) {
165     GetSearchData(paren)->distance = std::move(weight);
166   }
167 
SetParent(SearchState s,SearchState p)168   void SetParent(SearchState s, SearchState p) { GetSearchData(s)->parent = p; }
169 
SetParent(const ParenSpec & paren,SearchState p)170   void SetParent(const ParenSpec &paren, SearchState p) {
171     GetSearchData(paren)->parent = p;
172   }
173 
SetParenId(SearchState s,Label p)174   void SetParenId(SearchState s, Label p) {
175     if (p >= 32768) {
176       FSTERROR() << "PdtShortestPathData: Paren ID does not fit in an int16";
177     }
178     GetSearchData(s)->paren_id = p;
179   }
180 
SetFlags(SearchState s,uint8 f,uint8 mask)181   void SetFlags(SearchState s, uint8 f, uint8 mask) {
182     auto *data = GetSearchData(s);
183     data->flags &= ~mask;
184     data->flags |= f & mask;
185   }
186 
187   void GC(StateId s);
188 
Finish()189   void Finish() { finished_ = true; }
190 
191  private:
192   // Hash for search state.
193   struct SearchStateHash {
operatorSearchStateHash194     size_t operator()(const SearchState &s) const {
195       static constexpr auto prime = 7853;
196       return s.state + s.start * prime;
197     }
198   };
199 
200   // Hash for paren map.
201   struct ParenHash {
operatorParenHash202     size_t operator()(const ParenSpec &paren) const {
203       static constexpr auto prime0 = 7853;
204       static constexpr auto prime1 = 7867;
205       return paren.paren_id + paren.src_start * prime0 +
206              paren.dest_start * prime1;
207     }
208   };
209 
210   using SearchMap =
211       std::unordered_map<SearchState, SearchData, SearchStateHash>;
212 
213   using SearchMultimap = std::unordered_multimap<StateId, StateId>;
214 
215   // Hash map from paren spec to open paren data.
216   using ParenMap = std::unordered_map<ParenSpec, SearchData, ParenHash>;
217 
GetSearchData(SearchState s)218   SearchData *GetSearchData(SearchState s) const {
219     if (s == state_) return state_data_;
220     if (finished_) {
221       auto it = search_map_.find(s);
222       if (it == search_map_.end()) return &null_search_data_;
223       state_ = s;
224       return state_data_ = &(it->second);
225     } else {
226       state_ = s;
227       state_data_ = &search_map_[s];
228       if (!(state_data_->flags & kPdtInited)) {
229         ++nstates_;
230         if (gc_) search_multimap_.insert(std::make_pair(s.start, s.state));
231         state_data_->flags = kPdtInited;
232       }
233       return state_data_;
234     }
235   }
236 
GetSearchData(ParenSpec paren)237   SearchData *GetSearchData(ParenSpec paren) const {
238     if (paren == paren_) return paren_data_;
239     if (finished_) {
240       auto it = paren_map_.find(paren);
241       if (it == paren_map_.end()) return &null_search_data_;
242       paren_ = paren;
243       return state_data_ = &(it->second);
244     } else {
245       paren_ = paren;
246       return paren_data_ = &paren_map_[paren];
247     }
248   }
249 
250   mutable SearchMap search_map_;            // Maps from search state to data.
251   mutable SearchMultimap search_multimap_;  // Maps from "start" to subgraph.
252   mutable ParenMap paren_map_;              // Maps paren spec to search data.
253   mutable SearchState state_;               // Last state accessed.
254   mutable SearchData *state_data_;          // Last state data accessed.
255   mutable ParenSpec paren_;                 // Last paren spec accessed.
256   mutable SearchData *paren_data_;          // Last paren data accessed.
257   bool gc_;                                 // Allow GC?
258   mutable size_t nstates_;                  // Total number of search states.
259   size_t ngc_;                              // Number of GC'd search states.
260   mutable SearchData null_search_data_;     // Null search data.
261   bool finished_;                           // Read-only access when true.
262 
263   PdtShortestPathData(const PdtShortestPathData &) = delete;
264   PdtShortestPathData &operator=(const PdtShortestPathData &) = delete;
265 };
266 
267 // Deletes inaccessible search data from a given "start" (open paren dest)
268 // state. Assumes "final" (close paren source or PDT final) states have
269 // been flagged kPdtFinal.
270 template <class Arc>
GC(StateId start)271 void PdtShortestPathData<Arc>::GC(StateId start) {
272   if (!gc_) return;
273   std::vector<StateId> finals;
274   for (auto it = search_multimap_.find(start);
275        it != search_multimap_.end() && it->first == start; ++it) {
276     const SearchState s(it->second, start);
277     if (search_map_[s].flags & kPdtFinal) finals.push_back(s.state);
278   }
279   // Mark phase.
280   for (const auto state : finals) {
281     SearchState ss(state, start);
282     while (ss.state != kNoLabel) {
283       auto &sdata = search_map_[ss];
284       if (sdata.flags & kPdtMarked) break;
285       sdata.flags |= kPdtMarked;
286       const auto p = sdata.parent;
287       if (p.start != start && p.start != kNoLabel) {  // Entering sub-subgraph.
288         const ParenSpec paren(sdata.paren_id, ss.start, p.start);
289         ss = paren_map_[paren].parent;
290       } else {
291         ss = p;
292       }
293     }
294   }
295   // Sweep phase.
296   auto it = search_multimap_.find(start);
297   while (it != search_multimap_.end() && it->first == start) {
298     const SearchState s(it->second, start);
299     auto mit = search_map_.find(s);
300     const SearchData &data = mit->second;
301     if (!(data.flags & kPdtMarked)) {
302       search_map_.erase(mit);
303       ++ngc_;
304     }
305     search_multimap_.erase(it++);
306   }
307 }
308 
309 }  // namespace internal
310 
311 // This computes the single source shortest (balanced) path (SSSP) through a
312 // weighted PDT that has a bounded stack (i.e., is expandable as an FST). It is
313 // a generalization of the classic SSSP graph algorithm that removes a state s
314 // from a queue (defined by a user-provided queue type) and relaxes the
315 // destination states of transitions leaving s. In this PDT version, states that
316 // have entering open parentheses are treated as source states for a sub-graph
317 // SSSP problem with the shortest path up to the open parenthesis being first
318 // saved. When a close parenthesis is then encountered any balancing open
319 // parenthesis is examined for this saved information and multiplied back. In
320 // this way, each sub-graph is entered only once rather than repeatedly. If
321 // every state in the input PDT has the property that there is a unique "start"
322 // state for it with entering open parentheses, then this algorithm is quite
323 // straightforward. In general, this will not be the case, so the algorithm
324 // (implicitly) creates a new graph where each state is a pair of an original
325 // state and a possible parenthesis "start" state for that state.
326 template <class Arc, class Queue>
327 class PdtShortestPath {
328  public:
329   using Label = typename Arc::Label;
330   using StateId = typename Arc::StateId;
331   using Weight = typename Arc::Weight;
332 
333   using SpData = internal::PdtShortestPathData<Arc>;
334   using SearchState = typename SpData::SearchState;
335   using ParenSpec = typename SpData::ParenSpec;
336   using CloseSourceIterator =
337       typename internal::PdtBalanceData<Arc>::SetIterator;
338 
PdtShortestPath(const Fst<Arc> & ifst,const std::vector<std::pair<Label,Label>> & parens,const PdtShortestPathOptions<Arc,Queue> & opts)339   PdtShortestPath(const Fst<Arc> &ifst,
340                   const std::vector<std::pair<Label, Label>> &parens,
341                   const PdtShortestPathOptions<Arc, Queue> &opts)
342       : ifst_(ifst.Copy()),
343         parens_(parens),
344         keep_parens_(opts.keep_parentheses),
345         start_(ifst.Start()),
346         sp_data_(opts.path_gc),
347         error_(false) {
348     // TODO(kbg): Make this a compile-time static_assert once:
349     // 1) All weight properties are made constexpr for all weight types.
350     // 2) We have a pleasant way to "deregister" this oepration for non-path
351     //    semirings so an informative error message is produced. The best
352     //    solution will probably involve some kind of SFINAE magic.
353     if ((Weight::Properties() & (kPath | kRightSemiring)) !=
354         (kPath | kRightSemiring)) {
355       FSTERROR() << "PdtShortestPath: Weight needs to have the path"
356                  << " property and be right distributive: " << Weight::Type();
357       error_ = true;
358     }
359     for (Label i = 0; i < parens.size(); ++i) {
360       const auto &pair = parens[i];
361       paren_map_[pair.first] = i;
362       paren_map_[pair.second] = i;
363     }
364   }
365 
~PdtShortestPath()366   ~PdtShortestPath() {
367     VLOG(1) << "# of input states: " << CountStates(*ifst_);
368     VLOG(1) << "# of enqueued: " << nenqueued_;
369     VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
370   }
371 
ShortestPath(MutableFst<Arc> * ofst)372   void ShortestPath(MutableFst<Arc> *ofst) {
373     Init(ofst);
374     GetDistance(start_);
375     GetPath();
376     sp_data_.Finish();
377     if (error_) ofst->SetProperties(kError, kError);
378   }
379 
GetShortestPathData()380   const internal::PdtShortestPathData<Arc> &GetShortestPathData() const {
381     return sp_data_;
382   }
383 
GetBalanceData()384   internal::PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
385 
386  public:
387   // Hash multimap from close paren label to an paren arc.
388   using CloseParenMultimap =
389       std::unordered_multimap<internal::ParenState<Arc>, Arc,
390                               typename internal::ParenState<Arc>::Hash>;
391 
GetCloseParenMultimap()392   const CloseParenMultimap &GetCloseParenMultimap() const {
393     return close_paren_multimap_;
394   }
395 
396  private:
397   void Init(MutableFst<Arc> *ofst);
398 
399   void GetDistance(StateId start);
400 
401   void ProcFinal(SearchState s);
402 
403   void ProcArcs(SearchState s);
404 
405   void ProcOpenParen(Label paren_id, SearchState s, StateId nexstate,
406                      const Weight &weight);
407 
408   void ProcCloseParen(Label paren_id, SearchState s, const Weight &weight);
409 
410   void ProcNonParen(SearchState s, StateId nextstate, const Weight &weight);
411 
412   void Relax(SearchState s, SearchState t, StateId nextstate,
413              const Weight &weight, Label paren_id);
414 
415   void Enqueue(SearchState d);
416 
417   void GetPath();
418 
419   Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
420 
421   std::unique_ptr<Fst<Arc>> ifst_;
422   MutableFst<Arc> *ofst_;
423   const std::vector<std::pair<Label, Label>> &parens_;
424   bool keep_parens_;
425   Queue *state_queue_;
426   StateId start_;
427   Weight fdistance_;
428   SearchState f_parent_;
429   SpData sp_data_;
430   std::unordered_map<Label, Label> paren_map_;
431   CloseParenMultimap close_paren_multimap_;
432   internal::PdtBalanceData<Arc> balance_data_;
433   ssize_t nenqueued_;
434   bool error_;
435 
436   static constexpr uint8 kEnqueued = 0x10;
437   static constexpr uint8 kExpanded = 0x20;
438   static constexpr uint8 kFinished = 0x40;
439 
440   static const Arc kNoArc;
441 };
442 
443 template <class Arc, class Queue>
Init(MutableFst<Arc> * ofst)444 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
445   ofst_ = ofst;
446   ofst->DeleteStates();
447   ofst->SetInputSymbols(ifst_->InputSymbols());
448   ofst->SetOutputSymbols(ifst_->OutputSymbols());
449   if (ifst_->Start() == kNoStateId) return;
450   fdistance_ = Weight::Zero();
451   f_parent_ = SearchState(kNoStateId, kNoStateId);
452   sp_data_.Clear();
453   close_paren_multimap_.clear();
454   balance_data_.Clear();
455   nenqueued_ = 0;
456   // Finds open parens per destination state and close parens per source state.
457   for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
458     const auto s = siter.Value();
459     for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
460       const auto &arc = aiter.Value();
461       const auto it = paren_map_.find(arc.ilabel);
462       if (it != paren_map_.end()) {  // Is a paren?
463         const auto paren_id = it->second;
464         if (arc.ilabel == parens_[paren_id].first) {  // Open paren.
465           balance_data_.OpenInsert(paren_id, arc.nextstate);
466         } else {  // Close paren.
467           const internal::ParenState<Arc> paren_state(paren_id, s);
468           close_paren_multimap_.emplace(paren_state, arc);
469         }
470       }
471     }
472   }
473 }
474 
475 // Computes the shortest distance stored in a recursive way. Each sub-graph
476 // (i.e., different paren "start" state) begins with weight One().
477 template <class Arc, class Queue>
GetDistance(StateId start)478 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
479   if (start == kNoStateId) return;
480   Queue state_queue;
481   state_queue_ = &state_queue;
482   const SearchState q(start, start);
483   Enqueue(q);
484   sp_data_.SetDistance(q, Weight::One());
485   while (!state_queue_->Empty()) {
486     const auto state = state_queue_->Head();
487     state_queue_->Dequeue();
488     const SearchState s(state, start);
489     sp_data_.SetFlags(s, 0, kEnqueued);
490     ProcFinal(s);
491     ProcArcs(s);
492     sp_data_.SetFlags(s, kExpanded, kExpanded);
493   }
494   sp_data_.SetFlags(q, kFinished, kFinished);
495   balance_data_.FinishInsert(start);
496   sp_data_.GC(start);
497 }
498 
499 // Updates best complete path.
500 template <class Arc, class Queue>
ProcFinal(SearchState s)501 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
502   if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
503     const auto weight = Times(sp_data_.Distance(s), ifst_->Final(s.state));
504     if (fdistance_ != Plus(fdistance_, weight)) {
505       if (f_parent_.state != kNoStateId) {
506         sp_data_.SetFlags(f_parent_, 0, internal::kPdtFinal);
507       }
508       sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
509       fdistance_ = Plus(fdistance_, weight);
510       f_parent_ = s;
511     }
512   }
513 }
514 
515 // Processes all arcs leaving the state s.
516 template <class Arc, class Queue>
ProcArcs(SearchState s)517 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
518   for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
519        aiter.Next()) {
520     const auto &arc = aiter.Value();
521     const auto weight = Times(sp_data_.Distance(s), arc.weight);
522     const auto it = paren_map_.find(arc.ilabel);
523     if (it != paren_map_.end()) {  // Is a paren?
524       const auto paren_id = it->second;
525       if (arc.ilabel == parens_[paren_id].first) {
526         ProcOpenParen(paren_id, s, arc.nextstate, weight);
527       } else {
528         ProcCloseParen(paren_id, s, weight);
529       }
530     } else {
531       ProcNonParen(s, arc.nextstate, weight);
532     }
533   }
534 }
535 
536 // Saves the shortest path info for reaching this parenthesis and starts a new
537 // SSSP in the sub-graph pointed to by the parenthesis if previously unvisited.
538 // Otherwise it finds any previously encountered closing parentheses and relaxes
539 // them using the recursively stored shortest distance to them.
540 template <class Arc, class Queue>
ProcOpenParen(Label paren_id,SearchState s,StateId nextstate,const Weight & weight)541 inline void PdtShortestPath<Arc, Queue>::ProcOpenParen(Label paren_id,
542                                                        SearchState s,
543                                                        StateId nextstate,
544                                                        const Weight &weight) {
545   const SearchState d(nextstate, nextstate);
546   const ParenSpec paren(paren_id, s.start, d.start);
547   const auto pdist = sp_data_.Distance(paren);
548   if (pdist != Plus(pdist, weight)) {
549     sp_data_.SetDistance(paren, weight);
550     sp_data_.SetParent(paren, s);
551     const auto dist = sp_data_.Distance(d);
552     if (dist == Weight::Zero()) {
553       auto *state_queue = state_queue_;
554       GetDistance(d.start);
555       state_queue_ = state_queue;
556     } else if (!(sp_data_.Flags(d) & kFinished)) {
557       FSTERROR()
558           << "PdtShortestPath: open parenthesis recursion: not bounded stack";
559       error_ = true;
560     }
561     for (auto set_iter = balance_data_.Find(paren_id, nextstate);
562          !set_iter.Done(); set_iter.Next()) {
563       const SearchState cpstate(set_iter.Element(), d.start);
564       const internal::ParenState<Arc> paren_state(paren_id, cpstate.state);
565       for (auto cpit = close_paren_multimap_.find(paren_state);
566            cpit != close_paren_multimap_.end() && paren_state == cpit->first;
567            ++cpit) {
568         const auto &cparc = cpit->second;
569         const auto cpw =
570             Times(weight, Times(sp_data_.Distance(cpstate), cparc.weight));
571         Relax(cpstate, s, cparc.nextstate, cpw, paren_id);
572       }
573     }
574   }
575 }
576 
577 // Saves the correspondence between each closing parenthesis and its balancing
578 // open parenthesis info. Relaxes any close parenthesis destination state that
579 // has a balancing previously encountered open parenthesis.
580 template <class Arc, class Queue>
ProcCloseParen(Label paren_id,SearchState s,const Weight & weight)581 inline void PdtShortestPath<Arc, Queue>::ProcCloseParen(Label paren_id,
582                                                         SearchState s,
583                                                         const Weight &weight) {
584   const internal::ParenState<Arc> paren_state(paren_id, s.start);
585   if (!(sp_data_.Flags(s) & kExpanded)) {
586     balance_data_.CloseInsert(paren_id, s.start, s.state);
587     sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
588   }
589 }
590 
591 // Classical relaxation for non-parentheses.
592 template <class Arc, class Queue>
ProcNonParen(SearchState s,StateId nextstate,const Weight & weight)593 inline void PdtShortestPath<Arc, Queue>::ProcNonParen(SearchState s,
594                                                       StateId nextstate,
595                                                       const Weight &weight) {
596   Relax(s, s, nextstate, weight, kNoLabel);
597 }
598 
599 // Classical relaxation on the search graph for an arc with destination state
600 // nexstate from state s. State t is in the same sub-graph as nextstate (i.e.,
601 // has the same paren "start").
602 template <class Arc, class Queue>
Relax(SearchState s,SearchState t,StateId nextstate,const Weight & weight,Label paren_id)603 inline void PdtShortestPath<Arc, Queue>::Relax(SearchState s, SearchState t,
604                                                StateId nextstate,
605                                                const Weight &weight,
606                                                Label paren_id) {
607   const SearchState d(nextstate, t.start);
608   Weight dist = sp_data_.Distance(d);
609   if (dist != Plus(dist, weight)) {
610     sp_data_.SetParent(d, s);
611     sp_data_.SetParenId(d, paren_id);
612     sp_data_.SetDistance(d, Plus(dist, weight));
613     Enqueue(d);
614   }
615 }
616 
617 template <class Arc, class Queue>
Enqueue(SearchState s)618 inline void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
619   if (!(sp_data_.Flags(s) & kEnqueued)) {
620     state_queue_->Enqueue(s.state);
621     sp_data_.SetFlags(s, kEnqueued, kEnqueued);
622     ++nenqueued_;
623   } else {
624     state_queue_->Update(s.state);
625   }
626 }
627 
628 // Follows parent pointers to find the shortest path. A stack is used since the
629 // shortest distance is stored recursively.
630 template <class Arc, class Queue>
GetPath()631 void PdtShortestPath<Arc, Queue>::GetPath() {
632   SearchState s = f_parent_;
633   SearchState d = SearchState(kNoStateId, kNoStateId);
634   StateId s_p = kNoStateId;
635   StateId d_p = kNoStateId;
636   auto arc = kNoArc;
637   Label paren_id = kNoLabel;
638   std::stack<ParenSpec> paren_stack;
639   while (s.state != kNoStateId) {
640     d_p = s_p;
641     s_p = ofst_->AddState();
642     if (d.state == kNoStateId) {
643       ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
644     } else {
645       if (paren_id != kNoLabel) {                     // Paren?
646         if (arc.ilabel == parens_[paren_id].first) {  // Open paren?
647           paren_stack.pop();
648         } else {  // Close paren?
649           const ParenSpec paren(paren_id, d.start, s.start);
650           paren_stack.push(paren);
651         }
652         if (!keep_parens_) arc.ilabel = arc.olabel = 0;
653       }
654       arc.nextstate = d_p;
655       ofst_->AddArc(s_p, arc);
656     }
657     d = s;
658     s = sp_data_.Parent(d);
659     paren_id = sp_data_.ParenId(d);
660     if (s.state != kNoStateId) {
661       arc = GetPathArc(s, d, paren_id, false);
662     } else if (!paren_stack.empty()) {
663       const ParenSpec paren = paren_stack.top();
664       s = sp_data_.Parent(paren);
665       paren_id = paren.paren_id;
666       arc = GetPathArc(s, d, paren_id, true);
667     }
668   }
669   ofst_->SetStart(s_p);
670   ofst_->SetProperties(
671       ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
672       kFstProperties);
673 }
674 
675 // Finds transition with least weight between two states with label matching
676 // paren_id and open/close paren type or a non-paren if kNoLabel.
677 template <class Arc, class Queue>
GetPathArc(SearchState s,SearchState d,Label paren_id,bool open_paren)678 Arc PdtShortestPath<Arc, Queue>::GetPathArc(SearchState s, SearchState d,
679                                             Label paren_id, bool open_paren) {
680   auto path_arc = kNoArc;
681   for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
682        aiter.Next()) {
683     const auto &arc = aiter.Value();
684     if (arc.nextstate != d.state) continue;
685     Label arc_paren_id = kNoLabel;
686     const auto it = paren_map_.find(arc.ilabel);
687     if (it != paren_map_.end()) {
688       arc_paren_id = it->second;
689       bool arc_open_paren = (arc.ilabel == parens_[arc_paren_id].first);
690       if (arc_open_paren != open_paren) continue;
691     }
692     if (arc_paren_id != paren_id) continue;
693     if (arc.weight == Plus(arc.weight, path_arc.weight)) path_arc = arc;
694   }
695   if (path_arc.nextstate == kNoStateId) {
696     FSTERROR() << "PdtShortestPath::GetPathArc: Failed to find arc";
697     error_ = true;
698   }
699   return path_arc;
700 }
701 
702 template <class Arc, class Queue>
703 const Arc PdtShortestPath<Arc, Queue>::kNoArc = Arc(kNoLabel, kNoLabel,
704                                                     Weight::Zero(), kNoStateId);
705 
706 // Functional variants.
707 
708 template <class Arc, class Queue>
ShortestPath(const Fst<Arc> & ifst,const std::vector<std::pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst,const PdtShortestPathOptions<Arc,Queue> & opts)709 void ShortestPath(
710     const Fst<Arc> &ifst,
711     const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
712         &parens,
713     MutableFst<Arc> *ofst, const PdtShortestPathOptions<Arc, Queue> &opts) {
714   PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
715   psp.ShortestPath(ofst);
716 }
717 
718 template <class Arc>
ShortestPath(const Fst<Arc> & ifst,const std::vector<std::pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst)719 void ShortestPath(
720     const Fst<Arc> &ifst,
721     const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
722         &parens,
723     MutableFst<Arc> *ofst) {
724   using Q = FifoQueue<typename Arc::StateId>;
725   const PdtShortestPathOptions<Arc, Q> opts;
726   PdtShortestPath<Arc, Q> psp(ifst, parens, opts);
727   psp.ShortestPath(ofst);
728 }
729 
730 }  // namespace fst
731 
732 #endif  // FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
733