1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Classes to allow matching labels leaving FST states.
19
20 #ifndef FST_MATCHER_H_
21 #define FST_MATCHER_H_
22
23 #include <algorithm>
24 #include <memory>
25 #include <utility>
26
27 #include <fst/types.h>
28 #include <fst/log.h>
29
30 #include <fst/mutable-fst.h> // for all internal FST accessors.
31
32 #include <unordered_map>
33
34 namespace fst {
35
36 // Matchers find and iterate through requested labels at FST states. In the
37 // simplest form, these are just some associative map or search keyed on labels.
38 // More generally, they may implement matching special labels that represent
39 // sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher
40 // interface is:
41 //
42 // template <class F>
43 // class Matcher {
44 // public:
45 // using FST = F;
46 // using Arc = typename FST::Arc;
47 // using Label = typename Arc::Label;
48 // using StateId = typename Arc::StateId;
49 // using Weight = typename Arc::Weight;
50 //
51 // // Required constructors. Note:
52 // // -- the constructors that copy the FST arg are useful for
53 // // letting the matcher manage the FST through copies
54 // // (esp with 'safe' copies); e.g. ComposeFst depends on this.
55 // // -- the constructor that does not copy is useful when the
56 // // the FST is mutated during the lifetime of the matcher
57 // // (o.w. the matcher would have its own unmutated deep copy).
58 //
59 // // This makes a copy of the FST.
60 // Matcher(const FST &fst, MatchType type);
61 // // This doesn't copy the FST.
62 // Matcher(const FST *fst, MatchType type);
63 // // This makes a copy of the FST.
64 // // See Copy() below.
65 // Matcher(const Matcher &matcher, bool safe = false);
66 //
67 // // If safe = true, the copy is thread-safe. See Fst<>::Copy() for
68 // // further doc.
69 // Matcher *Copy(bool safe = false) const override;
70 //
71 // // Returns the match type that can be provided (depending on compatibility
72 // // of the input FST). It is either the requested match type, MATCH_NONE,
73 // // or MATCH_UNKNOWN. If test is false, a costly testing is avoided, but
74 // // MATCH_UNKNOWN may be returned. If test is true, a definite answer is
75 // // returned, but may involve more costly computation (e.g., visiting
76 // // the FST).
77 // // MatchType Type(bool test) const override;
78 //
79 // // Specifies the current state.
80 // void SetState(StateId s) final;
81 //
82 // // Finds matches to a label at the current state, returning true if a match
83 // // found. kNoLabel matches any non-consuming transitions, e.g., epsilon
84 // // transitions, which do not require a matching symbol.
85 // bool Find(Label label) final;
86 //
87 // // Iterator methods. Note that initially and after SetState() these have
88 // // undefined behavior until Find() is called.
89 //
90 // bool Done() const final;
91 //
92 // const Arc &Value() const final;
93 //
94 // void Next() final;
95 //
96 // // Returns final weight of a state.
97 // Weight Final(StateId) const final;
98 //
99 // // Indicates preference for being the side used for matching in
100 // // composition. If the value is kRequirePriority, then it is
101 // // mandatory that it be used. Calling this method without passing the
102 // // current state of the matcher invalidates the state of the matcher.
103 // ssize_t Priority(StateId s) final;
104 //
105 // // This specifies the known FST properties as viewed from this matcher. It
106 // // takes as argument the input FST's known properties.
107 // uint64 Properties(uint64 props) const override;
108 //
109 // // Returns matcher flags.
110 // uint32 Flags() const override;
111 //
112 // // Returns matcher FST.
113 // const FST &GetFst() const override;
114 // };
115
116 // Basic matcher flags.
117
118 // Matcher needs to be used as the matching side in composition for
119 // at least one state (has kRequirePriority).
120 constexpr uint32 kRequireMatch = 0x00000001;
121
122 // Flags used for basic matchers (see also lookahead.h).
123 constexpr uint32 kMatcherFlags = kRequireMatch;
124
125 // Matcher priority that is mandatory.
126 constexpr ssize_t kRequirePriority = -1;
127
128 // Matcher interface, templated on the Arc definition; used for matcher
129 // specializations that are returned by the InitMatcher FST method.
130 template <class A>
131 class MatcherBase {
132 public:
133 using Arc = A;
134 using Label = typename Arc::Label;
135 using StateId = typename Arc::StateId;
136 using Weight = typename Arc::Weight;
137
~MatcherBase()138 virtual ~MatcherBase() {}
139
140 // Virtual interface.
141
142 virtual MatcherBase *Copy(bool safe = false) const = 0;
143 virtual MatchType Type(bool) const = 0;
144 virtual void SetState(StateId) = 0;
145 virtual bool Find(Label) = 0;
146 virtual bool Done() const = 0;
147 virtual const Arc &Value() const = 0;
148 virtual void Next() = 0;
149 virtual const Fst<Arc> &GetFst() const = 0;
150 virtual uint64 Properties(uint64) const = 0;
151
152 // Trivial implementations that can be used by derived classes. Full
153 // devirtualization is expected for any derived class marked final.
Flags()154 virtual uint32 Flags() const { return 0; }
155
Final(StateId s)156 virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); }
157
Priority(StateId s)158 virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); }
159 };
160
161 // A matcher that expects sorted labels on the side to be matched.
162 // If match_type == MATCH_INPUT, epsilons match the implicit self-loop
163 // Arc(kNoLabel, 0, Weight::One(), current_state) as well as any
164 // actual epsilon transitions. If match_type == MATCH_OUTPUT, then
165 // Arc(0, kNoLabel, Weight::One(), current_state) is instead matched.
166 template <class F>
167 class SortedMatcher : public MatcherBase<typename F::Arc> {
168 public:
169 using FST = F;
170 using Arc = typename FST::Arc;
171 using Label = typename Arc::Label;
172 using StateId = typename Arc::StateId;
173 using Weight = typename Arc::Weight;
174
175 using MatcherBase<Arc>::Flags;
176 using MatcherBase<Arc>::Properties;
177
178 // Labels >= binary_label will be searched for by binary search;
179 // o.w. linear search is used.
180 // This makes a copy of the FST.
181 SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1)
182 : SortedMatcher(fst.Copy(), match_type, binary_label) {
183 owned_fst_.reset(&fst_);
184 }
185
186 // Labels >= binary_label will be searched for by binary search;
187 // o.w. linear search is used.
188 // This doesn't copy the FST.
189 SortedMatcher(const FST *fst, MatchType match_type, Label binary_label = 1)
190 : fst_(*fst),
191 state_(kNoStateId),
192 aiter_(nullptr),
193 match_type_(match_type),
194 binary_label_(binary_label),
195 match_label_(kNoLabel),
196 narcs_(0),
197 loop_(kNoLabel, 0, Weight::One(), kNoStateId),
198 error_(false),
199 aiter_pool_(1) {
200 switch (match_type_) {
201 case MATCH_INPUT:
202 case MATCH_NONE:
203 break;
204 case MATCH_OUTPUT:
205 std::swap(loop_.ilabel, loop_.olabel);
206 break;
207 default:
208 FSTERROR() << "SortedMatcher: Bad match type";
209 match_type_ = MATCH_NONE;
210 error_ = true;
211 }
212 }
213
214 // This makes a copy of the FST.
215 SortedMatcher(const SortedMatcher &matcher, bool safe = false)
216 : owned_fst_(matcher.fst_.Copy(safe)),
217 fst_(*owned_fst_),
218 state_(kNoStateId),
219 aiter_(nullptr),
220 match_type_(matcher.match_type_),
221 binary_label_(matcher.binary_label_),
222 match_label_(kNoLabel),
223 narcs_(0),
224 loop_(matcher.loop_),
225 error_(matcher.error_),
226 aiter_pool_(1) {}
227
~SortedMatcher()228 ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); }
229
230 SortedMatcher *Copy(bool safe = false) const override {
231 return new SortedMatcher(*this, safe);
232 }
233
Type(bool test)234 MatchType Type(bool test) const override {
235 if (match_type_ == MATCH_NONE) return match_type_;
236 const auto true_prop =
237 match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted;
238 const auto false_prop =
239 match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted;
240 const auto props = fst_.Properties(true_prop | false_prop, test);
241 if (props & true_prop) {
242 return match_type_;
243 } else if (props & false_prop) {
244 return MATCH_NONE;
245 } else {
246 return MATCH_UNKNOWN;
247 }
248 }
249
SetState(StateId s)250 void SetState(StateId s) final {
251 if (state_ == s) return;
252 state_ = s;
253 if (match_type_ == MATCH_NONE) {
254 FSTERROR() << "SortedMatcher: Bad match type";
255 error_ = true;
256 }
257 Destroy(aiter_, &aiter_pool_);
258 aiter_ = new (&aiter_pool_) ArcIterator<FST>(fst_, s);
259 aiter_->SetFlags(kArcNoCache, kArcNoCache);
260 narcs_ = internal::NumArcs(fst_, s);
261 loop_.nextstate = s;
262 }
263
Find(Label match_label)264 bool Find(Label match_label) final {
265 exact_match_ = true;
266 if (error_) {
267 current_loop_ = false;
268 match_label_ = kNoLabel;
269 return false;
270 }
271 current_loop_ = match_label == 0;
272 match_label_ = match_label == kNoLabel ? 0 : match_label;
273 if (Search()) {
274 return true;
275 } else {
276 return current_loop_;
277 }
278 }
279
280 // Positions matcher to the first position where inserting match_label would
281 // maintain the sort order.
LowerBound(Label label)282 void LowerBound(Label label) {
283 exact_match_ = false;
284 current_loop_ = false;
285 if (error_) {
286 match_label_ = kNoLabel;
287 return;
288 }
289 match_label_ = label;
290 Search();
291 }
292
293 // After Find(), returns false if no more exact matches.
294 // After LowerBound(), returns false if no more arcs.
Done()295 bool Done() const final {
296 if (current_loop_) return false;
297 if (aiter_->Done()) return true;
298 if (!exact_match_) return false;
299 aiter_->SetFlags(
300 match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
301 kArcValueFlags);
302 return GetLabel() != match_label_;
303 }
304
Value()305 const Arc &Value() const final {
306 if (current_loop_) return loop_;
307 aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
308 return aiter_->Value();
309 }
310
Next()311 void Next() final {
312 if (current_loop_) {
313 current_loop_ = false;
314 } else {
315 aiter_->Next();
316 }
317 }
318
Final(StateId s)319 Weight Final(StateId s) const final { return MatcherBase<Arc>::Final(s); }
320
Priority(StateId s)321 ssize_t Priority(StateId s) final { return MatcherBase<Arc>::Priority(s); }
322
GetFst()323 const FST &GetFst() const override { return fst_; }
324
Properties(uint64 inprops)325 uint64 Properties(uint64 inprops) const override {
326 return inprops | (error_ ? kError : 0);
327 }
328
Position()329 size_t Position() const { return aiter_ ? aiter_->Position() : 0; }
330
331 private:
GetLabel()332 Label GetLabel() const {
333 const auto &arc = aiter_->Value();
334 return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
335 }
336
337 bool BinarySearch();
338 bool LinearSearch();
339 bool Search();
340
341 std::unique_ptr<const FST> owned_fst_; // FST ptr if owned.
342 const FST &fst_; // FST for matching.
343 StateId state_; // Matcher state.
344 ArcIterator<FST> *aiter_; // Iterator for current state.
345 MatchType match_type_; // Type of match to perform.
346 Label binary_label_; // Least label for binary search.
347 Label match_label_; // Current label to be matched.
348 size_t narcs_; // Current state arc count.
349 Arc loop_; // For non-consuming symbols.
350 bool current_loop_; // Current arc is the implicit loop.
351 bool exact_match_; // Exact match or lower bound?
352 bool error_; // Error encountered?
353 MemoryPool<ArcIterator<FST>> aiter_pool_; // Pool of arc iterators.
354 };
355
356 // Returns true iff match to match_label_. The arc iterator is positioned at the
357 // lower bound, that is, the first element greater than or equal to
358 // match_label_, or the end if all elements are less than match_label_.
359 // If multiple elements are equal to the `match_label_`, returns the rightmost
360 // one.
361 template <class FST>
BinarySearch()362 inline bool SortedMatcher<FST>::BinarySearch() {
363 size_t size = narcs_;
364 if (size == 0) {
365 return false;
366 }
367 size_t high = size - 1;
368 while (size > 1) {
369 const size_t half = size / 2;
370 const size_t mid = high - half;
371 aiter_->Seek(mid);
372 if (GetLabel() >= match_label_) {
373 high = mid;
374 }
375 size -= half;
376 }
377 aiter_->Seek(high);
378 const auto label = GetLabel();
379 if (label == match_label_) {
380 return true;
381 }
382 if (label < match_label_) {
383 aiter_->Next();
384 }
385 return false;
386 }
387
388 // Returns true iff match to match_label_, positioning arc iterator at lower
389 // bound.
390 template <class FST>
LinearSearch()391 inline bool SortedMatcher<FST>::LinearSearch() {
392 for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) {
393 const auto label = GetLabel();
394 if (label == match_label_) return true;
395 if (label > match_label_) break;
396 }
397 return false;
398 }
399
400 // Returns true iff match to match_label_, positioning arc iterator at lower
401 // bound.
402 template <class FST>
Search()403 inline bool SortedMatcher<FST>::Search() {
404 aiter_->SetFlags(
405 match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue,
406 kArcValueFlags);
407 if (match_label_ >= binary_label_) {
408 return BinarySearch();
409 } else {
410 return LinearSearch();
411 }
412 }
413
414 // A matcher that stores labels in a per-state hash table populated upon the
415 // first visit to that state. Sorting is not required. Treatment of
416 // epsilons are the same as with SortedMatcher.
417 template <class F>
418 class HashMatcher : public MatcherBase<typename F::Arc> {
419 public:
420 using FST = F;
421 using Arc = typename FST::Arc;
422 using Label = typename Arc::Label;
423 using StateId = typename Arc::StateId;
424 using Weight = typename Arc::Weight;
425
426 using MatcherBase<Arc>::Flags;
427 using MatcherBase<Arc>::Final;
428 using MatcherBase<Arc>::Priority;
429
430 // This makes a copy of the FST.
HashMatcher(const FST & fst,MatchType match_type)431 HashMatcher(const FST &fst, MatchType match_type)
432 : HashMatcher(fst.Copy(), match_type) {
433 owned_fst_.reset(&fst_);
434 }
435
436 // This doesn't copy the FST.
HashMatcher(const FST * fst,MatchType match_type)437 HashMatcher(const FST *fst, MatchType match_type)
438 : fst_(*fst),
439 state_(kNoStateId),
440 match_type_(match_type),
441 loop_(kNoLabel, 0, Weight::One(), kNoStateId),
442 error_(false),
443 state_table_(std::make_shared<StateTable>()) {
444 switch (match_type_) {
445 case MATCH_INPUT:
446 case MATCH_NONE:
447 break;
448 case MATCH_OUTPUT:
449 std::swap(loop_.ilabel, loop_.olabel);
450 break;
451 default:
452 FSTERROR() << "HashMatcher: Bad match type";
453 match_type_ = MATCH_NONE;
454 error_ = true;
455 }
456 }
457
458 // This makes a copy of the FST.
459 HashMatcher(const HashMatcher &matcher, bool safe = false)
460 : owned_fst_(matcher.fst_.Copy(safe)),
461 fst_(*owned_fst_),
462 state_(kNoStateId),
463 match_type_(matcher.match_type_),
464 loop_(matcher.loop_),
465 error_(matcher.error_),
466 state_table_(safe ? std::make_shared<StateTable>()
467 : matcher.state_table_) {}
468
469 HashMatcher *Copy(bool safe = false) const override {
470 return new HashMatcher(*this, safe);
471 }
472
473 // The argument is ignored as there are no relevant properties to test.
Type(bool test)474 MatchType Type(bool test) const override { return match_type_; }
475
476 void SetState(StateId s) final;
477
Find(Label label)478 bool Find(Label label) final {
479 current_loop_ = label == 0;
480 if (label == 0) {
481 Search(label);
482 return true;
483 }
484 if (label == kNoLabel) label = 0;
485 return Search(label);
486 }
487
Done()488 bool Done() const final {
489 if (current_loop_) return false;
490 return label_it_ == label_end_;
491 }
492
Value()493 const Arc &Value() const final {
494 if (current_loop_) return loop_;
495 aiter_->Seek(label_it_->second);
496 return aiter_->Value();
497 }
498
Next()499 void Next() final {
500 if (current_loop_) {
501 current_loop_ = false;
502 } else {
503 ++label_it_;
504 }
505 }
506
GetFst()507 const FST &GetFst() const override { return fst_; }
508
Properties(uint64 inprops)509 uint64 Properties(uint64 inprops) const override {
510 return inprops | (error_ ? kError : 0);
511 }
512
513 private:
GetLabel()514 Label GetLabel() const {
515 const auto &arc = aiter_->Value();
516 return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel;
517 }
518
519 bool Search(Label match_label);
520
521 using LabelTable = std::unordered_multimap<Label, size_t>;
522 using StateTable = std::unordered_map<StateId, std::unique_ptr<LabelTable>>;
523
524 std::unique_ptr<const FST> owned_fst_; // ptr to FST if owned.
525 const FST &fst_; // FST for matching.
526 StateId state_; // Matcher state.
527 MatchType match_type_;
528 Arc loop_; // The implicit loop itself.
529 bool current_loop_; // Is the current arc the implicit loop?
530 bool error_; // Error encountered?
531 std::unique_ptr<ArcIterator<FST>> aiter_;
532 std::shared_ptr<StateTable> state_table_; // Table from state to label table.
533 LabelTable *label_table_; // Pointer to current state's label table.
534 typename LabelTable::iterator label_it_; // Position for label.
535 typename LabelTable::iterator label_end_; // Position for last label + 1.
536 };
537
538 template <class FST>
SetState(typename FST::Arc::StateId s)539 void HashMatcher<FST>::SetState(typename FST::Arc::StateId s) {
540 if (state_ == s) return;
541 // Resets everything for the state.
542 state_ = s;
543 loop_.nextstate = state_;
544 aiter_ = std::make_unique<ArcIterator<FST>>(fst_, state_);
545 if (match_type_ == MATCH_NONE) {
546 FSTERROR() << "HashMatcher: Bad match type";
547 error_ = true;
548 }
549 // Attempts to insert a new label table.
550 auto it_and_success = state_table_->emplace(
551 state_, std::make_unique<LabelTable>());
552 // Sets instance's pointer to the label table for this state.
553 label_table_ = it_and_success.first->second.get();
554 // If it already exists, no additional work is done and we simply return.
555 if (!it_and_success.second) return;
556 // Otherwise, populate this new table.
557 // Populates the label table.
558 label_table_->reserve(internal::NumArcs(fst_, state_));
559 const auto aiter_flags =
560 (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) |
561 kArcNoCache;
562 aiter_->SetFlags(aiter_flags, kArcFlags);
563 for (; !aiter_->Done(); aiter_->Next()) {
564 label_table_->emplace(GetLabel(), aiter_->Position());
565 }
566 aiter_->SetFlags(kArcValueFlags, kArcValueFlags);
567 }
568
569 template <class FST>
Search(typename FST::Arc::Label match_label)570 inline bool HashMatcher<FST>::Search(typename FST::Arc::Label match_label) {
571 auto range = label_table_->equal_range(match_label);
572 label_it_ = range.first;
573 label_end_ = range.second;
574 if (label_it_ == label_end_) return false;
575 aiter_->Seek(label_it_->second);
576 return true;
577 }
578
579 // Specifies whether we rewrite both the input and output sides during matching.
580 enum MatcherRewriteMode {
581 MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor.
582 MATCHER_REWRITE_ALWAYS,
583 MATCHER_REWRITE_NEVER
584 };
585
586 // For any requested label that doesn't match at a state, this matcher
587 // considers the *unique* transition that matches the label 'phi_label'
588 // (phi = 'fail'), and recursively looks for a match at its
589 // destination. When 'phi_loop' is true, if no match is found but a
590 // phi self-loop is found, then the phi transition found is returned
591 // with the phi_label rewritten as the requested label (both sides if
592 // an acceptor, or if 'rewrite_both' is true and both input and output
593 // labels of the found transition are 'phi_label'). If 'phi_label' is
594 // kNoLabel, this special matching is not done. PhiMatcher is
595 // templated itself on a matcher, which is used to perform the
596 // underlying matching. By default, the underlying matcher is
597 // constructed by PhiMatcher. The user can instead pass in this
598 // object; in that case, PhiMatcher takes its ownership.
599 // Phi non-determinism not supported. No non-consuming symbols other
600 // than epsilon supported with the underlying template argument matcher.
601 template <class M>
602 class PhiMatcher : public MatcherBase<typename M::Arc> {
603 public:
604 using FST = typename M::FST;
605 using Arc = typename FST::Arc;
606 using Label = typename Arc::Label;
607 using StateId = typename Arc::StateId;
608 using Weight = typename Arc::Weight;
609
610 // This makes a copy of the FST (w/o 'matcher' arg).
611 PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel,
612 bool phi_loop = true,
613 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
614 M *matcher = nullptr)
615 : matcher_(matcher ? matcher : new M(fst, match_type)),
616 match_type_(match_type),
617 phi_label_(phi_label),
618 state_(kNoStateId),
619 phi_loop_(phi_loop),
620 error_(false) {
621 if (match_type == MATCH_BOTH) {
622 FSTERROR() << "PhiMatcher: Bad match type";
623 match_type_ = MATCH_NONE;
624 error_ = true;
625 }
626 if (rewrite_mode == MATCHER_REWRITE_AUTO) {
627 rewrite_both_ = fst.Properties(kAcceptor, true);
628 } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
629 rewrite_both_ = true;
630 } else {
631 rewrite_both_ = false;
632 }
633 }
634
635 // This doesn't copy the FST.
636 PhiMatcher(const FST *fst, MatchType match_type, Label phi_label = kNoLabel,
637 bool phi_loop = true,
638 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
639 M *matcher = nullptr)
640 : PhiMatcher(*fst, match_type, phi_label, phi_loop, rewrite_mode,
641 matcher ? matcher : new M(fst, match_type)) {}
642
643 // This makes a copy of the FST.
644 PhiMatcher(const PhiMatcher &matcher, bool safe = false)
645 : matcher_(new M(*matcher.matcher_, safe)),
646 match_type_(matcher.match_type_),
647 phi_label_(matcher.phi_label_),
648 rewrite_both_(matcher.rewrite_both_),
649 state_(kNoStateId),
650 phi_loop_(matcher.phi_loop_),
651 error_(matcher.error_) {}
652
653 PhiMatcher *Copy(bool safe = false) const override {
654 return new PhiMatcher(*this, safe);
655 }
656
Type(bool test)657 MatchType Type(bool test) const override { return matcher_->Type(test); }
658
SetState(StateId s)659 void SetState(StateId s) final {
660 if (state_ == s) return;
661 matcher_->SetState(s);
662 state_ = s;
663 has_phi_ = phi_label_ != kNoLabel;
664 }
665
666 bool Find(Label match_label) final;
667
Done()668 bool Done() const final { return matcher_->Done(); }
669
Value()670 const Arc &Value() const final {
671 if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) {
672 return matcher_->Value();
673 } else if (phi_match_ == 0) { // Virtual epsilon loop.
674 phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_);
675 if (match_type_ == MATCH_OUTPUT) {
676 std::swap(phi_arc_.ilabel, phi_arc_.olabel);
677 }
678 return phi_arc_;
679 } else {
680 phi_arc_ = matcher_->Value();
681 phi_arc_.weight = Times(phi_weight_, phi_arc_.weight);
682 if (phi_match_ != kNoLabel) { // Phi loop match.
683 if (rewrite_both_) {
684 if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_;
685 if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_;
686 } else if (match_type_ == MATCH_INPUT) {
687 phi_arc_.ilabel = phi_match_;
688 } else {
689 phi_arc_.olabel = phi_match_;
690 }
691 }
692 return phi_arc_;
693 }
694 }
695
Next()696 void Next() final { matcher_->Next(); }
697
Final(StateId s)698 Weight Final(StateId s) const final {
699 auto weight = matcher_->Final(s);
700 if (phi_label_ == kNoLabel || weight != Weight::Zero()) {
701 return weight;
702 }
703 weight = Weight::One();
704 matcher_->SetState(s);
705 while (matcher_->Final(s) == Weight::Zero()) {
706 if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break;
707 weight = Times(weight, matcher_->Value().weight);
708 if (s == matcher_->Value().nextstate) {
709 return Weight::Zero(); // Does not follow phi self-loops.
710 }
711 s = matcher_->Value().nextstate;
712 matcher_->SetState(s);
713 }
714 weight = Times(weight, matcher_->Final(s));
715 return weight;
716 }
717
Priority(StateId s)718 ssize_t Priority(StateId s) final {
719 if (phi_label_ != kNoLabel) {
720 matcher_->SetState(s);
721 const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_);
722 return has_phi ? kRequirePriority : matcher_->Priority(s);
723 } else {
724 return matcher_->Priority(s);
725 }
726 }
727
GetFst()728 const FST &GetFst() const override { return matcher_->GetFst(); }
729
730 uint64 Properties(uint64 props) const override;
731
Flags()732 uint32 Flags() const override {
733 if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) {
734 return matcher_->Flags();
735 }
736 return matcher_->Flags() | kRequireMatch;
737 }
738
PhiLabel()739 Label PhiLabel() const { return phi_label_; }
740
741 private:
742 mutable std::unique_ptr<M> matcher_;
743 MatchType match_type_; // Type of match requested.
744 Label phi_label_; // Label that represents the phi transition.
745 bool rewrite_both_; // Rewrite both sides when both are phi_label_?
746 bool has_phi_; // Are there possibly phis at the current state?
747 Label phi_match_; // Current label that matches phi loop.
748 mutable Arc phi_arc_; // Arc to return.
749 StateId state_; // Matcher state.
750 Weight phi_weight_; // Product of the weights of phi transitions taken.
751 bool phi_loop_; // When true, phi self-loop are allowed and treated
752 // as rho (required for Aho-Corasick).
753 bool error_; // Error encountered?
754
755 PhiMatcher &operator=(const PhiMatcher &) = delete;
756 };
757
758 template <class M>
Find(Label label)759 inline bool PhiMatcher<M>::Find(Label label) {
760 if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) {
761 FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_;
762 error_ = true;
763 return false;
764 }
765 matcher_->SetState(state_);
766 phi_match_ = kNoLabel;
767 phi_weight_ = Weight::One();
768 // If phi_label_ == 0, there are no more true epsilon arcs.
769 if (phi_label_ == 0) {
770 if (label == kNoLabel) {
771 return false;
772 }
773 if (label == 0) { // but a virtual epsilon loop needs to be returned.
774 if (!matcher_->Find(kNoLabel)) {
775 return matcher_->Find(0);
776 } else {
777 phi_match_ = 0;
778 return true;
779 }
780 }
781 }
782 if (!has_phi_ || label == 0 || label == kNoLabel) {
783 return matcher_->Find(label);
784 }
785 auto s = state_;
786 while (!matcher_->Find(label)) {
787 // Look for phi transition (if phi_label_ == 0, we need to look
788 // for -1 to avoid getting the virtual self-loop)
789 if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false;
790 if (phi_loop_ && matcher_->Value().nextstate == s) {
791 phi_match_ = label;
792 return true;
793 }
794 phi_weight_ = Times(phi_weight_, matcher_->Value().weight);
795 s = matcher_->Value().nextstate;
796 matcher_->Next();
797 if (!matcher_->Done()) {
798 FSTERROR() << "PhiMatcher: Phi non-determinism not supported";
799 error_ = true;
800 }
801 matcher_->SetState(s);
802 }
803 return true;
804 }
805
806 template <class M>
Properties(uint64 inprops)807 inline uint64 PhiMatcher<M>::Properties(uint64 inprops) const {
808 auto outprops = matcher_->Properties(inprops);
809 if (error_) outprops |= kError;
810 if (match_type_ == MATCH_NONE) {
811 return outprops;
812 } else if (match_type_ == MATCH_INPUT) {
813 if (phi_label_ == 0) {
814 outprops &= ~(kEpsilons | kIEpsilons | kOEpsilons);
815 outprops |= kNoEpsilons | kNoIEpsilons;
816 }
817 if (rewrite_both_) {
818 return outprops &
819 ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
820 kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
821 } else {
822 return outprops &
823 ~(kODeterministic | kAcceptor | kString | kILabelSorted |
824 kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
825 }
826 } else if (match_type_ == MATCH_OUTPUT) {
827 if (phi_label_ == 0) {
828 outprops &= ~(kEpsilons | kIEpsilons | kOEpsilons);
829 outprops |= kNoEpsilons | kNoOEpsilons;
830 }
831 if (rewrite_both_) {
832 return outprops &
833 ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
834 kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
835 } else {
836 return outprops &
837 ~(kIDeterministic | kAcceptor | kString | kILabelSorted |
838 kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
839 }
840 } else {
841 // Shouldn't ever get here.
842 FSTERROR() << "PhiMatcher: Bad match type: " << match_type_;
843 return 0;
844 }
845 }
846
847 // For any requested label that doesn't match at a state, this matcher
848 // considers all transitions that match the label 'rho_label' (rho =
849 // 'rest'). Each such rho transition found is returned with the
850 // rho_label rewritten as the requested label (both sides if an
851 // acceptor, or if 'rewrite_both' is true and both input and output
852 // labels of the found transition are 'rho_label'). If 'rho_label' is
853 // kNoLabel, this special matching is not done. RhoMatcher is
854 // templated itself on a matcher, which is used to perform the
855 // underlying matching. By default, the underlying matcher is
856 // constructed by RhoMatcher. The user can instead pass in this
857 // object; in that case, RhoMatcher takes its ownership.
858 // No non-consuming symbols other than epsilon supported with
859 // the underlying template argument matcher.
860 template <class M>
861 class RhoMatcher : public MatcherBase<typename M::Arc> {
862 public:
863 using FST = typename M::FST;
864 using Arc = typename FST::Arc;
865 using Label = typename Arc::Label;
866 using StateId = typename Arc::StateId;
867 using Weight = typename Arc::Weight;
868
869 // This makes a copy of the FST (w/o 'matcher' arg).
870 RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel,
871 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
872 M *matcher = nullptr)
873 : matcher_(matcher ? matcher : new M(fst, match_type)),
874 match_type_(match_type),
875 rho_label_(rho_label),
876 error_(false),
877 state_(kNoStateId),
878 has_rho_(false) {
879 if (match_type == MATCH_BOTH) {
880 FSTERROR() << "RhoMatcher: Bad match type";
881 match_type_ = MATCH_NONE;
882 error_ = true;
883 }
884 if (rho_label == 0) {
885 FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label";
886 rho_label_ = kNoLabel;
887 error_ = true;
888 }
889 if (rewrite_mode == MATCHER_REWRITE_AUTO) {
890 rewrite_both_ = fst.Properties(kAcceptor, true);
891 } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
892 rewrite_both_ = true;
893 } else {
894 rewrite_both_ = false;
895 }
896 }
897
898 // This doesn't copy the FST.
899 RhoMatcher(const FST *fst, MatchType match_type, Label rho_label = kNoLabel,
900 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
901 M *matcher = nullptr)
902 : RhoMatcher(*fst, match_type, rho_label, rewrite_mode,
903 matcher ? matcher : new M(fst, match_type)) {}
904
905 // This makes a copy of the FST.
906 RhoMatcher(const RhoMatcher &matcher, bool safe = false)
907 : matcher_(new M(*matcher.matcher_, safe)),
908 match_type_(matcher.match_type_),
909 rho_label_(matcher.rho_label_),
910 rewrite_both_(matcher.rewrite_both_),
911 error_(matcher.error_),
912 state_(kNoStateId),
913 has_rho_(false) {}
914
915 RhoMatcher *Copy(bool safe = false) const override {
916 return new RhoMatcher(*this, safe);
917 }
918
Type(bool test)919 MatchType Type(bool test) const override { return matcher_->Type(test); }
920
SetState(StateId s)921 void SetState(StateId s) final {
922 if (state_ == s) return;
923 state_ = s;
924 matcher_->SetState(s);
925 has_rho_ = rho_label_ != kNoLabel;
926 }
927
Find(Label label)928 bool Find(Label label) final {
929 if (label == rho_label_ && rho_label_ != kNoLabel) {
930 FSTERROR() << "RhoMatcher::Find: bad label (rho)";
931 error_ = true;
932 return false;
933 }
934 if (matcher_->Find(label)) {
935 rho_match_ = kNoLabel;
936 return true;
937 } else if (has_rho_ && label != 0 && label != kNoLabel &&
938 (has_rho_ = matcher_->Find(rho_label_))) {
939 rho_match_ = label;
940 return true;
941 } else {
942 return false;
943 }
944 }
945
Done()946 bool Done() const final { return matcher_->Done(); }
947
Value()948 const Arc &Value() const final {
949 if (rho_match_ == kNoLabel) {
950 return matcher_->Value();
951 } else {
952 rho_arc_ = matcher_->Value();
953 if (rewrite_both_) {
954 if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_;
955 if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_;
956 } else if (match_type_ == MATCH_INPUT) {
957 rho_arc_.ilabel = rho_match_;
958 } else {
959 rho_arc_.olabel = rho_match_;
960 }
961 return rho_arc_;
962 }
963 }
964
Next()965 void Next() final { matcher_->Next(); }
966
Final(StateId s)967 Weight Final(StateId s) const final { return matcher_->Final(s); }
968
Priority(StateId s)969 ssize_t Priority(StateId s) final {
970 state_ = s;
971 matcher_->SetState(s);
972 has_rho_ = matcher_->Find(rho_label_);
973 if (has_rho_) {
974 return kRequirePriority;
975 } else {
976 return matcher_->Priority(s);
977 }
978 }
979
GetFst()980 const FST &GetFst() const override { return matcher_->GetFst(); }
981
982 uint64 Properties(uint64 props) const override;
983
Flags()984 uint32 Flags() const override {
985 if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) {
986 return matcher_->Flags();
987 }
988 return matcher_->Flags() | kRequireMatch;
989 }
990
RhoLabel()991 Label RhoLabel() const { return rho_label_; }
992
993 private:
994 std::unique_ptr<M> matcher_;
995 MatchType match_type_; // Type of match requested.
996 Label rho_label_; // Label that represents the rho transition
997 bool rewrite_both_; // Rewrite both sides when both are rho_label_?
998 Label rho_match_; // Current label that matches rho transition.
999 mutable Arc rho_arc_; // Arc to return when rho match.
1000 bool error_; // Error encountered?
1001 StateId state_; // Matcher state.
1002 bool has_rho_; // Are there possibly rhos at the current state?
1003 };
1004
1005 template <class M>
Properties(uint64 inprops)1006 inline uint64 RhoMatcher<M>::Properties(uint64 inprops) const {
1007 auto outprops = matcher_->Properties(inprops);
1008 if (error_) outprops |= kError;
1009 if (match_type_ == MATCH_NONE) {
1010 return outprops;
1011 } else if (match_type_ == MATCH_INPUT) {
1012 if (rewrite_both_) {
1013 return outprops &
1014 ~(kODeterministic | kNonODeterministic | kString | kILabelSorted |
1015 kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
1016 } else {
1017 return outprops & ~(kODeterministic | kAcceptor | kString |
1018 kILabelSorted | kNotILabelSorted);
1019 }
1020 } else if (match_type_ == MATCH_OUTPUT) {
1021 if (rewrite_both_) {
1022 return outprops &
1023 ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted |
1024 kNotILabelSorted | kOLabelSorted | kNotOLabelSorted);
1025 } else {
1026 return outprops & ~(kIDeterministic | kAcceptor | kString |
1027 kOLabelSorted | kNotOLabelSorted);
1028 }
1029 } else {
1030 // Shouldn't ever get here.
1031 FSTERROR() << "RhoMatcher: Bad match type: " << match_type_;
1032 return 0;
1033 }
1034 }
1035
1036 // For any requested label, this matcher considers all transitions
1037 // that match the label 'sigma_label' (sigma = "any"), and this in
1038 // additions to transitions with the requested label. Each such sigma
1039 // transition found is returned with the sigma_label rewritten as the
1040 // requested label (both sides if an acceptor, or if 'rewrite_both' is
1041 // true and both input and output labels of the found transition are
1042 // 'sigma_label'). If 'sigma_label' is kNoLabel, this special
1043 // matching is not done. SigmaMatcher is templated itself on a
1044 // matcher, which is used to perform the underlying matching. By
1045 // default, the underlying matcher is constructed by SigmaMatcher.
1046 // The user can instead pass in this object; in that case,
1047 // SigmaMatcher takes its ownership. No non-consuming symbols other
1048 // than epsilon supported with the underlying template argument matcher.
1049 template <class M>
1050 class SigmaMatcher : public MatcherBase<typename M::Arc> {
1051 public:
1052 using FST = typename M::FST;
1053 using Arc = typename FST::Arc;
1054 using Label = typename Arc::Label;
1055 using StateId = typename Arc::StateId;
1056 using Weight = typename Arc::Weight;
1057
1058 // This makes a copy of the FST (w/o 'matcher' arg).
1059 SigmaMatcher(const FST &fst, MatchType match_type,
1060 Label sigma_label = kNoLabel,
1061 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
1062 M *matcher = nullptr)
1063 : matcher_(matcher ? matcher : new M(fst, match_type)),
1064 match_type_(match_type),
1065 sigma_label_(sigma_label),
1066 error_(false),
1067 state_(kNoStateId) {
1068 if (match_type == MATCH_BOTH) {
1069 FSTERROR() << "SigmaMatcher: Bad match type";
1070 match_type_ = MATCH_NONE;
1071 error_ = true;
1072 }
1073 if (sigma_label == 0) {
1074 FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label";
1075 sigma_label_ = kNoLabel;
1076 error_ = true;
1077 }
1078 if (rewrite_mode == MATCHER_REWRITE_AUTO) {
1079 rewrite_both_ = fst.Properties(kAcceptor, true);
1080 } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) {
1081 rewrite_both_ = true;
1082 } else {
1083 rewrite_both_ = false;
1084 }
1085 }
1086
1087 // This doesn't copy the FST.
1088 SigmaMatcher(const FST *fst, MatchType match_type,
1089 Label sigma_label = kNoLabel,
1090 MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO,
1091 M *matcher = nullptr)
1092 : SigmaMatcher(*fst, match_type, sigma_label, rewrite_mode,
1093 matcher ? matcher : new M(fst, match_type)) {}
1094
1095 // This makes a copy of the FST.
1096 SigmaMatcher(const SigmaMatcher &matcher, bool safe = false)
1097 : matcher_(new M(*matcher.matcher_, safe)),
1098 match_type_(matcher.match_type_),
1099 sigma_label_(matcher.sigma_label_),
1100 rewrite_both_(matcher.rewrite_both_),
1101 error_(matcher.error_),
1102 state_(kNoStateId) {}
1103
1104 SigmaMatcher *Copy(bool safe = false) const override {
1105 return new SigmaMatcher(*this, safe);
1106 }
1107
Type(bool test)1108 MatchType Type(bool test) const override { return matcher_->Type(test); }
1109
SetState(StateId s)1110 void SetState(StateId s) final {
1111 if (state_ == s) return;
1112 state_ = s;
1113 matcher_->SetState(s);
1114 has_sigma_ =
1115 (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false;
1116 }
1117
Find(Label match_label)1118 bool Find(Label match_label) final {
1119 match_label_ = match_label;
1120 if (match_label == sigma_label_ && sigma_label_ != kNoLabel) {
1121 FSTERROR() << "SigmaMatcher::Find: bad label (sigma)";
1122 error_ = true;
1123 return false;
1124 }
1125 if (matcher_->Find(match_label)) {
1126 sigma_match_ = kNoLabel;
1127 return true;
1128 } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel &&
1129 matcher_->Find(sigma_label_)) {
1130 sigma_match_ = match_label;
1131 return true;
1132 } else {
1133 return false;
1134 }
1135 }
1136
Done()1137 bool Done() const final { return matcher_->Done(); }
1138
Value()1139 const Arc &Value() const final {
1140 if (sigma_match_ == kNoLabel) {
1141 return matcher_->Value();
1142 } else {
1143 sigma_arc_ = matcher_->Value();
1144 if (rewrite_both_) {
1145 if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_;
1146 if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_;
1147 } else if (match_type_ == MATCH_INPUT) {
1148 sigma_arc_.ilabel = sigma_match_;
1149 } else {
1150 sigma_arc_.olabel = sigma_match_;
1151 }
1152 return sigma_arc_;
1153 }
1154 }
1155
Next()1156 void Next() final {
1157 matcher_->Next();
1158 if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) &&
1159 (match_label_ > 0)) {
1160 matcher_->Find(sigma_label_);
1161 sigma_match_ = match_label_;
1162 }
1163 }
1164
Final(StateId s)1165 Weight Final(StateId s) const final { return matcher_->Final(s); }
1166
Priority(StateId s)1167 ssize_t Priority(StateId s) final {
1168 if (sigma_label_ != kNoLabel) {
1169 SetState(s);
1170 return has_sigma_ ? kRequirePriority : matcher_->Priority(s);
1171 } else {
1172 return matcher_->Priority(s);
1173 }
1174 }
1175
GetFst()1176 const FST &GetFst() const override { return matcher_->GetFst(); }
1177
1178 uint64 Properties(uint64 props) const override;
1179
Flags()1180 uint32 Flags() const override {
1181 if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) {
1182 return matcher_->Flags();
1183 }
1184 return matcher_->Flags() | kRequireMatch;
1185 }
1186
SigmaLabel()1187 Label SigmaLabel() const { return sigma_label_; }
1188
1189 private:
1190 std::unique_ptr<M> matcher_;
1191 MatchType match_type_; // Type of match requested.
1192 Label sigma_label_; // Label that represents the sigma transition.
1193 bool rewrite_both_; // Rewrite both sides when both are sigma_label_?
1194 bool has_sigma_; // Are there sigmas at the current state?
1195 Label sigma_match_; // Current label that matches sigma transition.
1196 mutable Arc sigma_arc_; // Arc to return when sigma match.
1197 Label match_label_; // Label being matched.
1198 bool error_; // Error encountered?
1199 StateId state_; // Matcher state.
1200 };
1201
1202 template <class M>
Properties(uint64 inprops)1203 inline uint64 SigmaMatcher<M>::Properties(uint64 inprops) const {
1204 auto outprops = matcher_->Properties(inprops);
1205 if (error_) outprops |= kError;
1206 if (match_type_ == MATCH_NONE) {
1207 return outprops;
1208 } else if (rewrite_both_) {
1209 return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1210 kNonODeterministic | kILabelSorted | kNotILabelSorted |
1211 kOLabelSorted | kNotOLabelSorted | kString);
1212 } else if (match_type_ == MATCH_INPUT) {
1213 return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1214 kNonODeterministic | kILabelSorted | kNotILabelSorted |
1215 kString | kAcceptor);
1216 } else if (match_type_ == MATCH_OUTPUT) {
1217 return outprops & ~(kIDeterministic | kNonIDeterministic | kODeterministic |
1218 kNonODeterministic | kOLabelSorted | kNotOLabelSorted |
1219 kString | kAcceptor);
1220 } else {
1221 // Shouldn't ever get here.
1222 FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_;
1223 return 0;
1224 }
1225 }
1226
1227 // Flags for MultiEpsMatcher.
1228
1229 // Return multi-epsilon arcs for Find(kNoLabel).
1230 const uint32 kMultiEpsList = 0x00000001;
1231
1232 // Return a kNolabel loop for Find(multi_eps).
1233 const uint32 kMultiEpsLoop = 0x00000002;
1234
1235 // MultiEpsMatcher: allows treating multiple non-0 labels as
1236 // non-consuming labels in addition to 0 that is always
1237 // non-consuming. Precise behavior controlled by 'flags' argument. By
1238 // default, the underlying matcher is constructed by
1239 // MultiEpsMatcher. The user can instead pass in this object; in that
1240 // case, MultiEpsMatcher takes its ownership iff 'own_matcher' is
1241 // true.
1242 template <class M>
1243 class MultiEpsMatcher {
1244 public:
1245 using FST = typename M::FST;
1246 using Arc = typename FST::Arc;
1247 using Label = typename Arc::Label;
1248 using StateId = typename Arc::StateId;
1249 using Weight = typename Arc::Weight;
1250
1251 // This makes a copy of the FST (w/o 'matcher' arg).
1252 MultiEpsMatcher(const FST &fst, MatchType match_type,
1253 uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1254 M *matcher = nullptr, bool own_matcher = true)
1255 : matcher_(matcher ? matcher : new M(fst, match_type)),
1256 flags_(flags),
1257 own_matcher_(matcher ? own_matcher : true) {
1258 Init(match_type);
1259 }
1260
1261 // This doesn't copy the FST.
1262 MultiEpsMatcher(const FST *fst, MatchType match_type,
1263 uint32 flags = (kMultiEpsLoop | kMultiEpsList),
1264 M *matcher = nullptr, bool own_matcher = true)
1265 : matcher_(matcher ? matcher : new M(fst, match_type)),
1266 flags_(flags),
1267 own_matcher_(matcher ? own_matcher : true) {
1268 Init(match_type);
1269 }
1270
1271 // This makes a copy of the FST.
1272 MultiEpsMatcher(const MultiEpsMatcher &matcher, bool safe = false)
1273 : matcher_(new M(*matcher.matcher_, safe)),
1274 flags_(matcher.flags_),
1275 own_matcher_(true),
1276 multi_eps_labels_(matcher.multi_eps_labels_),
1277 loop_(matcher.loop_) {
1278 loop_.nextstate = kNoStateId;
1279 }
1280
~MultiEpsMatcher()1281 ~MultiEpsMatcher() {
1282 if (own_matcher_) delete matcher_;
1283 }
1284
1285 MultiEpsMatcher *Copy(bool safe = false) const {
1286 return new MultiEpsMatcher(*this, safe);
1287 }
1288
Type(bool test)1289 MatchType Type(bool test) const { return matcher_->Type(test); }
1290
SetState(StateId state)1291 void SetState(StateId state) {
1292 matcher_->SetState(state);
1293 loop_.nextstate = state;
1294 }
1295
1296 bool Find(Label label);
1297
Done()1298 bool Done() const { return done_; }
1299
Value()1300 const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); }
1301
Next()1302 void Next() {
1303 if (!current_loop_) {
1304 matcher_->Next();
1305 done_ = matcher_->Done();
1306 if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) {
1307 ++multi_eps_iter_;
1308 while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1309 !matcher_->Find(*multi_eps_iter_)) {
1310 ++multi_eps_iter_;
1311 }
1312 if (multi_eps_iter_ != multi_eps_labels_.End()) {
1313 done_ = false;
1314 } else {
1315 done_ = !matcher_->Find(kNoLabel);
1316 }
1317 }
1318 } else {
1319 done_ = true;
1320 }
1321 }
1322
GetFst()1323 const FST &GetFst() const { return matcher_->GetFst(); }
1324
Properties(uint64 props)1325 uint64 Properties(uint64 props) const { return matcher_->Properties(props); }
1326
GetMatcher()1327 const M *GetMatcher() const { return matcher_; }
1328
Final(StateId s)1329 Weight Final(StateId s) const { return matcher_->Final(s); }
1330
Flags()1331 uint32 Flags() const { return matcher_->Flags(); }
1332
Priority(StateId s)1333 ssize_t Priority(StateId s) { return matcher_->Priority(s); }
1334
AddMultiEpsLabel(Label label)1335 void AddMultiEpsLabel(Label label) {
1336 if (label == 0) {
1337 FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1338 } else {
1339 multi_eps_labels_.Insert(label);
1340 }
1341 }
1342
RemoveMultiEpsLabel(Label label)1343 void RemoveMultiEpsLabel(Label label) {
1344 if (label == 0) {
1345 FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0";
1346 } else {
1347 multi_eps_labels_.Erase(label);
1348 }
1349 }
1350
ClearMultiEpsLabels()1351 void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); }
1352
1353 private:
Init(MatchType match_type)1354 void Init(MatchType match_type) {
1355 if (match_type == MATCH_INPUT) {
1356 loop_.ilabel = kNoLabel;
1357 loop_.olabel = 0;
1358 } else {
1359 loop_.ilabel = 0;
1360 loop_.olabel = kNoLabel;
1361 }
1362 loop_.weight = Weight::One();
1363 loop_.nextstate = kNoStateId;
1364 }
1365
1366 M *matcher_;
1367 uint32 flags_;
1368 bool own_matcher_; // Does this class delete the matcher?
1369
1370 // Multi-eps label set.
1371 CompactSet<Label, kNoLabel> multi_eps_labels_;
1372 typename CompactSet<Label, kNoLabel>::const_iterator multi_eps_iter_;
1373
1374 bool current_loop_; // Current arc is the implicit loop?
1375 mutable Arc loop_; // For non-consuming symbols.
1376 bool done_; // Matching done?
1377
1378 MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete;
1379 };
1380
1381 template <class M>
Find(Label label)1382 inline bool MultiEpsMatcher<M>::Find(Label label) {
1383 multi_eps_iter_ = multi_eps_labels_.End();
1384 current_loop_ = false;
1385 bool ret;
1386 if (label == 0) {
1387 ret = matcher_->Find(0);
1388 } else if (label == kNoLabel) {
1389 if (flags_ & kMultiEpsList) {
1390 // Returns all non-consuming arcs (including epsilon).
1391 multi_eps_iter_ = multi_eps_labels_.Begin();
1392 while ((multi_eps_iter_ != multi_eps_labels_.End()) &&
1393 !matcher_->Find(*multi_eps_iter_)) {
1394 ++multi_eps_iter_;
1395 }
1396 if (multi_eps_iter_ != multi_eps_labels_.End()) {
1397 ret = true;
1398 } else {
1399 ret = matcher_->Find(kNoLabel);
1400 }
1401 } else {
1402 // Returns all epsilon arcs.
1403 ret = matcher_->Find(kNoLabel);
1404 }
1405 } else if ((flags_ & kMultiEpsLoop) &&
1406 multi_eps_labels_.Find(label) != multi_eps_labels_.End()) {
1407 // Returns implicit loop.
1408 current_loop_ = true;
1409 ret = true;
1410 } else {
1411 ret = matcher_->Find(label);
1412 }
1413 done_ = !ret;
1414 return ret;
1415 }
1416
1417 // This class discards any implicit matches (e.g., the implicit epsilon
1418 // self-loops in the SortedMatcher). Matchers are most often used in
1419 // composition/intersection where the implicit matches are needed
1420 // e.g. for epsilon processing. However, if a matcher is simply being
1421 // used to look-up explicit label matches, this class saves the user
1422 // from having to check for and discard the unwanted implicit matches
1423 // themselves.
1424 template <class M>
1425 class ExplicitMatcher : public MatcherBase<typename M::Arc> {
1426 public:
1427 using FST = typename M::FST;
1428 using Arc = typename FST::Arc;
1429 using Label = typename Arc::Label;
1430 using StateId = typename Arc::StateId;
1431 using Weight = typename Arc::Weight;
1432
1433 // This makes a copy of the FST.
1434 ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr)
1435 : matcher_(matcher ? matcher : new M(fst, match_type)),
1436 match_type_(match_type),
1437 error_(false) {}
1438
1439 // This doesn't copy the FST.
1440 ExplicitMatcher(const FST *fst, MatchType match_type, M *matcher = nullptr)
1441 : matcher_(matcher ? matcher : new M(fst, match_type)),
1442 match_type_(match_type),
1443 error_(false) {}
1444
1445 // This makes a copy of the FST.
1446 ExplicitMatcher(const ExplicitMatcher &matcher, bool safe = false)
1447 : matcher_(new M(*matcher.matcher_, safe)),
1448 match_type_(matcher.match_type_),
1449 error_(matcher.error_) {}
1450
1451 ExplicitMatcher *Copy(bool safe = false) const override {
1452 return new ExplicitMatcher(*this, safe);
1453 }
1454
Type(bool test)1455 MatchType Type(bool test) const override { return matcher_->Type(test); }
1456
SetState(StateId s)1457 void SetState(StateId s) final { matcher_->SetState(s); }
1458
Find(Label label)1459 bool Find(Label label) final {
1460 matcher_->Find(label);
1461 CheckArc();
1462 return !Done();
1463 }
1464
Done()1465 bool Done() const final { return matcher_->Done(); }
1466
Value()1467 const Arc &Value() const final { return matcher_->Value(); }
1468
Next()1469 void Next() final {
1470 matcher_->Next();
1471 CheckArc();
1472 }
1473
Final(StateId s)1474 Weight Final(StateId s) const final { return matcher_->Final(s); }
1475
Priority(StateId s)1476 ssize_t Priority(StateId s) final { return matcher_->Priority(s); }
1477
GetFst()1478 const FST &GetFst() const final { return matcher_->GetFst(); }
1479
Properties(uint64 inprops)1480 uint64 Properties(uint64 inprops) const override {
1481 return matcher_->Properties(inprops);
1482 }
1483
GetMatcher()1484 const M *GetMatcher() const { return matcher_.get(); }
1485
Flags()1486 uint32 Flags() const override { return matcher_->Flags(); }
1487
1488 private:
1489 // Checks current arc if available and explicit. If not available, stops. If
1490 // not explicit, checks next ones.
CheckArc()1491 void CheckArc() {
1492 for (; !matcher_->Done(); matcher_->Next()) {
1493 const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel
1494 : matcher_->Value().olabel;
1495 if (label != kNoLabel) return;
1496 }
1497 }
1498
1499 std::unique_ptr<M> matcher_;
1500 MatchType match_type_; // Type of match requested.
1501 bool error_; // Error encountered?
1502 };
1503
1504 // Generic matcher, templated on the FST definition.
1505 //
1506 // Here is a typical use:
1507 //
1508 // Matcher<StdFst> matcher(fst, MATCH_INPUT);
1509 // matcher.SetState(state);
1510 // if (matcher.Find(label))
1511 // for (; !matcher.Done(); matcher.Next()) {
1512 // auto &arc = matcher.Value();
1513 // ...
1514 // }
1515 template <class F>
1516 class Matcher {
1517 public:
1518 using FST = F;
1519 using Arc = typename F::Arc;
1520 using Label = typename Arc::Label;
1521 using StateId = typename Arc::StateId;
1522 using Weight = typename Arc::Weight;
1523
1524 // This makes a copy of the FST.
Matcher(const FST & fst,MatchType match_type)1525 Matcher(const FST &fst, MatchType match_type)
1526 : owned_fst_(fst.Copy()), base_(owned_fst_->InitMatcher(match_type)) {
1527 if (!base_)
1528 base_ =
1529 std::make_unique<SortedMatcher<FST>>(owned_fst_.get(), match_type);
1530 }
1531
1532 // This doesn't copy the FST.
Matcher(const FST * fst,MatchType match_type)1533 Matcher(const FST *fst, MatchType match_type)
1534 : base_(fst->InitMatcher(match_type)) {
1535 if (!base_) base_ = std::make_unique<SortedMatcher<FST>>(fst, match_type);
1536 }
1537
1538 // This makes a copy of the FST.
1539 Matcher(const Matcher &matcher, bool safe = false)
1540 : base_(matcher.base_->Copy(safe)) {}
1541
1542 // Takes ownership of the provided matcher.
Matcher(MatcherBase<Arc> * base_matcher)1543 explicit Matcher(MatcherBase<Arc> *base_matcher) : base_(base_matcher) {}
1544
1545 Matcher *Copy(bool safe = false) const { return new Matcher(*this, safe); }
1546
Type(bool test)1547 MatchType Type(bool test) const { return base_->Type(test); }
1548
SetState(StateId s)1549 void SetState(StateId s) { base_->SetState(s); }
1550
Find(Label label)1551 bool Find(Label label) { return base_->Find(label); }
1552
Done()1553 bool Done() const { return base_->Done(); }
1554
Value()1555 const Arc &Value() const { return base_->Value(); }
1556
Next()1557 void Next() { base_->Next(); }
1558
GetFst()1559 const FST &GetFst() const { return fst::down_cast<const FST &>(base_->GetFst()); }
1560
Properties(uint64 props)1561 uint64 Properties(uint64 props) const { return base_->Properties(props); }
1562
Final(StateId s)1563 Weight Final(StateId s) const { return base_->Final(s); }
1564
Flags()1565 uint32 Flags() const { return base_->Flags() & kMatcherFlags; }
1566
Priority(StateId s)1567 ssize_t Priority(StateId s) { return base_->Priority(s); }
1568
1569 private:
1570 std::unique_ptr<const FST> owned_fst_;
1571 std::unique_ptr<MatcherBase<Arc>> base_;
1572 };
1573
1574 } // namespace fst
1575
1576 #endif // FST_MATCHER_H_
1577