1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Functions to find shortest paths in a PDT.
19
20 #ifndef FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
21 #define FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
22
23 #include <stack>
24 #include <unordered_map>
25 #include <utility>
26 #include <vector>
27
28 #include <fst/types.h>
29 #include <fst/log.h>
30 #include <fst/extensions/pdt/paren.h>
31 #include <fst/extensions/pdt/pdt.h>
32 #include <fst/shortest-path.h>
33 #include <unordered_map>
34
35 namespace fst {
36
37 template <class Arc, class Queue>
38 struct PdtShortestPathOptions {
39 bool keep_parentheses;
40 bool path_gc;
41
42 explicit PdtShortestPathOptions(bool keep_parentheses = false,
43 bool path_gc = true)
keep_parenthesesPdtShortestPathOptions44 : keep_parentheses(keep_parentheses), path_gc(path_gc) {}
45 };
46
47 namespace internal {
48
49 // Flags for shortest path data.
50
51 constexpr uint8 kPdtInited = 0x01;
52 constexpr uint8 kPdtFinal = 0x02;
53 constexpr uint8 kPdtMarked = 0x04;
54
55 // Stores shortest path tree info Distance(), Parent(), and ArcParent()
56 // information keyed on two types:
57 //
58 // 1. SearchState: This is a usual node in a shortest path tree but:
59 // a. is w.r.t a PDT search state (a pair of a PDT state and a "start" state,
60 // either the PDT start state or the destination state of an open
61 // parenthesis).
62 // b. the Distance() is from this "start" state to the search state.
63 // c. Parent().state is kNoLabel for the "start" state.
64 //
65 // 2. ParenSpec: This connects shortest path trees depending on the the
66 // parenthesis taken. Given the parenthesis spec:
67 // a. the Distance() is from the Parent() "start" state to the parenthesis
68 // destination state.
69 // b. The ArcParent() is the parenthesis arc.
70 template <class Arc>
71 class PdtShortestPathData {
72 public:
73 using Label = typename Arc::Label;
74 using StateId = typename Arc::StateId;
75 using Weight = typename Arc::Weight;
76
77 struct SearchState {
78 StateId state; // PDT state.
79 StateId start; // PDT paren "start" state.
80
81 explicit SearchState(StateId s = kNoStateId, StateId t = kNoStateId)
stateSearchState82 : state(s), start(t) {}
83
84 bool operator==(const SearchState &other) const {
85 if (&other == this) return true;
86 return other.state == state && other.start == start;
87 }
88 };
89
90 // Specifies paren ID, source and dest "start" states of a paren. These are
91 // the "start" states of the respective sub-graphs.
92 struct ParenSpec {
93 explicit ParenSpec(Label paren_id = kNoLabel,
94 StateId src_start = kNoStateId,
95 StateId dest_start = kNoStateId)
paren_idParenSpec96 : paren_id(paren_id), src_start(src_start), dest_start(dest_start) {}
97
98 Label paren_id;
99 StateId src_start; // Sub-graph "start" state for paren source.
100 StateId dest_start; // Sub-graph "start" state for paren dest.
101
102 bool operator==(const ParenSpec &other) const {
103 if (&other == this) return true;
104 return (other.paren_id == paren_id &&
105 other.src_start == other.src_start &&
106 other.dest_start == dest_start);
107 }
108 };
109
110 struct SearchData {
SearchDataSearchData111 SearchData()
112 : distance(Weight::Zero()),
113 parent(kNoStateId, kNoStateId),
114 paren_id(kNoLabel),
115 flags(0) {}
116
117 Weight distance; // Distance to this state from PDT "start" state.
118 SearchState parent; // Parent state in shortest path tree.
119 int16 paren_id; // If parent arc has paren, paren ID (or kNoLabel).
120 uint8 flags; // First byte reserved for PdtShortestPathData use.
121 };
122
PdtShortestPathData(bool gc)123 explicit PdtShortestPathData(bool gc)
124 : gc_(gc), nstates_(0), ngc_(0), finished_(false) {}
125
~PdtShortestPathData()126 ~PdtShortestPathData() {
127 VLOG(1) << "opm size: " << paren_map_.size();
128 VLOG(1) << "# of search states: " << nstates_;
129 if (gc_) VLOG(1) << "# of GC'd search states: " << ngc_;
130 }
131
Clear()132 void Clear() {
133 search_map_.clear();
134 search_multimap_.clear();
135 paren_map_.clear();
136 state_ = SearchState(kNoStateId, kNoStateId);
137 nstates_ = 0;
138 ngc_ = 0;
139 }
140
141 // TODO(kbg): Currently copying SearchState and passing a const reference to
142 // ParenSpec. Benchmark to confirm this is the right thing to do.
143
Distance(SearchState s)144 Weight Distance(SearchState s) const { return GetSearchData(s)->distance; }
145
Distance(const ParenSpec & paren)146 Weight Distance(const ParenSpec &paren) const {
147 return GetSearchData(paren)->distance;
148 }
149
Parent(SearchState s)150 SearchState Parent(SearchState s) const { return GetSearchData(s)->parent; }
151
Parent(const ParenSpec & paren)152 SearchState Parent(const ParenSpec &paren) const {
153 return GetSearchData(paren)->parent;
154 }
155
ParenId(SearchState s)156 Label ParenId(SearchState s) const { return GetSearchData(s)->paren_id; }
157
Flags(SearchState s)158 uint8 Flags(SearchState s) const { return GetSearchData(s)->flags; }
159
SetDistance(SearchState s,Weight weight)160 void SetDistance(SearchState s, Weight weight) {
161 GetSearchData(s)->distance = std::move(weight);
162 }
163
SetDistance(const ParenSpec & paren,Weight weight)164 void SetDistance(const ParenSpec &paren, Weight weight) {
165 GetSearchData(paren)->distance = std::move(weight);
166 }
167
SetParent(SearchState s,SearchState p)168 void SetParent(SearchState s, SearchState p) { GetSearchData(s)->parent = p; }
169
SetParent(const ParenSpec & paren,SearchState p)170 void SetParent(const ParenSpec &paren, SearchState p) {
171 GetSearchData(paren)->parent = p;
172 }
173
SetParenId(SearchState s,Label p)174 void SetParenId(SearchState s, Label p) {
175 if (p >= 32768) {
176 FSTERROR() << "PdtShortestPathData: Paren ID does not fit in an int16";
177 }
178 GetSearchData(s)->paren_id = p;
179 }
180
SetFlags(SearchState s,uint8 f,uint8 mask)181 void SetFlags(SearchState s, uint8 f, uint8 mask) {
182 auto *data = GetSearchData(s);
183 data->flags &= ~mask;
184 data->flags |= f & mask;
185 }
186
187 void GC(StateId s);
188
Finish()189 void Finish() { finished_ = true; }
190
191 private:
192 // Hash for search state.
193 struct SearchStateHash {
operatorSearchStateHash194 size_t operator()(const SearchState &s) const {
195 static constexpr auto prime = 7853;
196 return s.state + s.start * prime;
197 }
198 };
199
200 // Hash for paren map.
201 struct ParenHash {
operatorParenHash202 size_t operator()(const ParenSpec &paren) const {
203 static constexpr auto prime0 = 7853;
204 static constexpr auto prime1 = 7867;
205 return paren.paren_id + paren.src_start * prime0 +
206 paren.dest_start * prime1;
207 }
208 };
209
210 using SearchMap =
211 std::unordered_map<SearchState, SearchData, SearchStateHash>;
212
213 using SearchMultimap = std::unordered_multimap<StateId, StateId>;
214
215 // Hash map from paren spec to open paren data.
216 using ParenMap = std::unordered_map<ParenSpec, SearchData, ParenHash>;
217
GetSearchData(SearchState s)218 SearchData *GetSearchData(SearchState s) const {
219 if (s == state_) return state_data_;
220 if (finished_) {
221 auto it = search_map_.find(s);
222 if (it == search_map_.end()) return &null_search_data_;
223 state_ = s;
224 return state_data_ = &(it->second);
225 } else {
226 state_ = s;
227 state_data_ = &search_map_[s];
228 if (!(state_data_->flags & kPdtInited)) {
229 ++nstates_;
230 if (gc_) search_multimap_.insert(std::make_pair(s.start, s.state));
231 state_data_->flags = kPdtInited;
232 }
233 return state_data_;
234 }
235 }
236
GetSearchData(ParenSpec paren)237 SearchData *GetSearchData(ParenSpec paren) const {
238 if (paren == paren_) return paren_data_;
239 if (finished_) {
240 auto it = paren_map_.find(paren);
241 if (it == paren_map_.end()) return &null_search_data_;
242 paren_ = paren;
243 return state_data_ = &(it->second);
244 } else {
245 paren_ = paren;
246 return paren_data_ = &paren_map_[paren];
247 }
248 }
249
250 mutable SearchMap search_map_; // Maps from search state to data.
251 mutable SearchMultimap search_multimap_; // Maps from "start" to subgraph.
252 mutable ParenMap paren_map_; // Maps paren spec to search data.
253 mutable SearchState state_; // Last state accessed.
254 mutable SearchData *state_data_; // Last state data accessed.
255 mutable ParenSpec paren_; // Last paren spec accessed.
256 mutable SearchData *paren_data_; // Last paren data accessed.
257 bool gc_; // Allow GC?
258 mutable size_t nstates_; // Total number of search states.
259 size_t ngc_; // Number of GC'd search states.
260 mutable SearchData null_search_data_; // Null search data.
261 bool finished_; // Read-only access when true.
262
263 PdtShortestPathData(const PdtShortestPathData &) = delete;
264 PdtShortestPathData &operator=(const PdtShortestPathData &) = delete;
265 };
266
267 // Deletes inaccessible search data from a given "start" (open paren dest)
268 // state. Assumes "final" (close paren source or PDT final) states have
269 // been flagged kPdtFinal.
270 template <class Arc>
GC(StateId start)271 void PdtShortestPathData<Arc>::GC(StateId start) {
272 if (!gc_) return;
273 std::vector<StateId> finals;
274 for (auto it = search_multimap_.find(start);
275 it != search_multimap_.end() && it->first == start; ++it) {
276 const SearchState s(it->second, start);
277 if (search_map_[s].flags & kPdtFinal) finals.push_back(s.state);
278 }
279 // Mark phase.
280 for (const auto state : finals) {
281 SearchState ss(state, start);
282 while (ss.state != kNoLabel) {
283 auto &sdata = search_map_[ss];
284 if (sdata.flags & kPdtMarked) break;
285 sdata.flags |= kPdtMarked;
286 const auto p = sdata.parent;
287 if (p.start != start && p.start != kNoLabel) { // Entering sub-subgraph.
288 const ParenSpec paren(sdata.paren_id, ss.start, p.start);
289 ss = paren_map_[paren].parent;
290 } else {
291 ss = p;
292 }
293 }
294 }
295 // Sweep phase.
296 auto it = search_multimap_.find(start);
297 while (it != search_multimap_.end() && it->first == start) {
298 const SearchState s(it->second, start);
299 auto mit = search_map_.find(s);
300 const SearchData &data = mit->second;
301 if (!(data.flags & kPdtMarked)) {
302 search_map_.erase(mit);
303 ++ngc_;
304 }
305 search_multimap_.erase(it++);
306 }
307 }
308
309 } // namespace internal
310
311 // This computes the single source shortest (balanced) path (SSSP) through a
312 // weighted PDT that has a bounded stack (i.e., is expandable as an FST). It is
313 // a generalization of the classic SSSP graph algorithm that removes a state s
314 // from a queue (defined by a user-provided queue type) and relaxes the
315 // destination states of transitions leaving s. In this PDT version, states that
316 // have entering open parentheses are treated as source states for a sub-graph
317 // SSSP problem with the shortest path up to the open parenthesis being first
318 // saved. When a close parenthesis is then encountered any balancing open
319 // parenthesis is examined for this saved information and multiplied back. In
320 // this way, each sub-graph is entered only once rather than repeatedly. If
321 // every state in the input PDT has the property that there is a unique "start"
322 // state for it with entering open parentheses, then this algorithm is quite
323 // straightforward. In general, this will not be the case, so the algorithm
324 // (implicitly) creates a new graph where each state is a pair of an original
325 // state and a possible parenthesis "start" state for that state.
326 template <class Arc, class Queue>
327 class PdtShortestPath {
328 public:
329 using Label = typename Arc::Label;
330 using StateId = typename Arc::StateId;
331 using Weight = typename Arc::Weight;
332
333 using SpData = internal::PdtShortestPathData<Arc>;
334 using SearchState = typename SpData::SearchState;
335 using ParenSpec = typename SpData::ParenSpec;
336 using CloseSourceIterator =
337 typename internal::PdtBalanceData<Arc>::SetIterator;
338
PdtShortestPath(const Fst<Arc> & ifst,const std::vector<std::pair<Label,Label>> & parens,const PdtShortestPathOptions<Arc,Queue> & opts)339 PdtShortestPath(const Fst<Arc> &ifst,
340 const std::vector<std::pair<Label, Label>> &parens,
341 const PdtShortestPathOptions<Arc, Queue> &opts)
342 : ifst_(ifst.Copy()),
343 parens_(parens),
344 keep_parens_(opts.keep_parentheses),
345 start_(ifst.Start()),
346 sp_data_(opts.path_gc),
347 error_(false) {
348 // TODO(kbg): Make this a compile-time static_assert once:
349 // 1) All weight properties are made constexpr for all weight types.
350 // 2) We have a pleasant way to "deregister" this oepration for non-path
351 // semirings so an informative error message is produced. The best
352 // solution will probably involve some kind of SFINAE magic.
353 if ((Weight::Properties() & (kPath | kRightSemiring)) !=
354 (kPath | kRightSemiring)) {
355 FSTERROR() << "PdtShortestPath: Weight needs to have the path"
356 << " property and be right distributive: " << Weight::Type();
357 error_ = true;
358 }
359 for (Label i = 0; i < parens.size(); ++i) {
360 const auto &pair = parens[i];
361 paren_map_[pair.first] = i;
362 paren_map_[pair.second] = i;
363 }
364 }
365
~PdtShortestPath()366 ~PdtShortestPath() {
367 VLOG(1) << "# of input states: " << CountStates(*ifst_);
368 VLOG(1) << "# of enqueued: " << nenqueued_;
369 VLOG(1) << "cpmm size: " << close_paren_multimap_.size();
370 }
371
ShortestPath(MutableFst<Arc> * ofst)372 void ShortestPath(MutableFst<Arc> *ofst) {
373 Init(ofst);
374 GetDistance(start_);
375 GetPath();
376 sp_data_.Finish();
377 if (error_) ofst->SetProperties(kError, kError);
378 }
379
GetShortestPathData()380 const internal::PdtShortestPathData<Arc> &GetShortestPathData() const {
381 return sp_data_;
382 }
383
GetBalanceData()384 internal::PdtBalanceData<Arc> *GetBalanceData() { return &balance_data_; }
385
386 public:
387 // Hash multimap from close paren label to an paren arc.
388 using CloseParenMultimap =
389 std::unordered_multimap<internal::ParenState<Arc>, Arc,
390 typename internal::ParenState<Arc>::Hash>;
391
GetCloseParenMultimap()392 const CloseParenMultimap &GetCloseParenMultimap() const {
393 return close_paren_multimap_;
394 }
395
396 private:
397 void Init(MutableFst<Arc> *ofst);
398
399 void GetDistance(StateId start);
400
401 void ProcFinal(SearchState s);
402
403 void ProcArcs(SearchState s);
404
405 void ProcOpenParen(Label paren_id, SearchState s, StateId nexstate,
406 const Weight &weight);
407
408 void ProcCloseParen(Label paren_id, SearchState s, const Weight &weight);
409
410 void ProcNonParen(SearchState s, StateId nextstate, const Weight &weight);
411
412 void Relax(SearchState s, SearchState t, StateId nextstate,
413 const Weight &weight, Label paren_id);
414
415 void Enqueue(SearchState d);
416
417 void GetPath();
418
419 Arc GetPathArc(SearchState s, SearchState p, Label paren_id, bool open);
420
421 std::unique_ptr<Fst<Arc>> ifst_;
422 MutableFst<Arc> *ofst_;
423 const std::vector<std::pair<Label, Label>> &parens_;
424 bool keep_parens_;
425 Queue *state_queue_;
426 StateId start_;
427 Weight fdistance_;
428 SearchState f_parent_;
429 SpData sp_data_;
430 std::unordered_map<Label, Label> paren_map_;
431 CloseParenMultimap close_paren_multimap_;
432 internal::PdtBalanceData<Arc> balance_data_;
433 ssize_t nenqueued_;
434 bool error_;
435
436 static constexpr uint8 kEnqueued = 0x10;
437 static constexpr uint8 kExpanded = 0x20;
438 static constexpr uint8 kFinished = 0x40;
439
440 static const Arc kNoArc;
441 };
442
443 template <class Arc, class Queue>
Init(MutableFst<Arc> * ofst)444 void PdtShortestPath<Arc, Queue>::Init(MutableFst<Arc> *ofst) {
445 ofst_ = ofst;
446 ofst->DeleteStates();
447 ofst->SetInputSymbols(ifst_->InputSymbols());
448 ofst->SetOutputSymbols(ifst_->OutputSymbols());
449 if (ifst_->Start() == kNoStateId) return;
450 fdistance_ = Weight::Zero();
451 f_parent_ = SearchState(kNoStateId, kNoStateId);
452 sp_data_.Clear();
453 close_paren_multimap_.clear();
454 balance_data_.Clear();
455 nenqueued_ = 0;
456 // Finds open parens per destination state and close parens per source state.
457 for (StateIterator<Fst<Arc>> siter(*ifst_); !siter.Done(); siter.Next()) {
458 const auto s = siter.Value();
459 for (ArcIterator<Fst<Arc>> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) {
460 const auto &arc = aiter.Value();
461 const auto it = paren_map_.find(arc.ilabel);
462 if (it != paren_map_.end()) { // Is a paren?
463 const auto paren_id = it->second;
464 if (arc.ilabel == parens_[paren_id].first) { // Open paren.
465 balance_data_.OpenInsert(paren_id, arc.nextstate);
466 } else { // Close paren.
467 const internal::ParenState<Arc> paren_state(paren_id, s);
468 close_paren_multimap_.emplace(paren_state, arc);
469 }
470 }
471 }
472 }
473 }
474
475 // Computes the shortest distance stored in a recursive way. Each sub-graph
476 // (i.e., different paren "start" state) begins with weight One().
477 template <class Arc, class Queue>
GetDistance(StateId start)478 void PdtShortestPath<Arc, Queue>::GetDistance(StateId start) {
479 if (start == kNoStateId) return;
480 Queue state_queue;
481 state_queue_ = &state_queue;
482 const SearchState q(start, start);
483 Enqueue(q);
484 sp_data_.SetDistance(q, Weight::One());
485 while (!state_queue_->Empty()) {
486 const auto state = state_queue_->Head();
487 state_queue_->Dequeue();
488 const SearchState s(state, start);
489 sp_data_.SetFlags(s, 0, kEnqueued);
490 ProcFinal(s);
491 ProcArcs(s);
492 sp_data_.SetFlags(s, kExpanded, kExpanded);
493 }
494 sp_data_.SetFlags(q, kFinished, kFinished);
495 balance_data_.FinishInsert(start);
496 sp_data_.GC(start);
497 }
498
499 // Updates best complete path.
500 template <class Arc, class Queue>
ProcFinal(SearchState s)501 void PdtShortestPath<Arc, Queue>::ProcFinal(SearchState s) {
502 if (ifst_->Final(s.state) != Weight::Zero() && s.start == start_) {
503 const auto weight = Times(sp_data_.Distance(s), ifst_->Final(s.state));
504 if (fdistance_ != Plus(fdistance_, weight)) {
505 if (f_parent_.state != kNoStateId) {
506 sp_data_.SetFlags(f_parent_, 0, internal::kPdtFinal);
507 }
508 sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
509 fdistance_ = Plus(fdistance_, weight);
510 f_parent_ = s;
511 }
512 }
513 }
514
515 // Processes all arcs leaving the state s.
516 template <class Arc, class Queue>
ProcArcs(SearchState s)517 void PdtShortestPath<Arc, Queue>::ProcArcs(SearchState s) {
518 for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
519 aiter.Next()) {
520 const auto &arc = aiter.Value();
521 const auto weight = Times(sp_data_.Distance(s), arc.weight);
522 const auto it = paren_map_.find(arc.ilabel);
523 if (it != paren_map_.end()) { // Is a paren?
524 const auto paren_id = it->second;
525 if (arc.ilabel == parens_[paren_id].first) {
526 ProcOpenParen(paren_id, s, arc.nextstate, weight);
527 } else {
528 ProcCloseParen(paren_id, s, weight);
529 }
530 } else {
531 ProcNonParen(s, arc.nextstate, weight);
532 }
533 }
534 }
535
536 // Saves the shortest path info for reaching this parenthesis and starts a new
537 // SSSP in the sub-graph pointed to by the parenthesis if previously unvisited.
538 // Otherwise it finds any previously encountered closing parentheses and relaxes
539 // them using the recursively stored shortest distance to them.
540 template <class Arc, class Queue>
ProcOpenParen(Label paren_id,SearchState s,StateId nextstate,const Weight & weight)541 inline void PdtShortestPath<Arc, Queue>::ProcOpenParen(Label paren_id,
542 SearchState s,
543 StateId nextstate,
544 const Weight &weight) {
545 const SearchState d(nextstate, nextstate);
546 const ParenSpec paren(paren_id, s.start, d.start);
547 const auto pdist = sp_data_.Distance(paren);
548 if (pdist != Plus(pdist, weight)) {
549 sp_data_.SetDistance(paren, weight);
550 sp_data_.SetParent(paren, s);
551 const auto dist = sp_data_.Distance(d);
552 if (dist == Weight::Zero()) {
553 auto *state_queue = state_queue_;
554 GetDistance(d.start);
555 state_queue_ = state_queue;
556 } else if (!(sp_data_.Flags(d) & kFinished)) {
557 FSTERROR()
558 << "PdtShortestPath: open parenthesis recursion: not bounded stack";
559 error_ = true;
560 }
561 for (auto set_iter = balance_data_.Find(paren_id, nextstate);
562 !set_iter.Done(); set_iter.Next()) {
563 const SearchState cpstate(set_iter.Element(), d.start);
564 const internal::ParenState<Arc> paren_state(paren_id, cpstate.state);
565 for (auto cpit = close_paren_multimap_.find(paren_state);
566 cpit != close_paren_multimap_.end() && paren_state == cpit->first;
567 ++cpit) {
568 const auto &cparc = cpit->second;
569 const auto cpw =
570 Times(weight, Times(sp_data_.Distance(cpstate), cparc.weight));
571 Relax(cpstate, s, cparc.nextstate, cpw, paren_id);
572 }
573 }
574 }
575 }
576
577 // Saves the correspondence between each closing parenthesis and its balancing
578 // open parenthesis info. Relaxes any close parenthesis destination state that
579 // has a balancing previously encountered open parenthesis.
580 template <class Arc, class Queue>
ProcCloseParen(Label paren_id,SearchState s,const Weight & weight)581 inline void PdtShortestPath<Arc, Queue>::ProcCloseParen(Label paren_id,
582 SearchState s,
583 const Weight &weight) {
584 const internal::ParenState<Arc> paren_state(paren_id, s.start);
585 if (!(sp_data_.Flags(s) & kExpanded)) {
586 balance_data_.CloseInsert(paren_id, s.start, s.state);
587 sp_data_.SetFlags(s, internal::kPdtFinal, internal::kPdtFinal);
588 }
589 }
590
591 // Classical relaxation for non-parentheses.
592 template <class Arc, class Queue>
ProcNonParen(SearchState s,StateId nextstate,const Weight & weight)593 inline void PdtShortestPath<Arc, Queue>::ProcNonParen(SearchState s,
594 StateId nextstate,
595 const Weight &weight) {
596 Relax(s, s, nextstate, weight, kNoLabel);
597 }
598
599 // Classical relaxation on the search graph for an arc with destination state
600 // nexstate from state s. State t is in the same sub-graph as nextstate (i.e.,
601 // has the same paren "start").
602 template <class Arc, class Queue>
Relax(SearchState s,SearchState t,StateId nextstate,const Weight & weight,Label paren_id)603 inline void PdtShortestPath<Arc, Queue>::Relax(SearchState s, SearchState t,
604 StateId nextstate,
605 const Weight &weight,
606 Label paren_id) {
607 const SearchState d(nextstate, t.start);
608 Weight dist = sp_data_.Distance(d);
609 if (dist != Plus(dist, weight)) {
610 sp_data_.SetParent(d, s);
611 sp_data_.SetParenId(d, paren_id);
612 sp_data_.SetDistance(d, Plus(dist, weight));
613 Enqueue(d);
614 }
615 }
616
617 template <class Arc, class Queue>
Enqueue(SearchState s)618 inline void PdtShortestPath<Arc, Queue>::Enqueue(SearchState s) {
619 if (!(sp_data_.Flags(s) & kEnqueued)) {
620 state_queue_->Enqueue(s.state);
621 sp_data_.SetFlags(s, kEnqueued, kEnqueued);
622 ++nenqueued_;
623 } else {
624 state_queue_->Update(s.state);
625 }
626 }
627
628 // Follows parent pointers to find the shortest path. A stack is used since the
629 // shortest distance is stored recursively.
630 template <class Arc, class Queue>
GetPath()631 void PdtShortestPath<Arc, Queue>::GetPath() {
632 SearchState s = f_parent_;
633 SearchState d = SearchState(kNoStateId, kNoStateId);
634 StateId s_p = kNoStateId;
635 StateId d_p = kNoStateId;
636 auto arc = kNoArc;
637 Label paren_id = kNoLabel;
638 std::stack<ParenSpec> paren_stack;
639 while (s.state != kNoStateId) {
640 d_p = s_p;
641 s_p = ofst_->AddState();
642 if (d.state == kNoStateId) {
643 ofst_->SetFinal(s_p, ifst_->Final(f_parent_.state));
644 } else {
645 if (paren_id != kNoLabel) { // Paren?
646 if (arc.ilabel == parens_[paren_id].first) { // Open paren?
647 paren_stack.pop();
648 } else { // Close paren?
649 const ParenSpec paren(paren_id, d.start, s.start);
650 paren_stack.push(paren);
651 }
652 if (!keep_parens_) arc.ilabel = arc.olabel = 0;
653 }
654 arc.nextstate = d_p;
655 ofst_->AddArc(s_p, arc);
656 }
657 d = s;
658 s = sp_data_.Parent(d);
659 paren_id = sp_data_.ParenId(d);
660 if (s.state != kNoStateId) {
661 arc = GetPathArc(s, d, paren_id, false);
662 } else if (!paren_stack.empty()) {
663 const ParenSpec paren = paren_stack.top();
664 s = sp_data_.Parent(paren);
665 paren_id = paren.paren_id;
666 arc = GetPathArc(s, d, paren_id, true);
667 }
668 }
669 ofst_->SetStart(s_p);
670 ofst_->SetProperties(
671 ShortestPathProperties(ofst_->Properties(kFstProperties, false)),
672 kFstProperties);
673 }
674
675 // Finds transition with least weight between two states with label matching
676 // paren_id and open/close paren type or a non-paren if kNoLabel.
677 template <class Arc, class Queue>
GetPathArc(SearchState s,SearchState d,Label paren_id,bool open_paren)678 Arc PdtShortestPath<Arc, Queue>::GetPathArc(SearchState s, SearchState d,
679 Label paren_id, bool open_paren) {
680 auto path_arc = kNoArc;
681 for (ArcIterator<Fst<Arc>> aiter(*ifst_, s.state); !aiter.Done();
682 aiter.Next()) {
683 const auto &arc = aiter.Value();
684 if (arc.nextstate != d.state) continue;
685 Label arc_paren_id = kNoLabel;
686 const auto it = paren_map_.find(arc.ilabel);
687 if (it != paren_map_.end()) {
688 arc_paren_id = it->second;
689 bool arc_open_paren = (arc.ilabel == parens_[arc_paren_id].first);
690 if (arc_open_paren != open_paren) continue;
691 }
692 if (arc_paren_id != paren_id) continue;
693 if (arc.weight == Plus(arc.weight, path_arc.weight)) path_arc = arc;
694 }
695 if (path_arc.nextstate == kNoStateId) {
696 FSTERROR() << "PdtShortestPath::GetPathArc: Failed to find arc";
697 error_ = true;
698 }
699 return path_arc;
700 }
701
702 template <class Arc, class Queue>
703 const Arc PdtShortestPath<Arc, Queue>::kNoArc = Arc(kNoLabel, kNoLabel,
704 Weight::Zero(), kNoStateId);
705
706 // Functional variants.
707
708 template <class Arc, class Queue>
ShortestPath(const Fst<Arc> & ifst,const std::vector<std::pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst,const PdtShortestPathOptions<Arc,Queue> & opts)709 void ShortestPath(
710 const Fst<Arc> &ifst,
711 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
712 &parens,
713 MutableFst<Arc> *ofst, const PdtShortestPathOptions<Arc, Queue> &opts) {
714 PdtShortestPath<Arc, Queue> psp(ifst, parens, opts);
715 psp.ShortestPath(ofst);
716 }
717
718 template <class Arc>
ShortestPath(const Fst<Arc> & ifst,const std::vector<std::pair<typename Arc::Label,typename Arc::Label>> & parens,MutableFst<Arc> * ofst)719 void ShortestPath(
720 const Fst<Arc> &ifst,
721 const std::vector<std::pair<typename Arc::Label, typename Arc::Label>>
722 &parens,
723 MutableFst<Arc> *ofst) {
724 using Q = FifoQueue<typename Arc::StateId>;
725 const PdtShortestPathOptions<Arc, Q> opts;
726 PdtShortestPath<Arc, Q> psp(ifst, parens, opts);
727 psp.ShortestPath(ofst);
728 }
729
730 } // namespace fst
731
732 #endif // FST_EXTENSIONS_PDT_SHORTEST_PATH_H_
733