1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to allow matching labels leaving FST states.
19 
20 #ifndef FST_MATCHER_H_
21 #define FST_MATCHER_H_
22 
23 #include <algorithm>
24 #include <memory>
25 #include <utility>
26 
27 #include <fst/types.h>
28 #include <fst/log.h>
29 
30 #include <fst/mutable-fst.h>  // for all internal FST accessors.
31 
32 #include <unordered_map>
33 
34 namespace fst {
35 
36 // Matchers find and iterate through requested labels at FST states. In the
37 // simplest form, these are just some associative map or search keyed on labels.
38 // More generally, they may implement matching special labels that represent
39 // sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher
40 // interface is:
41 //
42 // template <class F>
43 // class Matcher {
44 //  public:
45 //   using FST = F;
46 //   using Arc = typename FST::Arc;
47 //   using Label = typename Arc::Label;
48 //   using StateId = typename Arc::StateId;
49 //   using Weight = typename Arc::Weight;
50 //
51 //   // Required constructors. Note:
52 //   // -- the constructors that copy the FST arg are useful for
53 //   // letting the matcher manage the FST through copies
54 //   // (esp with 'safe' copies); e.g. ComposeFst depends on this.
55 //   // -- the constructor that does not copy is useful when the
56 //   // the FST is mutated during the lifetime of the matcher
57 //   // (o.w. the matcher would have its own unmutated deep copy).
58 //
59 //   // This makes a copy of the FST.
60 //   Matcher(const FST &fst, MatchType type);
61 //   // This doesn't copy the FST.
62 //   Matcher(const FST *fst, MatchType type);
63 //   // This makes a copy of the FST.
64 //   // See Copy() below.
65 //   Matcher(const Matcher &matcher, bool safe = false);
66 //
67 //   // If safe = true, the copy is thread-safe. See Fst<>::Copy() for
68 //   // further doc.
69 //   Matcher *Copy(bool safe = false) const override;
70 //
71 //   // Returns the match type that can be provided (depending on compatibility
72 //   // of the input FST). It is either the requested match type, MATCH_NONE,
73 //   // or MATCH_UNKNOWN. If test is false, a costly testing is avoided, but
74 //   // MATCH_UNKNOWN may be returned. If test is true, a definite answer is
75 //   // returned, but may involve more costly computation (e.g., visiting
76 //   // the FST).
77 //   // MatchType Type(bool test) const override;
78 //
79 //   // Specifies the current state.
80 //   void SetState(StateId s) final;
81 //
82 //   // Finds matches to a label at the current state, returning true if a match
83 //   // found. kNoLabel matches any non-consuming transitions, e.g., epsilon
84 //   // transitions, which do not require a matching symbol.
85 //   bool Find(Label label) final;
86 //
87 //   // Iterator methods. Note that initially and after SetState() these have
88 //   // undefined behavior until Find() is called.
89 //
90 //   bool Done() const final;
91 //
92 //   const Arc &Value() const final;
93 //
94 //   void Next() final;
95 //
96 //   // Returns final weight of a state.
97 //   Weight Final(StateId) const final;
98 //
99 //   // Indicates preference for being the side used for matching in
100 //   // composition. If the value is kRequirePriority, then it is
101 //   // mandatory that it be used. Calling this method without passing the
102 //   // current state of the matcher invalidates the state of the matcher.
103 //   ssize_t Priority(StateId s) final;
104 //
105 //   // This specifies the known FST properties as viewed from this matcher. It
106 //   // takes as argument the input FST's known properties.
107 //   uint64 Properties(uint64 props) const override;
108 //
109 //   // Returns matcher flags.
110 //   uint32 Flags() const override;
111 //
112 //   // Returns matcher FST.
113 //   const FST &GetFst() const override;
114 // };
115 
116 // Basic matcher flags.
117 
118 // Matcher needs to be used as the matching side in composition for
119 // at least one state (has kRequirePriority).
120 constexpr uint32 kRequireMatch = 0x00000001;
121 
122 // Flags used for basic matchers (see also lookahead.h).
123 constexpr uint32 kMatcherFlags = kRequireMatch;
124 
125 // Matcher priority that is mandatory.
126 constexpr ssize_t kRequirePriority = -1;
127 
128 // Matcher interface, templated on the Arc definition; used for matcher
129 // specializations that are returned by the InitMatcher FST method.
130 template <class A>
131 class MatcherBase {
132  public:
133   using Arc = A;
134   using Label = typename Arc::Label;
135   using StateId = typename Arc::StateId;
136   using Weight = typename Arc::Weight;
137 
~MatcherBase()138   virtual ~MatcherBase() {}
139 
140   // Virtual interface.
141 
142   virtual MatcherBase *Copy(bool safe = false) const = 0;
143   virtual MatchType Type(bool) const = 0;
144   virtual void SetState(StateId) = 0;
145   virtual bool Find(Label) = 0;
146   virtual bool Done() const = 0;
147   virtual const Arc &Value() const = 0;
148   virtual void Next() = 0;
149   virtual const Fst<Arc> &GetFst() const = 0;
150   virtual uint64 Properties(uint64) const = 0;
151 
152   // Trivial implementations that can be used by derived classes. Full
153   // devirtualization is expected for any derived class marked final.
Flags()154   virtual uint32 Flags() const { return 0; }
155 
Final(StateId s)156   virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); }
157 
Priority(StateId s)158   virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); }
159 };
160 
161 // A matcher that expects sorted labels on the side to be matched.
162 // If match_type == MATCH_INPUT, epsilons match the implicit self-loop
163 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
164 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
165 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
166 template <class F>
167 class SortedMatcher : public MatcherBase<typename F::Arc> {
168  public:
169   using FST = F;
170   using Arc = typename FST::Arc;
171   using Label = typename Arc::Label;
172   using StateId = typename Arc::StateId;
173   using Weight = typename Arc::Weight;
174 
175   using MatcherBase<Arc>::Flags;
176   using MatcherBase<Arc>::Properties;
177 
178   // Labels >= binary_label will be searched for by binary search;
179   // o.w. linear search is used.
180   // This makes a copy of the FST.
181   SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1)
182       : SortedMatcher(fst.Copy(), match_type, binary_label) {
183     owned_fst_.reset(&fst_);
184   }
185 
186   // Labels >= binary_label will be searched for by binary search;
187   // o.w. linear search is used.
188   // This doesn't copy the FST.
189   SortedMatcher(const FST *fst, MatchType match_type, Label binary_label = 1)
190       : fst_(*fst),
191         state_(kNoStateId),
192         aiter_(nullptr),
193         match_type_(match_type),
194         binary_label_(binary_label),
195         match_label_(kNoLabel),
196         narcs_(0),
197         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
198         error_(false),
199         aiter_pool_(1) {
200     switch (match_type_) {
201       case MATCH_INPUT:
202       case MATCH_NONE:
203         break;
204       case MATCH_OUTPUT:
205         std::swap(loop_.ilabel, loop_.olabel);
206         break;
207       default:
208         FSTERROR() << "SortedMatcher: Bad match type";
209         match_type_ = MATCH_NONE;
210         error_ = true;
211     }
212   }
213 
214   // This makes a copy of the FST.
215   SortedMatcher(const SortedMatcher &matcher, bool safe = false)
216       : owned_fst_(matcher.fst_.Copy(safe)),
217         fst_(*owned_fst_),
218         state_(kNoStateId),
219         aiter_(nullptr),
220         match_type_(matcher.match_type_),
221         binary_label_(matcher.binary_label_),
222         match_label_(kNoLabel),
223         narcs_(0),
224         loop_(matcher.loop_),
225         error_(matcher.error_),
226         aiter_pool_(1) {}
227 
~SortedMatcher()228   ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); }
229 
230   SortedMatcher *Copy(bool safe = false) const override {
231     return new SortedMatcher(*this, safe);
232   }
233 
Type(bool test)234   MatchType Type(bool test) const override {
235     if (match_type_ == MATCH_NONE) return match_type_;
236     const auto true_prop =
237         match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
238     const auto false_prop =
239         match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
240     const auto props = fst_.Properties(true_prop | false_prop, test);
241     if (props & true_prop) {
242       return match_type_;
243     } else if (props & false_prop) {
244       return MATCH_NONE;
245     } else {
246       return MATCH_UNKNOWN;
247     }
248   }
249 
SetState(StateId s)250   void SetState(StateId s) final {
251     if (state_ == s) return;
252     state_ = s;
253     if (match_type_ == MATCH_NONE) {
254       FSTERROR() << "SortedMatcher: Bad match type";
255       error_ = true;
256     }
257     Destroy(aiter_, &aiter_pool_);
258     aiter_ = new (&aiter_pool_) ArcIterator<FST>(fst_, s);
259     aiter_->SetFlags(kArcNoCache, kArcNoCache);
260     narcs_ = internal::NumArcs(fst_, s);
261     loop_.nextstate = s;
262   }
263 
Find(Label match_label)264   bool Find(Label match_label) final {
265     exact_match_ = true;
266     if (error_) {
267       current_loop_ = false;
268       match_label_ = kNoLabel;
269       return false;
270     }
271     current_loop_ = match_label == 0;
272     match_label_ = match_label == kNoLabel ? 0 : match_label;
273     if (Search()) {
274       return true;
275     } else {
276       return current_loop_;
277     }
278   }
279 
280   // Positions matcher to the first position where inserting match_label would
281   // maintain the sort order.
LowerBound(Label label)282   void LowerBound(Label label) {
283     exact_match_ = false;
284     current_loop_ = false;
285     if (error_) {
286       match_label_ = kNoLabel;
287       return;
288     }
289     match_label_ = label;
290     Search();
291   }
292 
293   // After Find(), returns false if no more exact matches.
294   // After LowerBound(), returns false if no more arcs.
Done()295   bool Done() const final {
296     if (current_loop_) return false;
297     if (aiter_->Done()) return true;
298     if (!exact_match_) return false;
299     aiter_->SetFlags(
300         match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
301         kArcValueFlags);
302     return GetLabel() != match_label_;
303   }
304 
Value()305   const Arc &Value() const final {
306     if (current_loop_) return loop_;
307     aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
308     return aiter_->Value();
309   }
310 
Next()311   void Next() final {
312     if (current_loop_) {
313       current_loop_ = false;
314     } else {
315       aiter_->Next();
316     }
317   }
318 
Final(StateId s)319   Weight Final(StateId s) const final { return MatcherBase<Arc>::Final(s); }
320 
Priority(StateId s)321   ssize_t Priority(StateId s) final { return MatcherBase<Arc>::Priority(s); }
322 
GetFst()323   const FST &GetFst() const override { return fst_; }
324 
Properties(uint64 inprops)325   uint64 Properties(uint64 inprops) const override {
326     return inprops | (error_ ? kError : 0);
327   }
328 
Position()329   size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
330 
331  private:
GetLabel()332   Label GetLabel() const {
333     const auto &arc = aiter_->Value();
334     return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
335   }
336 
337   bool BinarySearch();
338   bool LinearSearch();
339   bool Search();
340 
341   std::unique_ptr<const FST> owned_fst_;  // FST ptr if owned.
342   const FST &fst_;                        // FST for matching.
343   StateId state_;                         // Matcher state.
344   ArcIterator<FST> *aiter_;               // Iterator for current state.
345   MatchType match_type_;                  // Type of match to perform.
346   Label binary_label_;                    // Least label for binary search.
347   Label match_label_;                     // Current label to be matched.
348   size_t narcs_;                          // Current state arc count.
349   Arc loop_;                              // For non-consuming symbols.
350   bool current_loop_;                     // Current arc is the implicit loop.
351   bool exact_match_;                      // Exact match or lower bound?
352   bool error_;                            // Error encountered?
353   MemoryPool<ArcIterator<FST>> aiter_pool_;  // Pool of arc iterators.
354 };
355 
356 // Returns true iff match to match_label_. The arc iterator is positioned at the
357 // lower bound, that is, the first element greater than or equal to
358 // match_label_, or the end if all elements are less than match_label_.
359 // If multiple elements are equal to the `match_label_`, returns the rightmost
360 // one.
361 template <class FST>
BinarySearch()362 inline bool SortedMatcher<FST>::BinarySearch() {
363   size_t size = narcs_;
364   if (size == 0) {
365     return false;
366   }
367   size_t high = size - 1;
368   while (size > 1) {
369     const size_t half = size / 2;
370     const size_t mid = high - half;
371     aiter_->Seek(mid);
372     if (GetLabel() >= match_label_) {
373       high = mid;
374     }
375     size -= half;
376   }
377   aiter_->Seek(high);
378   const auto label = GetLabel();
379   if (label == match_label_) {
380     return true;
381   }
382   if (label < match_label_) {
383     aiter_->Next();
384   }
385   return false;
386 }
387 
388 // Returns true iff match to match_label_, positioning arc iterator at lower
389 // bound.
390 template <class FST>
LinearSearch()391 inline bool SortedMatcher<FST>::LinearSearch() {
392   for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
393     const auto label = GetLabel();
394     if (label == match_label_) return true;
395     if (label > match_label_) break;
396   }
397   return false;
398 }
399 
400 // Returns true iff match to match_label_, positioning arc iterator at lower
401 // bound.
402 template <class FST>
Search()403 inline bool SortedMatcher<FST>::Search() {
404   aiter_->SetFlags(
405       match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
406       kArcValueFlags);
407   if (match_label_ >= binary_label_) {
408     return BinarySearch();
409   } else {
410     return LinearSearch();
411   }
412 }
413 
414 // A matcher that stores labels in a per-state hash table populated upon the
415 // first visit to that state. Sorting is not required. Treatment of
416 // epsilons are the same as with SortedMatcher.
417 template <class F>
418 class HashMatcher : public MatcherBase<typename F::Arc> {
419  public:
420   using FST = F;
421   using Arc = typename FST::Arc;
422   using Label = typename Arc::Label;
423   using StateId = typename Arc::StateId;
424   using Weight = typename Arc::Weight;
425 
426   using MatcherBase<Arc>::Flags;
427   using MatcherBase<Arc>::Final;
428   using MatcherBase<Arc>::Priority;
429 
430   // This makes a copy of the FST.
HashMatcher(const FST & fst,MatchType match_type)431   HashMatcher(const FST &fst, MatchType match_type)
432       : HashMatcher(fst.Copy(), match_type) {
433     owned_fst_.reset(&fst_);
434   }
435 
436   // This doesn't copy the FST.
HashMatcher(const FST * fst,MatchType match_type)437   HashMatcher(const FST *fst, MatchType match_type)
438       : fst_(*fst),
439         state_(kNoStateId),
440         match_type_(match_type),
441         loop_(kNoLabel, 0, Weight::One(), kNoStateId),
442         error_(false),
443         state_table_(std::make_shared<StateTable>()) {
444     switch (match_type_) {
445       case MATCH_INPUT:
446       case MATCH_NONE:
447         break;
448       case MATCH_OUTPUT:
449         std::swap(loop_.ilabel, loop_.olabel);
450         break;
451       default:
452         FSTERROR() << "HashMatcher: Bad match type";
453         match_type_ = MATCH_NONE;
454         error_ = true;
455     }
456   }
457 
458   // This makes a copy of the FST.
459   HashMatcher(const HashMatcher &matcher, bool safe = false)
460       : owned_fst_(matcher.fst_.Copy(safe)),
461         fst_(*owned_fst_),
462         state_(kNoStateId),
463         match_type_(matcher.match_type_),
464         loop_(matcher.loop_),
465         error_(matcher.error_),
466         state_table_(safe ? std::make_shared<StateTable>()
467                           : matcher.state_table_) {}
468 
469   HashMatcher *Copy(bool safe = false) const override {
470     return new HashMatcher(*this, safe);
471   }
472 
473   // The argument is ignored as there are no relevant properties to test.
Type(bool test)474   MatchType Type(bool test) const override { return match_type_; }
475 
476   void SetState(StateId s) final;
477 
Find(Label label)478   bool Find(Label label) final {
479     current_loop_ = label == 0;
480     if (label == 0) {
481       Search(label);
482       return true;
483     }
484     if (label == kNoLabel) label = 0;
485     return Search(label);
486   }
487 
Done()488   bool Done() const final {
489     if (current_loop_) return false;
490     return label_it_ == label_end_;
491   }
492 
Value()493   const Arc &Value() const final {
494     if (current_loop_) return loop_;
495     aiter_->Seek(label_it_->second);
496     return aiter_->Value();
497   }
498 
Next()499   void Next() final {
500     if (current_loop_) {
501       current_loop_ = false;
502     } else {
503       ++label_it_;
504     }
505   }
506 
GetFst()507   const FST &GetFst() const override { return fst_; }
508 
Properties(uint64 inprops)509   uint64 Properties(uint64 inprops) const override {
510     return inprops | (error_ ? kError : 0);
511   }
512 
513  private:
GetLabel()514   Label GetLabel() const {
515     const auto &arc = aiter_->Value();
516     return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
517   }
518 
519   bool Search(Label match_label);
520 
521   using LabelTable = std::unordered_multimap<Label, size_t>;
522   using StateTable = std::unordered_map<StateId, std::unique_ptr<LabelTable>>;
523 
524   std::unique_ptr<const FST> owned_fst_;  // ptr to FST if owned.
525   const FST &fst_;                        // FST for matching.
526   StateId state_;                         // Matcher state.
527   MatchType match_type_;
528   Arc loop_;           // The implicit loop itself.
529   bool current_loop_;  // Is the current arc the implicit loop?
530   bool error_;         // Error encountered?
531   std::unique_ptr<ArcIterator<FST>> aiter_;
532   std::shared_ptr<StateTable> state_table_;  // Table from state to label table.
533   LabelTable *label_table_;  // Pointer to current state's label table.
534   typename LabelTable::iterator label_it_;   // Position for label.
535   typename LabelTable::iterator label_end_;  // Position for last label + 1.
536 };
537 
538 template <class FST>
SetState(typename FST::Arc::StateId s)539 void HashMatcher<FST>::SetState(typename FST::Arc::StateId s) {
540   if (state_ == s) return;
541   // Resets everything for the state.
542   state_ = s;
543   loop_.nextstate = state_;
544   aiter_ = std::make_unique<ArcIterator<FST>>(fst_, state_);
545   if (match_type_ == MATCH_NONE) {
546     FSTERROR() << "HashMatcher: Bad match type";
547     error_ = true;
548   }
549   // Attempts to insert a new label table.
550   auto it_and_success = state_table_->emplace(
551       state_, std::make_unique<LabelTable>());
552   // Sets instance's pointer to the label table for this state.
553   label_table_ = it_and_success.first->second.get();
554   // If it already exists, no additional work is done and we simply return.
555   if (!it_and_success.second) return;
556   // Otherwise, populate this new table.
557   // Populates the label table.
558   label_table_->reserve(internal::NumArcs(fst_, state_));
559   const auto aiter_flags =
560       (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) |
561       kArcNoCache;
562   aiter_->SetFlags(aiter_flags, kArcFlags);
563   for (; !aiter_->Done(); aiter_->Next()) {
564     label_table_->emplace(GetLabel(), aiter_->Position());
565   }
566   aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
567 }
568 
569 template <class FST>
Search(typename FST::Arc::Label match_label)570 inline bool HashMatcher<FST>::Search(typename FST::Arc::Label match_label) {
571   auto range = label_table_->equal_range(match_label);
572   label_it_ = range.first;
573   label_end_ = range.second;
574   if (label_it_ == label_end_) return false;
575   aiter_->Seek(label_it_->second);
576   return true;
577 }
578 
579 // Specifies whether we rewrite both the input and output sides during matching.
580 enum MatcherRewriteMode {
581   MATCHER_REWRITE_AUTO = 0,  // Rewrites both sides iff acceptor.
582   MATCHER_REWRITE_ALWAYS,
583   MATCHER_REWRITE_NEVER
584 };
585 
586 // For any requested label that doesn't match at a state, this matcher
587 // considers the *unique* transition that matches the label 'phi_label'
588 // (phi = 'fail'), and recursively looks for a match at its
589 // destination. When 'phi_loop' is true, if no match is found but a
590 // phi self-loop is found, then the phi transition found is returned
591 // with the phi_label rewritten as the requested label (both sides if
592 // an acceptor, or if 'rewrite_both' is true and both input and output
593 // labels of the found transition are 'phi_label'). If 'phi_label' is
594 // kNoLabel, this special matching is not done. PhiMatcher is
595 // templated itself on a matcher, which is used to perform the
596 // underlying matching. By default, the underlying matcher is
597 // constructed by PhiMatcher. The user can instead pass in this
598 // object; in that case, PhiMatcher takes its ownership.
599 // Phi non-determinism not supported. No non-consuming symbols other
600 // than epsilon supported with the underlying template argument matcher.
601 template <class M>
602 class PhiMatcher : public MatcherBase<typename M::Arc> {
603  public:
604   using FST = typename M::FST;
605   using Arc = typename FST::Arc;
606   using Label = typename Arc::Label;
607   using StateId = typename Arc::StateId;
608   using Weight = typename Arc::Weight;
609 
610   // This makes a copy of the FST (w/o 'matcher' arg).
611   PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel,
612              bool phi_loop = true,
613              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
614              M *matcher = nullptr)
615       : matcher_(matcher ? matcher : new M(fst, match_type)),
616         match_type_(match_type),
617         phi_label_(phi_label),
618         state_(kNoStateId),
619         phi_loop_(phi_loop),
620         error_(false) {
621     if (match_type == MATCH_BOTH) {
622       FSTERROR() << "PhiMatcher: Bad match type";
623       match_type_ = MATCH_NONE;
624       error_ = true;
625     }
626     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
627       rewrite_both_ = fst.Properties(kAcceptor, true);
628     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
629       rewrite_both_ = true;
630     } else {
631       rewrite_both_ = false;
632     }
633   }
634 
635   // This doesn't copy the FST.
636   PhiMatcher(const FST *fst, MatchType match_type, Label phi_label = kNoLabel,
637              bool phi_loop = true,
638              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
639              M *matcher = nullptr)
640       : PhiMatcher(*fst, match_type, phi_label, phi_loop, rewrite_mode,
641                    matcher ? matcher : new M(fst, match_type)) {}
642 
643   // This makes a copy of the FST.
644   PhiMatcher(const PhiMatcher &matcher, bool safe = false)
645       : matcher_(new M(*matcher.matcher_, safe)),
646         match_type_(matcher.match_type_),
647         phi_label_(matcher.phi_label_),
648         rewrite_both_(matcher.rewrite_both_),
649         state_(kNoStateId),
650         phi_loop_(matcher.phi_loop_),
651         error_(matcher.error_) {}
652 
653   PhiMatcher *Copy(bool safe = false) const override {
654     return new PhiMatcher(*this, safe);
655   }
656 
Type(bool test)657   MatchType Type(bool test) const override { return matcher_->Type(test); }
658 
SetState(StateId s)659   void SetState(StateId s) final {
660     if (state_ == s) return;
661     matcher_->SetState(s);
662     state_ = s;
663     has_phi_ = phi_label_ != kNoLabel;
664   }
665 
666   bool Find(Label match_label) final;
667 
Done()668   bool Done() const final { return matcher_->Done(); }
669 
Value()670   const Arc &Value() const final {
671     if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
672       return matcher_->Value();
673     } else if (phi_match_ == 0) {  // Virtual epsilon loop.
674       phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
675       if (match_type_ == MATCH_OUTPUT) {
676         std::swap(phi_arc_.ilabel, phi_arc_.olabel);
677       }
678       return phi_arc_;
679     } else {
680       phi_arc_ = matcher_->Value();
681       phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
682       if (phi_match_ != kNoLabel) {  // Phi loop match.
683         if (rewrite_both_) {
684           if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_;
685           if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_;
686         } else if (match_type_ == MATCH_INPUT) {
687           phi_arc_.ilabel = phi_match_;
688         } else {
689           phi_arc_.olabel = phi_match_;
690         }
691       }
692       return phi_arc_;
693     }
694   }
695 
Next()696   void Next() final { matcher_->Next(); }
697 
Final(StateId s)698   Weight Final(StateId s) const final {
699     auto weight = matcher_->Final(s);
700     if (phi_label_ == kNoLabel || weight != Weight::Zero()) {
701       return weight;
702     }
703     weight = Weight::One();
704     matcher_->SetState(s);
705     while (matcher_->Final(s) == Weight::Zero()) {
706       if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break;
707       weight = Times(weight, matcher_->Value().weight);
708       if (s == matcher_->Value().nextstate) {
709         return Weight::Zero();  // Does not follow phi self-loops.
710       }
711       s = matcher_->Value().nextstate;
712       matcher_->SetState(s);
713     }
714     weight = Times(weight, matcher_->Final(s));
715     return weight;
716   }
717 
Priority(StateId s)718   ssize_t Priority(StateId s) final {
719     if (phi_label_ != kNoLabel) {
720       matcher_->SetState(s);
721       const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_);
722       return has_phi ? kRequirePriority : matcher_->Priority(s);
723     } else {
724       return matcher_->Priority(s);
725     }
726   }
727 
GetFst()728   const FST &GetFst() const override { return matcher_->GetFst(); }
729 
730   uint64 Properties(uint64 props) const override;
731 
Flags()732   uint32 Flags() const override {
733     if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) {
734       return matcher_->Flags();
735     }
736     return matcher_->Flags() | kRequireMatch;
737   }
738 
PhiLabel()739   Label PhiLabel() const { return phi_label_; }
740 
741  private:
742   mutable std::unique_ptr<M> matcher_;
743   MatchType match_type_;  // Type of match requested.
744   Label phi_label_;       // Label that represents the phi transition.
745   bool rewrite_both_;     // Rewrite both sides when both are phi_label_?
746   bool has_phi_;          // Are there possibly phis at the current state?
747   Label phi_match_;       // Current label that matches phi loop.
748   mutable Arc phi_arc_;   // Arc to return.
749   StateId state_;         // Matcher state.
750   Weight phi_weight_;     // Product of the weights of phi transitions taken.
751   bool phi_loop_;         // When true, phi self-loop are allowed and treated
752                           // as rho (required for Aho-Corasick).
753   bool error_;            // Error encountered?
754 
755   PhiMatcher &operator=(const PhiMatcher &) = delete;
756 };
757 
758 template <class M>
Find(Label label)759 inline bool PhiMatcher<M>::Find(Label label) {
760   if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
761     FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
762     error_ = true;
763     return false;
764   }
765   matcher_->SetState(state_);
766   phi_match_ = kNoLabel;
767   phi_weight_ = Weight::One();
768   // If phi_label_ == 0, there are no more true epsilon arcs.
769   if (phi_label_ == 0) {
770     if (label == kNoLabel) {
771       return false;
772     }
773     if (label == 0) {  // but a virtual epsilon loop needs to be returned.
774       if (!matcher_->Find(kNoLabel)) {
775         return matcher_->Find(0);
776       } else {
777         phi_match_ = 0;
778         return true;
779       }
780     }
781   }
782   if (!has_phi_ || label == 0 || label == kNoLabel) {
783     return matcher_->Find(label);
784   }
785   auto s = state_;
786   while (!matcher_->Find(label)) {
787     // Look for phi transition (if phi_label_ == 0, we need to look
788     // for -1 to avoid getting the virtual self-loop)
789     if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false;
790     if (phi_loop_ && matcher_->Value().nextstate == s) {
791       phi_match_ = label;
792       return true;
793     }
794     phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
795     s = matcher_->Value().nextstate;
796     matcher_->Next();
797     if (!matcher_->Done()) {
798       FSTERROR() << "PhiMatcher: Phi non-determinism not supported";
799       error_ = true;
800     }
801     matcher_->SetState(s);
802   }
803   return true;
804 }
805 
806 template <class M>
Properties(uint64 inprops)807 inline uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
808   auto outprops = matcher_->Properties(inprops);
809   if (error_) outprops |= kError;
810   if (match_type_ == MATCH_NONE) {
811     return outprops;
812   } else if (match_type_ == MATCH_INPUT) {
813     if (phi_label_ == 0) {
814       outprops &= ~(kEpsilons | kIEpsilons | kOEpsilons);
815       outprops |= kNoEpsilons | kNoIEpsilons;
816     }
817     if (rewrite_both_) {
818       return outprops &
819              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
820                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
821     } else {
822       return outprops &
823              ~(kODeterministic | kAcceptor | kString | kILabelSorted |
824                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
825     }
826   } else if (match_type_ == MATCH_OUTPUT) {
827     if (phi_label_ == 0) {
828       outprops &= ~(kEpsilons | kIEpsilons | kOEpsilons);
829       outprops |= kNoEpsilons | kNoOEpsilons;
830     }
831     if (rewrite_both_) {
832       return outprops &
833              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
834                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
835     } else {
836       return outprops &
837              ~(kIDeterministic | kAcceptor | kString | kILabelSorted |
838                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
839     }
840   } else {
841     // Shouldn't ever get here.
842     FSTERROR() << "PhiMatcher: Bad match type: " << match_type_;
843     return 0;
844   }
845 }
846 
847 // For any requested label that doesn't match at a state, this matcher
848 // considers all transitions that match the label 'rho_label' (rho =
849 // 'rest'). Each such rho transition found is returned with the
850 // rho_label rewritten as the requested label (both sides if an
851 // acceptor, or if 'rewrite_both' is true and both input and output
852 // labels of the found transition are 'rho_label'). If 'rho_label' is
853 // kNoLabel, this special matching is not done. RhoMatcher is
854 // templated itself on a matcher, which is used to perform the
855 // underlying matching. By default, the underlying matcher is
856 // constructed by RhoMatcher. The user can instead pass in this
857 // object; in that case, RhoMatcher takes its ownership.
858 // No non-consuming symbols other than epsilon supported with
859 // the underlying template argument matcher.
860 template <class M>
861 class RhoMatcher : public MatcherBase<typename M::Arc> {
862  public:
863   using FST = typename M::FST;
864   using Arc = typename FST::Arc;
865   using Label = typename Arc::Label;
866   using StateId = typename Arc::StateId;
867   using Weight = typename Arc::Weight;
868 
869   // This makes a copy of the FST (w/o 'matcher' arg).
870   RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel,
871              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
872              M *matcher = nullptr)
873       : matcher_(matcher ? matcher : new M(fst, match_type)),
874         match_type_(match_type),
875         rho_label_(rho_label),
876         error_(false),
877         state_(kNoStateId),
878         has_rho_(false) {
879     if (match_type == MATCH_BOTH) {
880       FSTERROR() << "RhoMatcher: Bad match type";
881       match_type_ = MATCH_NONE;
882       error_ = true;
883     }
884     if (rho_label == 0) {
885       FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
886       rho_label_ = kNoLabel;
887       error_ = true;
888     }
889     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
890       rewrite_both_ = fst.Properties(kAcceptor, true);
891     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
892       rewrite_both_ = true;
893     } else {
894       rewrite_both_ = false;
895     }
896   }
897 
898   // This doesn't copy the FST.
899   RhoMatcher(const FST *fst, MatchType match_type, Label rho_label = kNoLabel,
900              MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
901              M *matcher = nullptr)
902       : RhoMatcher(*fst, match_type, rho_label, rewrite_mode,
903                    matcher ? matcher : new M(fst, match_type)) {}
904 
905   // This makes a copy of the FST.
906   RhoMatcher(const RhoMatcher &matcher, bool safe = false)
907       : matcher_(new M(*matcher.matcher_, safe)),
908         match_type_(matcher.match_type_),
909         rho_label_(matcher.rho_label_),
910         rewrite_both_(matcher.rewrite_both_),
911         error_(matcher.error_),
912         state_(kNoStateId),
913         has_rho_(false) {}
914 
915   RhoMatcher *Copy(bool safe = false) const override {
916     return new RhoMatcher(*this, safe);
917   }
918 
Type(bool test)919   MatchType Type(bool test) const override { return matcher_->Type(test); }
920 
SetState(StateId s)921   void SetState(StateId s) final {
922     if (state_ == s) return;
923     state_ = s;
924     matcher_->SetState(s);
925     has_rho_ = rho_label_ != kNoLabel;
926   }
927 
Find(Label label)928   bool Find(Label label) final {
929     if (label == rho_label_ && rho_label_ != kNoLabel) {
930       FSTERROR() << "RhoMatcher::Find: bad label (rho)";
931       error_ = true;
932       return false;
933     }
934     if (matcher_->Find(label)) {
935       rho_match_ = kNoLabel;
936       return true;
937     } else if (has_rho_ && label != 0 && label != kNoLabel &&
938                (has_rho_ = matcher_->Find(rho_label_))) {
939       rho_match_ = label;
940       return true;
941     } else {
942       return false;
943     }
944   }
945 
Done()946   bool Done() const final { return matcher_->Done(); }
947 
Value()948   const Arc &Value() const final {
949     if (rho_match_ == kNoLabel) {
950       return matcher_->Value();
951     } else {
952       rho_arc_ = matcher_->Value();
953       if (rewrite_both_) {
954         if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_;
955         if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_;
956       } else if (match_type_ == MATCH_INPUT) {
957         rho_arc_.ilabel = rho_match_;
958       } else {
959         rho_arc_.olabel = rho_match_;
960       }
961       return rho_arc_;
962     }
963   }
964 
Next()965   void Next() final { matcher_->Next(); }
966 
Final(StateId s)967   Weight Final(StateId s) const final { return matcher_->Final(s); }
968 
Priority(StateId s)969   ssize_t Priority(StateId s) final {
970     state_ = s;
971     matcher_->SetState(s);
972     has_rho_ = matcher_->Find(rho_label_);
973     if (has_rho_) {
974       return kRequirePriority;
975     } else {
976       return matcher_->Priority(s);
977     }
978   }
979 
GetFst()980   const FST &GetFst() const override { return matcher_->GetFst(); }
981 
982   uint64 Properties(uint64 props) const override;
983 
Flags()984   uint32 Flags() const override {
985     if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) {
986       return matcher_->Flags();
987     }
988     return matcher_->Flags() | kRequireMatch;
989   }
990 
RhoLabel()991   Label RhoLabel() const { return rho_label_; }
992 
993  private:
994   std::unique_ptr<M> matcher_;
995   MatchType match_type_;  // Type of match requested.
996   Label rho_label_;       // Label that represents the rho transition
997   bool rewrite_both_;     // Rewrite both sides when both are rho_label_?
998   Label rho_match_;       // Current label that matches rho transition.
999   mutable Arc rho_arc_;   // Arc to return when rho match.
1000   bool error_;            // Error encountered?
1001   StateId state_;         // Matcher state.
1002   bool has_rho_;          // Are there possibly rhos at the current state?
1003 };
1004 
1005 template <class M>
Properties(uint64 inprops)1006 inline uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
1007   auto outprops = matcher_->Properties(inprops);
1008   if (error_) outprops |= kError;
1009   if (match_type_ == MATCH_NONE) {
1010     return outprops;
1011   } else if (match_type_ == MATCH_INPUT) {
1012     if (rewrite_both_) {
1013       return outprops &
1014              ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
1015                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
1016     } else {
1017       return outprops & ~(kODeterministic | kAcceptor | kString |
1018                           kILabelSorted | kNotILabelSorted);
1019     }
1020   } else if (match_type_ == MATCH_OUTPUT) {
1021     if (rewrite_both_) {
1022       return outprops &
1023              ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
1024                kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
1025     } else {
1026       return outprops & ~(kIDeterministic | kAcceptor | kString |
1027                           kOLabelSorted | kNotOLabelSorted);
1028     }
1029   } else {
1030     // Shouldn't ever get here.
1031     FSTERROR() << "RhoMatcher: Bad match type: " << match_type_;
1032     return 0;
1033   }
1034 }
1035 
1036 // For any requested label, this matcher considers all transitions
1037 // that match the label 'sigma_label' (sigma = "any"), and this in
1038 // additions to transitions with the requested label. Each such sigma
1039 // transition found is returned with the sigma_label rewritten as the
1040 // requested label (both sides if an acceptor, or if 'rewrite_both' is
1041 // true and both input and output labels of the found transition are
1042 // 'sigma_label'). If 'sigma_label' is kNoLabel, this special
1043 // matching is not done. SigmaMatcher is templated itself on a
1044 // matcher, which is used to perform the underlying matching. By
1045 // default, the underlying matcher is constructed by SigmaMatcher.
1046 // The user can instead pass in this object; in that case,
1047 // SigmaMatcher takes its ownership. No non-consuming symbols other
1048 // than epsilon supported with the underlying template argument matcher.
1049 template <class M>
1050 class SigmaMatcher : public MatcherBase<typename M::Arc> {
1051  public:
1052   using FST = typename M::FST;
1053   using Arc = typename FST::Arc;
1054   using Label = typename Arc::Label;
1055   using StateId = typename Arc::StateId;
1056   using Weight = typename Arc::Weight;
1057 
1058   // This makes a copy of the FST (w/o 'matcher' arg).
1059   SigmaMatcher(const FST &fst, MatchType match_type,
1060                Label sigma_label = kNoLabel,
1061                MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
1062                M *matcher = nullptr)
1063       : matcher_(matcher ? matcher : new M(fst, match_type)),
1064         match_type_(match_type),
1065         sigma_label_(sigma_label),
1066         error_(false),
1067         state_(kNoStateId) {
1068     if (match_type == MATCH_BOTH) {
1069       FSTERROR() << "SigmaMatcher: Bad match type";
1070       match_type_ = MATCH_NONE;
1071       error_ = true;
1072     }
1073     if (sigma_label == 0) {
1074       FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
1075       sigma_label_ = kNoLabel;
1076       error_ = true;
1077     }
1078     if (rewrite_mode == MATCHER_REWRITE_AUTO) {
1079       rewrite_both_ = fst.Properties(kAcceptor, true);
1080     } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
1081       rewrite_both_ = true;
1082     } else {
1083       rewrite_both_ = false;
1084     }
1085   }
1086 
1087   // This doesn't copy the FST.
1088   SigmaMatcher(const FST *fst, MatchType match_type,
1089                Label sigma_label = kNoLabel,
1090                MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
1091                M *matcher = nullptr)
1092       : SigmaMatcher(*fst, match_type, sigma_label, rewrite_mode,
1093                      matcher ? matcher : new M(fst, match_type)) {}
1094 
1095   // This makes a copy of the FST.
1096   SigmaMatcher(const SigmaMatcher &matcher, bool safe = false)
1097       : matcher_(new M(*matcher.matcher_, safe)),
1098         match_type_(matcher.match_type_),
1099         sigma_label_(matcher.sigma_label_),
1100         rewrite_both_(matcher.rewrite_both_),
1101         error_(matcher.error_),
1102         state_(kNoStateId) {}
1103 
1104   SigmaMatcher *Copy(bool safe = false) const override {
1105     return new SigmaMatcher(*this, safe);
1106   }
1107 
Type(bool test)1108   MatchType Type(bool test) const override { return matcher_->Type(test); }
1109 
SetState(StateId s)1110   void SetState(StateId s) final {
1111     if (state_ == s) return;
1112     state_ = s;
1113     matcher_->SetState(s);
1114     has_sigma_ =
1115         (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false;
1116   }
1117 
Find(Label match_label)1118   bool Find(Label match_label) final {
1119     match_label_ = match_label;
1120     if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
1121       FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
1122       error_ = true;
1123       return false;
1124     }
1125     if (matcher_->Find(match_label)) {
1126       sigma_match_ = kNoLabel;
1127       return true;
1128     } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
1129                matcher_->Find(sigma_label_)) {
1130       sigma_match_ = match_label;
1131       return true;
1132     } else {
1133       return false;
1134     }
1135   }
1136 
Done()1137   bool Done() const final { return matcher_->Done(); }
1138 
Value()1139   const Arc &Value() const final {
1140     if (sigma_match_ == kNoLabel) {
1141       return matcher_->Value();
1142     } else {
1143       sigma_arc_ = matcher_->Value();
1144       if (rewrite_both_) {
1145         if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_;
1146         if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_;
1147       } else if (match_type_ == MATCH_INPUT) {
1148         sigma_arc_.ilabel = sigma_match_;
1149       } else {
1150         sigma_arc_.olabel = sigma_match_;
1151       }
1152       return sigma_arc_;
1153     }
1154   }
1155 
Next()1156   void Next() final {
1157     matcher_->Next();
1158     if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
1159         (match_label_ > 0)) {
1160       matcher_->Find(sigma_label_);
1161       sigma_match_ = match_label_;
1162     }
1163   }
1164 
Final(StateId s)1165   Weight Final(StateId s) const final { return matcher_->Final(s); }
1166 
Priority(StateId s)1167   ssize_t Priority(StateId s) final {
1168     if (sigma_label_ != kNoLabel) {
1169       SetState(s);
1170       return has_sigma_ ? kRequirePriority : matcher_->Priority(s);
1171     } else {
1172       return matcher_->Priority(s);
1173     }
1174   }
1175 
GetFst()1176   const FST &GetFst() const override { return matcher_->GetFst(); }
1177 
1178   uint64 Properties(uint64 props) const override;
1179 
Flags()1180   uint32 Flags() const override {
1181     if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) {
1182       return matcher_->Flags();
1183     }
1184     return matcher_->Flags() | kRequireMatch;
1185   }
1186 
SigmaLabel()1187   Label SigmaLabel() const { return sigma_label_; }
1188 
1189  private:
1190   std::unique_ptr<M> matcher_;
1191   MatchType match_type_;   // Type of match requested.
1192   Label sigma_label_;      // Label that represents the sigma transition.
1193   bool rewrite_both_;      // Rewrite both sides when both are sigma_label_?
1194   bool has_sigma_;         // Are there sigmas at the current state?
1195   Label sigma_match_;      // Current label that matches sigma transition.
1196   mutable Arc sigma_arc_;  // Arc to return when sigma match.
1197   Label match_label_;      // Label being matched.
1198   bool error_;             // Error encountered?
1199   StateId state_;          // Matcher state.
1200 };
1201 
1202 template <class M>
Properties(uint64 inprops)1203 inline uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
1204   auto outprops = matcher_->Properties(inprops);
1205   if (error_) outprops |= kError;
1206   if (match_type_ == MATCH_NONE) {
1207     return outprops;
1208   } else if (rewrite_both_) {
1209     return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1210                         kNonODeterministic | kILabelSorted | kNotILabelSorted |
1211                         kOLabelSorted | kNotOLabelSorted | kString);
1212   } else if (match_type_ == MATCH_INPUT) {
1213     return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1214                         kNonODeterministic | kILabelSorted | kNotILabelSorted |
1215                         kString | kAcceptor);
1216   } else if (match_type_ == MATCH_OUTPUT) {
1217     return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1218                         kNonODeterministic | kOLabelSorted | kNotOLabelSorted |
1219                         kString | kAcceptor);
1220   } else {
1221     // Shouldn't ever get here.
1222     FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_;
1223     return 0;
1224   }
1225 }
1226 
1227 // Flags for MultiEpsMatcher.
1228 
1229 // Return multi-epsilon arcs for Find(kNoLabel).
1230 const uint32 kMultiEpsList = 0x00000001;
1231 
1232 // Return a kNolabel loop for Find(multi_eps).
1233 const uint32 kMultiEpsLoop = 0x00000002;
1234 
1235 // MultiEpsMatcher: allows treating multiple non-0 labels as
1236 // non-consuming labels in addition to 0 that is always
1237 // non-consuming. Precise behavior controlled by 'flags' argument. By
1238 // default, the underlying matcher is constructed by
1239 // MultiEpsMatcher. The user can instead pass in this object; in that
1240 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
1241 // true.
1242 template <class M>
1243 class MultiEpsMatcher {
1244  public:
1245   using FST = typename M::FST;
1246   using Arc = typename FST::Arc;
1247   using Label = typename Arc::Label;
1248   using StateId = typename Arc::StateId;
1249   using Weight = typename Arc::Weight;
1250 
1251   // This makes a copy of the FST (w/o 'matcher' arg).
1252   MultiEpsMatcher(const FST &fst, MatchType match_type,
1253                   uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1254                   M *matcher = nullptr, bool own_matcher = true)
1255       : matcher_(matcher ? matcher : new M(fst, match_type)),
1256         flags_(flags),
1257         own_matcher_(matcher ? own_matcher : true) {
1258     Init(match_type);
1259   }
1260 
1261   // This doesn't copy the FST.
1262   MultiEpsMatcher(const FST *fst, MatchType match_type,
1263                   uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1264                   M *matcher = nullptr, bool own_matcher = true)
1265       : matcher_(matcher ? matcher : new M(fst, match_type)),
1266         flags_(flags),
1267         own_matcher_(matcher ? own_matcher : true) {
1268     Init(match_type);
1269   }
1270 
1271   // This makes a copy of the FST.
1272   MultiEpsMatcher(const MultiEpsMatcher &matcher, bool safe = false)
1273       : matcher_(new M(*matcher.matcher_, safe)),
1274         flags_(matcher.flags_),
1275         own_matcher_(true),
1276         multi_eps_labels_(matcher.multi_eps_labels_),
1277         loop_(matcher.loop_) {
1278     loop_.nextstate = kNoStateId;
1279   }
1280 
~MultiEpsMatcher()1281   ~MultiEpsMatcher() {
1282     if (own_matcher_) delete matcher_;
1283   }
1284 
1285   MultiEpsMatcher *Copy(bool safe = false) const {
1286     return new MultiEpsMatcher(*this, safe);
1287   }
1288 
Type(bool test)1289   MatchType Type(bool test) const { return matcher_->Type(test); }
1290 
SetState(StateId state)1291   void SetState(StateId state) {
1292     matcher_->SetState(state);
1293     loop_.nextstate = state;
1294   }
1295 
1296   bool Find(Label label);
1297 
Done()1298   bool Done() const { return done_; }
1299 
Value()1300   const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); }
1301 
Next()1302   void Next() {
1303     if (!current_loop_) {
1304       matcher_->Next();
1305       done_ = matcher_->Done();
1306       if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1307         ++multi_eps_iter_;
1308         while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1309                !matcher_->Find(*multi_eps_iter_)) {
1310           ++multi_eps_iter_;
1311         }
1312         if (multi_eps_iter_ != multi_eps_labels_.End()) {
1313           done_ = false;
1314         } else {
1315           done_ = !matcher_->Find(kNoLabel);
1316         }
1317       }
1318     } else {
1319       done_ = true;
1320     }
1321   }
1322 
GetFst()1323   const FST &GetFst() const { return matcher_->GetFst(); }
1324 
Properties(uint64 props)1325   uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1326 
GetMatcher()1327   const M *GetMatcher() const { return matcher_; }
1328 
Final(StateId s)1329   Weight Final(StateId s) const { return matcher_->Final(s); }
1330 
Flags()1331   uint32 Flags() const { return matcher_->Flags(); }
1332 
Priority(StateId s)1333   ssize_t Priority(StateId s) { return matcher_->Priority(s); }
1334 
AddMultiEpsLabel(Label label)1335   void AddMultiEpsLabel(Label label) {
1336     if (label == 0) {
1337       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1338     } else {
1339       multi_eps_labels_.Insert(label);
1340     }
1341   }
1342 
RemoveMultiEpsLabel(Label label)1343   void RemoveMultiEpsLabel(Label label) {
1344     if (label == 0) {
1345       FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1346     } else {
1347       multi_eps_labels_.Erase(label);
1348     }
1349   }
1350 
ClearMultiEpsLabels()1351   void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); }
1352 
1353  private:
Init(MatchType match_type)1354   void Init(MatchType match_type) {
1355     if (match_type == MATCH_INPUT) {
1356       loop_.ilabel = kNoLabel;
1357       loop_.olabel = 0;
1358     } else {
1359       loop_.ilabel = 0;
1360       loop_.olabel = kNoLabel;
1361     }
1362     loop_.weight = Weight::One();
1363     loop_.nextstate = kNoStateId;
1364   }
1365 
1366   M *matcher_;
1367   uint32 flags_;
1368   bool own_matcher_;  // Does this class delete the matcher?
1369 
1370   // Multi-eps label set.
1371   CompactSet<Label, kNoLabel> multi_eps_labels_;
1372   typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1373 
1374   bool current_loop_;  // Current arc is the implicit loop?
1375   mutable Arc loop_;   // For non-consuming symbols.
1376   bool done_;          // Matching done?
1377 
1378   MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete;
1379 };
1380 
1381 template <class M>
Find(Label label)1382 inline bool MultiEpsMatcher<M>::Find(Label label) {
1383   multi_eps_iter_ = multi_eps_labels_.End();
1384   current_loop_ = false;
1385   bool ret;
1386   if (label == 0) {
1387     ret = matcher_->Find(0);
1388   } else if (label == kNoLabel) {
1389     if (flags_ & kMultiEpsList) {
1390       // Returns all non-consuming arcs (including epsilon).
1391       multi_eps_iter_ = multi_eps_labels_.Begin();
1392       while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1393              !matcher_->Find(*multi_eps_iter_)) {
1394         ++multi_eps_iter_;
1395       }
1396       if (multi_eps_iter_ != multi_eps_labels_.End()) {
1397         ret = true;
1398       } else {
1399         ret = matcher_->Find(kNoLabel);
1400       }
1401     } else {
1402       // Returns all epsilon arcs.
1403       ret = matcher_->Find(kNoLabel);
1404     }
1405   } else if ((flags_ & kMultiEpsLoop) &&
1406              multi_eps_labels_.Find(label) != multi_eps_labels_.End()) {
1407     // Returns implicit loop.
1408     current_loop_ = true;
1409     ret = true;
1410   } else {
1411     ret = matcher_->Find(label);
1412   }
1413   done_ = !ret;
1414   return ret;
1415 }
1416 
1417 // This class discards any implicit matches (e.g., the implicit epsilon
1418 // self-loops in the SortedMatcher). Matchers are most often used in
1419 // composition/intersection where the implicit matches are needed
1420 // e.g. for epsilon processing. However, if a matcher is simply being
1421 // used to look-up explicit label matches, this class saves the user
1422 // from having to check for and discard the unwanted implicit matches
1423 // themselves.
1424 template <class M>
1425 class ExplicitMatcher : public MatcherBase<typename M::Arc> {
1426  public:
1427   using FST = typename M::FST;
1428   using Arc = typename FST::Arc;
1429   using Label = typename Arc::Label;
1430   using StateId = typename Arc::StateId;
1431   using Weight = typename Arc::Weight;
1432 
1433   // This makes a copy of the FST.
1434   ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr)
1435       : matcher_(matcher ? matcher : new M(fst, match_type)),
1436         match_type_(match_type),
1437         error_(false) {}
1438 
1439   // This doesn't copy the FST.
1440   ExplicitMatcher(const FST *fst, MatchType match_type, M *matcher = nullptr)
1441       : matcher_(matcher ? matcher : new M(fst, match_type)),
1442         match_type_(match_type),
1443         error_(false) {}
1444 
1445   // This makes a copy of the FST.
1446   ExplicitMatcher(const ExplicitMatcher &matcher, bool safe = false)
1447       : matcher_(new M(*matcher.matcher_, safe)),
1448         match_type_(matcher.match_type_),
1449         error_(matcher.error_) {}
1450 
1451   ExplicitMatcher *Copy(bool safe = false) const override {
1452     return new ExplicitMatcher(*this, safe);
1453   }
1454 
Type(bool test)1455   MatchType Type(bool test) const override { return matcher_->Type(test); }
1456 
SetState(StateId s)1457   void SetState(StateId s) final { matcher_->SetState(s); }
1458 
Find(Label label)1459   bool Find(Label label) final {
1460     matcher_->Find(label);
1461     CheckArc();
1462     return !Done();
1463   }
1464 
Done()1465   bool Done() const final { return matcher_->Done(); }
1466 
Value()1467   const Arc &Value() const final { return matcher_->Value(); }
1468 
Next()1469   void Next() final {
1470     matcher_->Next();
1471     CheckArc();
1472   }
1473 
Final(StateId s)1474   Weight Final(StateId s) const final { return matcher_->Final(s); }
1475 
Priority(StateId s)1476   ssize_t Priority(StateId s) final { return matcher_->Priority(s); }
1477 
GetFst()1478   const FST &GetFst() const final { return matcher_->GetFst(); }
1479 
Properties(uint64 inprops)1480   uint64 Properties(uint64 inprops) const override {
1481     return matcher_->Properties(inprops);
1482   }
1483 
GetMatcher()1484   const M *GetMatcher() const { return matcher_.get(); }
1485 
Flags()1486   uint32 Flags() const override { return matcher_->Flags(); }
1487 
1488  private:
1489   // Checks current arc if available and explicit. If not available, stops. If
1490   // not explicit, checks next ones.
CheckArc()1491   void CheckArc() {
1492     for (; !matcher_->Done(); matcher_->Next()) {
1493       const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel
1494                                                     : matcher_->Value().olabel;
1495       if (label != kNoLabel) return;
1496     }
1497   }
1498 
1499   std::unique_ptr<M> matcher_;
1500   MatchType match_type_;  // Type of match requested.
1501   bool error_;            // Error encountered?
1502 };
1503 
1504 // Generic matcher, templated on the FST definition.
1505 //
1506 // Here is a typical use:
1507 //
1508 //   Matcher<StdFst> matcher(fst, MATCH_INPUT);
1509 //   matcher.SetState(state);
1510 //   if (matcher.Find(label))
1511 //     for (; !matcher.Done(); matcher.Next()) {
1512 //       auto &arc = matcher.Value();
1513 //       ...
1514 //     }
1515 template <class F>
1516 class Matcher {
1517  public:
1518   using FST = F;
1519   using Arc = typename F::Arc;
1520   using Label = typename Arc::Label;
1521   using StateId = typename Arc::StateId;
1522   using Weight = typename Arc::Weight;
1523 
1524   // This makes a copy of the FST.
Matcher(const FST & fst,MatchType match_type)1525   Matcher(const FST &fst, MatchType match_type)
1526       : owned_fst_(fst.Copy()), base_(owned_fst_->InitMatcher(match_type)) {
1527     if (!base_)
1528       base_ =
1529           std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
1530   }
1531 
1532   // This doesn't copy the FST.
Matcher(const FST * fst,MatchType match_type)1533   Matcher(const FST *fst, MatchType match_type)
1534       : base_(fst->InitMatcher(match_type)) {
1535     if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
1536   }
1537 
1538   // This makes a copy of the FST.
1539   Matcher(const Matcher &matcher, bool safe = false)
1540       : base_(matcher.base_->Copy(safe)) {}
1541 
1542   // Takes ownership of the provided matcher.
Matcher(MatcherBase<Arc> * base_matcher)1543   explicit Matcher(MatcherBase<Arc> *base_matcher) : base_(base_matcher) {}
1544 
1545   Matcher *Copy(bool safe = false) const { return new Matcher(*this, safe); }
1546 
Type(bool test)1547   MatchType Type(bool test) const { return base_->Type(test); }
1548 
SetState(StateId s)1549   void SetState(StateId s) { base_->SetState(s); }
1550 
Find(Label label)1551   bool Find(Label label) { return base_->Find(label); }
1552 
Done()1553   bool Done() const { return base_->Done(); }
1554 
Value()1555   const Arc &Value() const { return base_->Value(); }
1556 
Next()1557   void Next() { base_->Next(); }
1558 
GetFst()1559   const FST &GetFst() const { return fst::down_cast<const FST &>(base_->GetFst()); }
1560 
Properties(uint64 props)1561   uint64 Properties(uint64 props) const { return base_->Properties(props); }
1562 
Final(StateId s)1563   Weight Final(StateId s) const { return base_->Final(s); }
1564 
Flags()1565   uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1566 
Priority(StateId s)1567   ssize_t Priority(StateId s) { return base_->Priority(s); }
1568 
1569  private:
1570   std::unique_ptr<const FST> owned_fst_;
1571   std::unique_ptr<MatcherBase<Arc>> base_;
1572 };
1573 
1574 }  // namespace fst
1575 
1576 #endif  // FST_MATCHER_H_
1577