1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Class to compute the composition of two FSTs.
19 
20 #ifndef FST_COMPOSE_H_
21 #define FST_COMPOSE_H_
22 
23 #include <algorithm>
24 #include <memory>
25 #include <utility>
26 
27 #include <fst/types.h>
28 #include <fst/log.h>
29 
30 #include <fst/cache.h>
31 #include <fst/compose-filter.h>
32 #include <fst/fst-decl.h>  // For optional argument declarations
33 #include <fst/lookahead-filter.h>
34 #include <fst/matcher.h>
35 #include <fst/state-table.h>
36 #include <fst/test-properties.h>
37 
38 
39 namespace fst {
40 
41 // Delayed composition options templated on the arc type, the matcher,
42 // the composition filter, and the composition state table. By
43 // default, the matchers, filter, and state table are constructed by
44 // composition. If set below, the user can instead pass in these
45 // objects; in that case, ComposeFst takes their ownership. This
46 // version controls composition implemented between generic Fst<Arc>
47 // types and a shared matcher type M for Fst<Arc>. This should be
48 // adequate for most applications, giving a reasonable tradeoff
49 // between efficiency and code sharing (but see ComposeFstImplOptions).
50 template <class Arc, class M = Matcher<Fst<Arc>>,
51           class Filter = SequenceComposeFilter<M>,
52           class StateTable =
53               GenericComposeStateTable<Arc, typename Filter::FilterState>>
54 struct ComposeFstOptions : public CacheOptions {
55   M *matcher1;              // FST1 matcher.
56   M *matcher2;              // FST2 matcher.
57   Filter *filter;           // Composition filter.
58   StateTable *state_table;  // Composition state table.
59 
60   explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(),
61                              M *matcher1 = nullptr, M *matcher2 = nullptr,
62                              Filter *filter = nullptr,
63                              StateTable *state_table = nullptr)
CacheOptionsComposeFstOptions64       : CacheOptions(opts),
65         matcher1(matcher1),
66         matcher2(matcher2),
67         filter(filter),
68         state_table(state_table) {}
69 };
70 
71 // Forward declaration of ComposeFstMatcher.
72 template <class C, class F, class T>
73 class ComposeFstMatcher;
74 
75 // Delayed composition options templated on the two matcher types, the
76 // composition filter, the composition state table and the cache store. By
77 // default, the matchers, filter, state table and cache store are constructed
78 // by composition. If set below, the user can instead pass in these objects; in
79 // that case, ComposeFst takes their ownership. This version controls
80 // composition implemented using arbitrary matchers (of the same arc type but
81 // otherwise arbitrary FST type). The user must ensure the matchers are
82 // compatible. These options permit the most efficient use, but shares the
83 // least code. This is for advanced use only in the most demanding or
84 // specialized applications that can benefit from it; otherwise, prefer
85 // ComposeFstOptions).
86 template <class M1, class M2, class Filter = SequenceComposeFilter<M1, M2>,
87           class StateTable = GenericComposeStateTable<
88               typename M1::Arc, typename Filter::FilterState>,
89           class CacheStore = DefaultCacheStore<typename M1::Arc>>
90 struct ComposeFstImplOptions : public CacheImplOptions<CacheStore> {
91   M1 *matcher1;    // FST1 matcher (see matcher.h)....
92   M2 *matcher2;    // FST2 matcher.
93   Filter *filter;  // Composition filter (see compose-filter.h).
94   StateTable
95       *state_table;      // Composition state table (see compose-state-table.h).
96   bool own_state_table;  // ComposeFstImpl takes ownership of 'state_table'?
97   bool allow_noncommute;  // Allow non-commutative weights
98 
99   explicit ComposeFstImplOptions(const CacheOptions &opts,
100                                  M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
101                                  Filter *filter = nullptr,
102                                  StateTable *state_table = nullptr)
103       : CacheImplOptions<CacheStore>(opts),
104         matcher1(matcher1),
105         matcher2(matcher2),
106         filter(filter),
107         state_table(state_table),
108         own_state_table(true),
109         allow_noncommute(false) {}
110 
111   explicit ComposeFstImplOptions(const CacheImplOptions<CacheStore> &opts,
112                                  M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
113                                  Filter *filter = nullptr,
114                                  StateTable *state_table = nullptr)
115       : CacheImplOptions<CacheStore>(opts),
116         matcher1(matcher1),
117         matcher2(matcher2),
118         filter(filter),
119         state_table(state_table),
120         own_state_table(true),
121         allow_noncommute(false) {}
122 
ComposeFstImplOptionsComposeFstImplOptions123   ComposeFstImplOptions()
124       : matcher1(nullptr),
125         matcher2(nullptr),
126         filter(nullptr),
127         state_table(nullptr),
128         own_state_table(true),
129         allow_noncommute(false) {}
130 };
131 
132 namespace internal {
133 
134 // Implementation of delayed composition. This base class is common to the
135 // variants with different matchers, composition filters and state tables.
136 template <class Arc, class CacheStore = DefaultCacheStore<Arc>,
137           class F = ComposeFst<Arc, CacheStore>>
138 class ComposeFstImplBase
139     : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
140  public:
141   using FST = F;
142   using Label = typename Arc::Label;
143   using StateId = typename Arc::StateId;
144   using Weight = typename Arc::Weight;
145 
146   using State = typename CacheStore::State;
147   using CacheImpl = CacheBaseImpl<State, CacheStore>;
148 
149   using FstImpl<Arc>::SetType;
150   using FstImpl<Arc>::SetProperties;
151   using FstImpl<Arc>::Properties;
152   using FstImpl<Arc>::SetInputSymbols;
153   using FstImpl<Arc>::SetOutputSymbols;
154 
155   using CacheImpl::HasArcs;
156   using CacheImpl::HasFinal;
157   using CacheImpl::HasStart;
158   using CacheImpl::SetFinal;
159   using CacheImpl::SetStart;
160 
ComposeFstImplBase(const CacheImplOptions<CacheStore> & opts)161   explicit ComposeFstImplBase(const CacheImplOptions<CacheStore> &opts)
162       : CacheImpl(opts) {}
163 
ComposeFstImplBase(const CacheOptions & opts)164   explicit ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {}
165 
ComposeFstImplBase(const ComposeFstImplBase & impl)166   ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) {
167     SetType(impl.Type());
168     SetProperties(impl.Properties(), kCopyProperties);
169     SetInputSymbols(impl.InputSymbols());
170     SetOutputSymbols(impl.OutputSymbols());
171   }
172 
173   virtual ComposeFstImplBase *Copy() const = 0;
174 
~ComposeFstImplBase()175   ~ComposeFstImplBase() override {}
176 
Start()177   StateId Start() {
178     if (!HasStart()) {
179       const auto start = ComputeStart();
180       if (start != kNoStateId) SetStart(start);
181     }
182     return CacheImpl::Start();
183   }
184 
Final(StateId s)185   Weight Final(StateId s) {
186     if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
187     return CacheImpl::Final(s);
188   }
189 
190   virtual void Expand(StateId s) = 0;
191 
NumArcs(StateId s)192   size_t NumArcs(StateId s) {
193     if (!HasArcs(s)) Expand(s);
194     return CacheImpl::NumArcs(s);
195   }
196 
NumInputEpsilons(StateId s)197   size_t NumInputEpsilons(StateId s) {
198     if (!HasArcs(s)) Expand(s);
199     return CacheImpl::NumInputEpsilons(s);
200   }
201 
NumOutputEpsilons(StateId s)202   size_t NumOutputEpsilons(StateId s) {
203     if (!HasArcs(s)) Expand(s);
204     return CacheImpl::NumOutputEpsilons(s);
205   }
206 
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)207   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
208     if (!HasArcs(s)) Expand(s);
209     CacheImpl::InitArcIterator(s, data);
210   }
211 
InitMatcher(const F & fst,MatchType match_type)212   virtual MatcherBase<Arc> *InitMatcher(const F &fst,
213                                         MatchType match_type) const {
214     // Use the default matcher if no override is provided.
215     return nullptr;
216   }
217 
218  protected:
219   virtual StateId ComputeStart() = 0;
220   virtual Weight ComputeFinal(StateId s) = 0;
221 };
222 
223 // Implementation of delayed composition templated on the matchers (see
224 // matcher.h), composition filter (see compose-filter.h) and the composition
225 // state table (see compose-state-table.h).
226 template <class CacheStore, class Filter, class StateTable>
227 class ComposeFstImpl
228     : public ComposeFstImplBase<typename CacheStore::Arc, CacheStore> {
229  public:
230   using Matcher1 = typename Filter::Matcher1;
231   using Matcher2 = typename Filter::Matcher2;
232 
233   using FST1 = typename Matcher1::FST;
234   using FST2 = typename Matcher2::FST;
235 
236   using Arc = typename CacheStore::Arc;
237   using Label = typename Arc::Label;
238   using StateId = typename Arc::StateId;
239   using Weight = typename Arc::Weight;
240 
241   using FilterState = typename Filter::FilterState;
242   using State = typename CacheStore::State;
243 
244   using CacheImpl = CacheBaseImpl<State, CacheStore>;
245 
246   using StateTuple = typename StateTable::StateTuple;
247 
248   friend class ComposeFstMatcher<CacheStore, Filter, StateTable>;
249 
250   using FstImpl<Arc>::SetInputSymbols;
251   using FstImpl<Arc>::SetOutputSymbols;
252   using FstImpl<Arc>::SetType;
253   using FstImpl<Arc>::SetProperties;
254 
255   template <class M1, class M2>
256   ComposeFstImpl(const FST1 &fst1, const FST2 &fst2,
257                  const ComposeFstImplOptions<M1, M2, Filter, StateTable,
258                                              CacheStore> &opts);
259 
ComposeFstImpl(const ComposeFstImpl & impl)260   ComposeFstImpl(const ComposeFstImpl &impl)
261       : ComposeFstImplBase<Arc, CacheStore>(impl),
262         filter_(new Filter(*impl.filter_, true)),
263         matcher1_(filter_->GetMatcher1()),
264         matcher2_(filter_->GetMatcher2()),
265         fst1_(matcher1_->GetFst()),
266         fst2_(matcher2_->GetFst()),
267         state_table_(new StateTable(*impl.state_table_)),
268         own_state_table_(true),
269         match_type_(impl.match_type_) {}
270 
~ComposeFstImpl()271   ~ComposeFstImpl() override {
272     if (own_state_table_) delete state_table_;
273   }
274 
Copy()275   ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); }
276 
Properties()277   uint64 Properties() const override { return Properties(kFstProperties); }
278 
279   // Sets error if found, and returns other FST impl properties.
Properties(uint64 mask)280   uint64 Properties(uint64 mask) const override {
281     if ((mask & kError) &&
282         (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) ||
283          (matcher1_->Properties(0) & kError) ||
284          (matcher2_->Properties(0) & kError) |
285              (filter_->Properties(0) & kError) ||
286          state_table_->Error())) {
287       SetProperties(kError, kError);
288     }
289     return FstImpl<Arc>::Properties(mask);
290   }
291 
292   // Arranges it so that the first arg to OrderedExpand is the Fst
293   // that will be matched on.
Expand(StateId s)294   void Expand(StateId s) override {
295     const auto &tuple = state_table_->Tuple(s);
296     const auto s1 = tuple.StateId1();
297     const auto s2 = tuple.StateId2();
298     filter_->SetState(s1, s2, tuple.GetFilterState());
299     if (MatchInput(s1, s2)) {
300       OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true);
301     } else {
302       OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false);
303     }
304   }
305 
GetFst1()306   const FST1 &GetFst1() const { return fst1_; }
307 
GetFst2()308   const FST2 &GetFst2() const { return fst2_; }
309 
GetMatcher1()310   const Matcher1 *GetMatcher1() const { return matcher1_; }
311 
GetMatcher1()312   Matcher1 *GetMatcher1() { return matcher1_; }
313 
GetMatcher2()314   const Matcher2 *GetMatcher2() const { return matcher2_; }
315 
GetMatcher2()316   Matcher2 *GetMatcher2() { return matcher2_; }
317 
GetFilter()318   const Filter *GetFilter() const { return filter_.get(); }
319 
GetFilter()320   Filter *GetFilter() { return filter_.get(); }
321 
GetStateTable()322   const StateTable *GetStateTable() const { return state_table_; }
323 
GetStateTable()324   StateTable *GetStateTable() { return state_table_; }
325 
InitMatcher(const ComposeFst<Arc,CacheStore> & fst,MatchType match_type)326   MatcherBase<Arc> *InitMatcher(const ComposeFst<Arc, CacheStore> &fst,
327                                 MatchType match_type) const override {
328     const auto test_props = match_type == MATCH_INPUT
329                                 ? kFstProperties & ~kILabelInvariantProperties
330                                 : kFstProperties & ~kOLabelInvariantProperties;
331     // If both matchers support 'match_type' and we have a guarantee that a
332     // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of
333     // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then
334     // ComposeFstMatcher can be used.
335     if ((matcher1_->Type(false) == match_type) &&
336         (matcher2_->Type(false) == match_type) &&
337         (filter_->Properties(test_props) == test_props)) {
338       return new ComposeFstMatcher<CacheStore, Filter, StateTable>(&fst,
339                                                                    match_type);
340     }
341     return nullptr;
342   }
343 
344  private:
345   // This does that actual matching of labels in the composition. The
346   // arguments are ordered so matching is called on state 'sa' of
347   // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg
348   // determines whether the input or output label of arcs at 'sb' is
349   // the one to match on.
350   template <class FST, class Matcher>
OrderedExpand(StateId s,const Fst<Arc> &,StateId sa,const FST & fstb,StateId sb,Matcher * matchera,bool match_input)351   void OrderedExpand(StateId s, const Fst<Arc> &, StateId sa, const FST &fstb,
352                      StateId sb, Matcher *matchera, bool match_input) {
353     matchera->SetState(sa);
354     // First processes non-consuming symbols (e.g., epsilons) on FSTA.
355     const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0,
356                    Weight::One(), sb);
357     MatchArc(s, matchera, loop, match_input);
358     // Then processes matches on FSTB.
359     for (ArcIterator<FST> iterb(fstb, sb); !iterb.Done(); iterb.Next()) {
360       MatchArc(s, matchera, iterb.Value(), match_input);
361     }
362     CacheImpl::SetArcs(s);
363   }
364 
365   // Matches a single transition from 'fstb' against 'fata' at 's'.
366   template <class Matcher>
MatchArc(StateId s,Matcher * matchera,const Arc & arc,bool match_input)367   void MatchArc(StateId s, Matcher *matchera, const Arc &arc,
368                 bool match_input) {
369     if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) {
370       for (; !matchera->Done(); matchera->Next()) {
371         auto arca = matchera->Value();
372         auto arcb = arc;
373         if (match_input) {
374           const auto &fs = filter_->FilterArc(&arcb, &arca);
375           if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
376         } else {
377           const auto &fs = filter_->FilterArc(&arca, &arcb);
378           if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
379         }
380       }
381     }
382   }
383 
384   // Add a matching transition at 's'.
AddArc(StateId s,const Arc & arc1,const Arc & arc2,const FilterState & f)385   void AddArc(StateId s, const Arc &arc1, const Arc &arc2,
386               const FilterState &f) {
387     const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
388     CacheImpl::EmplaceArc(s, arc1.ilabel, arc2.olabel,
389                           Times(arc1.weight, arc2.weight),
390                           state_table_->FindState(tuple));
391   }
392 
ComputeStart()393   StateId ComputeStart() override {
394     const auto s1 = fst1_.Start();
395     if (s1 == kNoStateId) return kNoStateId;
396     const auto s2 = fst2_.Start();
397     if (s2 == kNoStateId) return kNoStateId;
398     const auto &fs = filter_->Start();
399     const StateTuple tuple(s1, s2, fs);
400     return state_table_->FindState(tuple);
401   }
402 
ComputeFinal(StateId s)403   Weight ComputeFinal(StateId s) override {
404     const auto &tuple = state_table_->Tuple(s);
405     const auto s1 = tuple.StateId1();
406     auto final1 = matcher1_->Final(s1);
407     if (final1 == Weight::Zero()) return final1;
408     const auto s2 = tuple.StateId2();
409     auto final2 = matcher2_->Final(s2);
410     if (final2 == Weight::Zero()) return final2;
411     filter_->SetState(s1, s2, tuple.GetFilterState());
412     filter_->FilterFinal(&final1, &final2);
413     return Times(final1, final2);
414   }
415 
416   // Determines which side to match on per composition state.
MatchInput(StateId s1,StateId s2)417   bool MatchInput(StateId s1, StateId s2) {
418     switch (match_type_) {
419       case MATCH_INPUT:
420         return true;
421       case MATCH_OUTPUT:
422         return false;
423       default:  // MATCH_BOTH
424         const auto priority1 = matcher1_->Priority(s1);
425         const auto priority2 = matcher2_->Priority(s2);
426         if (priority1 == kRequirePriority && priority2 == kRequirePriority) {
427           FSTERROR() << "ComposeFst: Both sides can't require match";
428           SetProperties(kError, kError);
429           return true;
430         }
431         if (priority1 == kRequirePriority) return false;
432         if (priority2 == kRequirePriority) {
433           return true;
434         }
435         return priority1 <= priority2;
436     }
437   }
438 
439   // Identifies and verifies the capabilities of the matcher to be used for
440   // composition.
441   void SetMatchType();
442 
443   std::unique_ptr<Filter> filter_;
444   Matcher1 *matcher1_;  // Borrowed reference.
445   Matcher2 *matcher2_;  // Borrowed reference.
446   const FST1 &fst1_;
447   const FST2 &fst2_;
448   StateTable *state_table_;
449   bool own_state_table_;
450 
451   MatchType match_type_;
452 };
453 
454 template <class CacheStore, class Filter, class StateTable>
455 template <class M1, class M2>
ComposeFstImpl(const FST1 & fst1,const FST2 & fst2,const ComposeFstImplOptions<M1,M2,Filter,StateTable,CacheStore> & opts)456 ComposeFstImpl<CacheStore, Filter, StateTable>::ComposeFstImpl(
457     const FST1 &fst1, const FST2 &fst2,
458     const ComposeFstImplOptions<M1, M2, Filter, StateTable, CacheStore> &opts)
459     : ComposeFstImplBase<Arc, CacheStore>(opts),
460       filter_(opts.filter
461                   ? opts.filter
462                   : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
463       matcher1_(filter_->GetMatcher1()),
464       matcher2_(filter_->GetMatcher2()),
465       fst1_(matcher1_->GetFst()),
466       fst2_(matcher2_->GetFst()),
467       state_table_(opts.state_table ? opts.state_table
468                                     : new StateTable(fst1_, fst2_)),
469       own_state_table_(opts.state_table ? opts.own_state_table : true) {
470   SetType("compose");
471   if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
472     FSTERROR() << "ComposeFst: Output symbol table of 1st argument "
473                << "does not match input symbol table of 2nd argument";
474     SetProperties(kError, kError);
475   }
476   SetInputSymbols(fst1_.InputSymbols());
477   SetOutputSymbols(fst2_.OutputSymbols());
478   SetMatchType();
479   VLOG(2) << "ComposeFstImpl: Match type: " << match_type_;
480   if (match_type_ == MATCH_NONE) SetProperties(kError, kError);
481   const auto fprops1 = fst1.Properties(kFstProperties, false);
482   const auto fprops2 = fst2.Properties(kFstProperties, false);
483   const auto mprops1 = matcher1_->Properties(fprops1);
484   const auto mprops2 = matcher2_->Properties(fprops2);
485   const auto cprops = ComposeProperties(mprops1, mprops2);
486   SetProperties(filter_->Properties(cprops), kCopyProperties);
487   if (state_table_->Error()) SetProperties(kError, kError);
488 }
489 
490 template <class CacheStore, class Filter, class StateTable>
SetMatchType()491 void ComposeFstImpl<CacheStore, Filter, StateTable>::SetMatchType() {
492   // Ensures any required matching is possible and known.
493   if ((matcher1_->Flags() & kRequireMatch) &&
494       matcher1_->Type(true) != MATCH_OUTPUT) {
495     FSTERROR() << "ComposeFst: 1st argument cannot perform required matching "
496                << "(sort?).";
497     match_type_ = MATCH_NONE;
498     return;
499   }
500   if ((matcher2_->Flags() & kRequireMatch) &&
501       matcher2_->Type(true) != MATCH_INPUT) {
502     FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching "
503                << "(sort?).";
504     match_type_ = MATCH_NONE;
505     return;
506   }
507   // Finds which sides to match on (favoring minimal testing of capabilities).
508   const auto type1 = matcher1_->Type(false);
509   const auto type2 = matcher2_->Type(false);
510   if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) {
511     match_type_ = MATCH_BOTH;
512   } else if (type1 == MATCH_OUTPUT) {
513     match_type_ = MATCH_OUTPUT;
514   } else if (type2 == MATCH_INPUT) {
515     match_type_ = MATCH_INPUT;
516   } else if (matcher1_->Type(true) == MATCH_OUTPUT) {
517     match_type_ = MATCH_OUTPUT;
518   } else if (matcher2_->Type(true) == MATCH_INPUT) {
519     match_type_ = MATCH_INPUT;
520   } else {
521     FSTERROR() << "ComposeFst: 1st argument cannot match on output labels "
522                << "and 2nd argument cannot match on input labels (sort?).";
523     match_type_ = MATCH_NONE;
524   }
525 }
526 
527 }  // namespace internal
528 
529 // Computes the composition of two transducers. This version is a delayed FST.
530 // If FST1 transduces string x to y with weight a and FST2 transduces y to z
531 // with weight b, then their composition transduces string x to z with weight
532 // Times(x, z).
533 //
534 // The output labels of the first transducer or the input labels of the second
535 // transducer must be sorted (with the default matcher). The weights need to
536 // form a commutative semiring (valid for TropicalWeight and LogWeight).
537 //
538 // Complexity:
539 //
540 // Assuming the first FST is unsorted and the second is sorted,
541 //
542 //   Time: O(v1 v2 d1 (log d2 + m2)),
543 //   Space: O(v1 v2)
544 //
545 // where vi = # of states visited, di = maximum out-degree, and mi the
546 // maximum multiplicity of the states visited, for the ith FST. Constant time
547 // and space to visit an input state or arc is assumed and exclusive of caching.
548 //
549 // Caveats:
550 // - ComposeFst does not trim its output (since it is a delayed operation).
551 // - The efficiency of composition can be strongly affected by several factors:
552 //   - the choice of which transducer is sorted - prefer sorting the FST
553 //     that has the greater average out-degree.
554 //   - the amount of non-determinism
555 //   - the presence and location of epsilon transitions - avoid epsilon
556 //     transitions on the output side of the first transducer or
557 //     the input side of the second transducer or prefer placing
558 //     them later in a path since they delay matching and can
559 //     introduce non-coaccessible states and transitions.
560 //
561 // This class attaches interface to implementation and handles reference
562 // counting, delegating most methods to ImplToFst. The CacheStore specifies the
563 // cache store (default declared in fst-decl.h).
564 template <class A, class CacheStore /* = DefaultCacheStore<A> */>
565 class ComposeFst
566     : public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
567  public:
568   using Arc = A;
569   using StateId = typename Arc::StateId;
570   using Weight = typename Arc::Weight;
571 
572   using Store = CacheStore;
573   using State = typename CacheStore::State;
574 
575   using Impl = internal::ComposeFstImplBase<A, CacheStore>;
576 
577   friend class ArcIterator<ComposeFst<Arc, CacheStore>>;
578   friend class StateIterator<ComposeFst<Arc, CacheStore>>;
579   template <class, class, class>
580   friend class ComposeFstMatcher;
581 
582   // Compose specifying only caching options.
583   ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
584              const CacheOptions &opts = CacheOptions())
CreateBase(fst1,fst2,opts)585       : ImplToFst<Impl>(CreateBase(fst1, fst2, opts)) {}
586 
587   // Compose specifying one shared matcher type M. Requires that the input FSTs
588   // and matcher FST types be Fst<Arc>. Recommended for best code-sharing and
589   // matcher compatiblity.
590   template <class Matcher, class Filter, class StateTuple>
ComposeFst(const Fst<Arc> & fst1,const Fst<Arc> & fst2,const ComposeFstOptions<Arc,Matcher,Filter,StateTuple> & opts)591   ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
592              const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts)
593       : ImplToFst<Impl>(CreateBase1(fst1, fst2, opts)) {}
594 
595   // Compose specifying two matcher types Matcher1 and Matcher2. Requires input
596   // FST (of the same Arc type, but o.w. arbitrary) match the corresponding
597   // matcher FST types). Recommended only for advanced use in demanding or
598   // specialized applications due to potential code bloat and matcher
599   // incompatibilities.
600   template <class Matcher1, class Matcher2, class Filter, class StateTuple>
ComposeFst(const typename Matcher1::FST & fst1,const typename Matcher2::FST & fst2,const ComposeFstImplOptions<Matcher1,Matcher2,Filter,StateTuple,CacheStore> & opts)601   ComposeFst(const typename Matcher1::FST &fst1,
602              const typename Matcher2::FST &fst2,
603              const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
604                                          CacheStore> &opts)
605       : ImplToFst<Impl>(CreateBase2(fst1, fst2, opts)) {}
606 
607   // See Fst<>::Copy() for doc.
608   ComposeFst(const ComposeFst &fst, bool safe = false)
609       : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
610                              : fst.GetSharedImpl()) {}
611 
612   // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc.
613   ComposeFst *Copy(bool safe = false) const override {
614     return new ComposeFst(*this, safe);
615   }
616 
617   inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
618 
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)619   void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
620     GetMutableImpl()->InitArcIterator(s, data);
621   }
622 
InitMatcher(MatchType match_type)623   MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
624     return GetImpl()->InitMatcher(*this, match_type);
625   }
626 
627  protected:
628   using ImplToFst<Impl>::GetImpl;
629   using ImplToFst<Impl>::GetMutableImpl;
630 
ComposeFst(std::shared_ptr<Impl> impl)631   explicit ComposeFst(std::shared_ptr<Impl> impl) : ImplToFst<Impl>(impl) {}
632 
633   // Create compose implementation specifying two matcher types.
634   template <class Matcher1, class Matcher2, class Filter, class StateTuple>
CreateBase2(const typename Matcher1::FST & fst1,const typename Matcher2::FST & fst2,const ComposeFstImplOptions<Matcher1,Matcher2,Filter,StateTuple,CacheStore> & opts)635   static std::shared_ptr<Impl> CreateBase2(
636       const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2,
637       const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
638                                   CacheStore> &opts) {
639     auto impl = std::make_shared<
640         internal::ComposeFstImpl<CacheStore, Filter, StateTuple>>(fst1, fst2,
641                                                                   opts);
642     if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) {
643       const auto props1 = fst1.Properties(kUnweighted, true);
644       const auto props2 = fst2.Properties(kUnweighted, true);
645       if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) {
646         FSTERROR() << "ComposeFst: Weights must be a commutative semiring: "
647                    << Weight::Type();
648         impl->SetProperties(kError, kError);
649       }
650     }
651     return impl;
652   }
653 
654   // Create compose implementation specifying one matcher type; requires that
655   // input and matcher FST types be Fst<Arc>.
656   template <class Matcher, class Filter, class StateTuple>
CreateBase1(const Fst<Arc> & fst1,const Fst<Arc> & fst2,const ComposeFstOptions<Arc,Matcher,Filter,StateTuple> & opts)657   static std::shared_ptr<Impl> CreateBase1(
658       const Fst<Arc> &fst1, const Fst<Arc> &fst2,
659       const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts) {
660     ComposeFstImplOptions<Matcher, Matcher, Filter, StateTuple, CacheStore>
661         nopts(opts, opts.matcher1, opts.matcher2, opts.filter,
662               opts.state_table);
663     return CreateBase2(fst1, fst2, nopts);
664   }
665 
666   // Create compose implementation specifying no matcher type.
CreateBase(const Fst<Arc> & fst1,const Fst<Arc> & fst2,const CacheOptions & opts)667   static std::shared_ptr<Impl> CreateBase(const Fst<Arc> &fst1,
668                                           const Fst<Arc> &fst2,
669                                           const CacheOptions &opts) {
670     switch (LookAheadMatchType(fst1, fst2)) {  // Check for lookahead matchers
671       default:
672       case MATCH_NONE: {  // Default composition (no look-ahead).
673         ComposeFstOptions<Arc> nopts(opts);
674         return CreateBase1(fst1, fst2, nopts);
675       }
676       case MATCH_OUTPUT: {  // Lookahead on fst1.
677         using M = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher;
678         using F = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter;
679         ComposeFstOptions<Arc, M, F> nopts(opts);
680         return CreateBase1(fst1, fst2, nopts);
681       }
682       case MATCH_INPUT: {  // Lookahead on fst2
683         using M = typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher;
684         using F = typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter;
685         ComposeFstOptions<Arc, M, F> nopts(opts);
686         return CreateBase1(fst1, fst2, nopts);
687       }
688     }
689   }
690 
691  private:
692   ComposeFst &operator=(const ComposeFst &fst) = delete;
693 };
694 
695 // Specialization for ComposeFst.
696 template <class Arc, class CacheStore>
697 class StateIterator<ComposeFst<Arc, CacheStore>>
698     : public CacheStateIterator<ComposeFst<Arc, CacheStore>> {
699  public:
StateIterator(const ComposeFst<Arc,CacheStore> & fst)700   explicit StateIterator(const ComposeFst<Arc, CacheStore> &fst)
701       : CacheStateIterator<ComposeFst<Arc, CacheStore>>(fst,
702                                                         fst.GetMutableImpl()) {}
703 };
704 
705 // Specialization for ComposeFst.
706 template <class Arc, class CacheStore>
707 class ArcIterator<ComposeFst<Arc, CacheStore>>
708     : public CacheArcIterator<ComposeFst<Arc, CacheStore>> {
709  public:
710   using StateId = typename Arc::StateId;
711 
ArcIterator(const ComposeFst<Arc,CacheStore> & fst,StateId s)712   ArcIterator(const ComposeFst<Arc, CacheStore> &fst, StateId s)
713       : CacheArcIterator<ComposeFst<Arc, CacheStore>>(fst.GetMutableImpl(), s) {
714     if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
715   }
716 };
717 
718 template <class Arc, class CacheStore>
InitStateIterator(StateIteratorData<Arc> * data)719 inline void ComposeFst<Arc, CacheStore>::InitStateIterator(
720     StateIteratorData<Arc> *data) const {
721   data->base =
722       std::make_unique<StateIterator<ComposeFst<Arc, CacheStore>>>(*this);
723 }
724 
725 // Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT,
726 // iff the underlying matchers for the two FSTS being composed support
727 // MATCH_INPUT or MATCH_OUTPUT, respectively.
728 template <class CacheStore, class Filter, class StateTable>
729 class ComposeFstMatcher : public MatcherBase<typename CacheStore::Arc> {
730  public:
731   using Arc = typename CacheStore::Arc;
732   using Label = typename Arc::Label;
733   using StateId = typename Arc::StateId;
734   using Weight = typename Arc::Weight;
735 
736   using Matcher1 = typename Filter::Matcher1;
737   using Matcher2 = typename Filter::Matcher2;
738   using FilterState = typename Filter::FilterState;
739 
740   using StateTuple = typename StateTable::StateTuple;
741   using Impl = internal::ComposeFstImpl<CacheStore, Filter, StateTable>;
742 
743   // The compose FST arg must match the filter and state table types.
744   // This makes a copy of the FST.
ComposeFstMatcher(const ComposeFst<Arc,CacheStore> & fst,MatchType match_type)745   ComposeFstMatcher(const ComposeFst<Arc, CacheStore> &fst,
746                     MatchType match_type)
747       : owned_fst_(fst.Copy()),
748         fst_(*owned_fst_),
749         impl_(fst::down_cast<const Impl *>(fst_.GetImpl())),
750         s_(kNoStateId),
751         match_type_(match_type),
752         matcher1_(impl_->matcher1_->Copy()),
753         matcher2_(impl_->matcher2_->Copy()),
754         current_loop_(false),
755         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
756     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
757   }
758 
759   // The compose FST arg must match the filter and state table types.
760   // This doesn't copy the FST (although it may copy components).
ComposeFstMatcher(const ComposeFst<Arc,CacheStore> * fst,MatchType match_type)761   ComposeFstMatcher(const ComposeFst<Arc, CacheStore> *fst,
762                     MatchType match_type)
763       : fst_(*fst),
764         impl_(fst::down_cast<const Impl *>(fst_.GetImpl())),
765         s_(kNoStateId),
766         match_type_(match_type),
767         matcher1_(impl_->matcher1_->Copy()),
768         matcher2_(impl_->matcher2_->Copy()),
769         current_loop_(false),
770         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
771     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
772   }
773 
774   // This makes a copy of the FST.
775   ComposeFstMatcher(
776       const ComposeFstMatcher<CacheStore, Filter, StateTable> &matcher,
777       bool safe = false)
778       : owned_fst_(matcher.fst_.Copy(safe)),
779         fst_(*owned_fst_),
780         impl_(fst::down_cast<const Impl *>(fst_.GetImpl())),
781         s_(kNoStateId),
782         match_type_(matcher.match_type_),
783         matcher1_(matcher.matcher1_->Copy(safe)),
784         matcher2_(matcher.matcher2_->Copy(safe)),
785         current_loop_(false),
786         loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
787     if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
788   }
789 
790   ComposeFstMatcher *Copy(bool safe = false) const override {
791     return new ComposeFstMatcher(*this, safe);
792   }
793 
Type(bool test)794   MatchType Type(bool test) const override {
795     if ((matcher1_->Type(test) == MATCH_NONE) ||
796         (matcher2_->Type(test) == MATCH_NONE)) {
797       return MATCH_NONE;
798     }
799     if (((matcher1_->Type(test) == MATCH_UNKNOWN) &&
800          (matcher2_->Type(test) == MATCH_UNKNOWN)) ||
801         ((matcher1_->Type(test) == MATCH_UNKNOWN) &&
802          (matcher2_->Type(test) == match_type_)) ||
803         ((matcher1_->Type(test) == match_type_) &&
804          (matcher2_->Type(test) == MATCH_UNKNOWN))) {
805       return MATCH_UNKNOWN;
806     }
807     if ((matcher1_->Type(test) == match_type_) &&
808         (matcher2_->Type(test) == match_type_)) {
809       return match_type_;
810     }
811     return MATCH_NONE;
812   }
813 
GetFst()814   const Fst<Arc> &GetFst() const override { return fst_; }
815 
Properties(uint64 inprops)816   uint64 Properties(uint64 inprops) const override { return inprops; }
817 
SetState(StateId s)818   void SetState(StateId s) final {
819     if (s_ == s) return;
820     s_ = s;
821     const auto &tuple = impl_->state_table_->Tuple(s);
822     matcher1_->SetState(tuple.StateId1());
823     matcher2_->SetState(tuple.StateId2());
824     loop_.nextstate = s_;
825   }
826 
Find(Label label)827   bool Find(Label label) final {
828     bool found = false;
829     current_loop_ = false;
830     if (label == 0) {
831       current_loop_ = true;
832       found = true;
833     }
834     if (match_type_ == MATCH_INPUT) {
835       found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
836     } else {  // match_type_ == MATCH_OUTPUT
837       found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
838     }
839     return found;
840   }
841 
Done()842   bool Done() const final {
843     return !current_loop_ && matcher1_->Done() && matcher2_->Done();
844   }
845 
Value()846   const Arc &Value() const final { return current_loop_ ? loop_ : arc_; }
847 
Next()848   void Next() final {
849     if (current_loop_) {
850       current_loop_ = false;
851     } else if (match_type_ == MATCH_INPUT) {
852       FindNext(matcher1_.get(), matcher2_.get());
853     } else {  // match_type_ == MATCH_OUTPUT
854       FindNext(matcher2_.get(), matcher1_.get());
855     }
856   }
857 
Priority(StateId s)858   ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
859 
860  private:
861   // Processes a match with the filter and creates resulting arc.
MatchArc(StateId s,Arc * arc1,Arc * arc2)862   bool MatchArc(StateId s, Arc *arc1, Arc *arc2) {
863     const auto &fs = impl_->filter_->FilterArc(arc1, arc2);
864     if (fs == FilterState::NoState()) return false;
865     const StateTuple tuple(arc1->nextstate, arc2->nextstate, fs);
866     arc_.ilabel = arc1->ilabel;
867     arc_.olabel = arc2->olabel;
868     arc_.weight = Times(arc1->weight, arc2->weight);
869     arc_.nextstate = impl_->state_table_->FindState(tuple);
870     return true;
871   }
872 
873   // Finds the first match allowed by the filter.
874   template <class MatcherA, class MatcherB>
FindLabel(Label label,MatcherA * matchera,MatcherB * matcherb)875   bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) {
876     if (matchera->Find(label)) {
877       matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel
878                                                 : matchera->Value().ilabel);
879       return FindNext(matchera, matcherb);
880     }
881     return false;
882   }
883 
884   // Finds the next match allowed by the filter, returning true iff such a
885   // match is found.
886   template <class MatcherA, class MatcherB>
FindNext(MatcherA * matchera,MatcherB * matcherb)887   bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
888     // State when entering this function:
889     // 'matchera' is pointed to a match x, y for label x, and a match for y was
890     // requested on 'matcherb'.
891     while (!matchera->Done() || !matcherb->Done()) {
892       if (matcherb->Done()) {
893         // If no more matches for y on 'matcherb', moves forward on 'matchera'
894         // until a match x, y' is found such that there is a match for y' on
895         // 'matcherb'.
896         matchera->Next();
897         while (!matchera->Done() &&
898                !matcherb->Find(match_type_ == MATCH_INPUT
899                                    ? matchera->Value().olabel
900                                    : matchera->Value().ilabel)) {
901           matchera->Next();
902         }
903       }
904       while (!matcherb->Done()) {
905         // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is
906         // pointing to a match y', z' ('arcb'). If combining these two arcs is
907         // allowed by the filter (hence resulting in an arc x, z') return true.
908         // Position 'matcherb' on the next potential match for y' before
909         // returning.
910         auto arca = matchera->Value();
911         auto arcb = matcherb->Value();
912         // Position 'matcherb' on the next potential match for y'.
913         matcherb->Next();
914         // Returns true If combining these two arcs is allowed by the filter
915         // (hence resulting in an arc x, z'); otherwise consider next match
916         // for y' on 'matcherb'.
917         if (match_type_ == MATCH_INPUT) {
918           return MatchArc(s_, &arca, &arcb);
919         } else {
920           return MatchArc(s_, &arcb, &arca);
921         }
922       }
923     }
924     // Both 'matchera' and 'matcherb' are done, no more match to analyse.
925     return false;
926   }
927 
928   std::unique_ptr<const ComposeFst<Arc, CacheStore>> owned_fst_;
929   const ComposeFst<Arc, CacheStore> &fst_;
930   const Impl *impl_;
931   StateId s_;
932   MatchType match_type_;
933   std::unique_ptr<Matcher1> matcher1_;
934   std::unique_ptr<Matcher2> matcher2_;
935   bool current_loop_;
936   Arc loop_;
937   Arc arc_;
938 };
939 
940 // Useful alias when using StdArc.
941 using StdComposeFst = ComposeFst<StdArc>;
942 
943 enum ComposeFilter {
944   AUTO_FILTER,
945   NULL_FILTER,
946   TRIVIAL_FILTER,
947   SEQUENCE_FILTER,
948   ALT_SEQUENCE_FILTER,
949   MATCH_FILTER,
950   NO_MATCH_FILTER
951 };
952 
953 struct ComposeOptions {
954   bool connect;               // Connect output?
955   ComposeFilter filter_type;  // Pre-defined filter to use.
956 
957   explicit ComposeOptions(bool connect = true,
958                           ComposeFilter filter_type = AUTO_FILTER)
connectComposeOptions959       : connect(connect), filter_type(filter_type) {}
960 };
961 
962 // Computes the composition of two transducers. This version writes
963 // the composed FST into a MutableFst. If FST1 transduces string x to
964 // y with weight a and FST2 transduces y to z with weight b, then
965 // their composition transduces string x to z with weight
966 // Times(a, b).
967 //
968 // The output labels of the first transducer or the input labels of
969 // the second transducer must be sorted. The weights need to form a
970 // commutative semiring (valid for TropicalWeight and LogWeight).
971 //
972 // Complexity:
973 //
974 // Assuming the first FST is unsorted and the second is sorted:
975 //
976 //   Time: O(V1 V2 D1 (log D2 + M2)),
977 //   Space: O(V1 V2 D1 M2)
978 //
979 // where Vi = # of states, Di = maximum out-degree, and Mi is the maximum
980 // multiplicity, for the ith FST.
981 //
982 // Caveats:
983 //
984 // - Compose trims its output.
985 // - The efficiency of composition can be strongly affected by several factors:
986 //   - the choice of which transducer is sorted - prefer sorting the FST
987 //     that has the greater average out-degree.
988 //   - the amount of non-determinism
989 //   - the presence and location of epsilon transitions - avoid epsilon
990 //     transitions on the output side of the first transducer or
991 //     the input side of the second transducer or prefer placing
992 //     them later in a path since they delay matching and can
993 //     introduce non-coaccessible states and transitions.
994 template <class Arc>
995 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
996              MutableFst<Arc> *ofst,
997              const ComposeOptions &opts = ComposeOptions()) {
998   using M = Matcher<Fst<Arc>>;
999   // In each case, we cache only the last state for fastest copy.
1000   switch (opts.filter_type) {
1001     case AUTO_FILTER: {
1002       CacheOptions nopts;
1003       nopts.gc_limit = 0;
1004       *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
1005       break;
1006     }
1007     case NULL_FILTER: {
1008       ComposeFstOptions<Arc, M, NullComposeFilter<M>> copts;
1009       copts.gc_limit = 0;
1010       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1011       break;
1012     }
1013     case SEQUENCE_FILTER: {
1014       ComposeFstOptions<Arc, M, SequenceComposeFilter<M>> copts;
1015       copts.gc_limit = 0;
1016       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1017       break;
1018     }
1019     case ALT_SEQUENCE_FILTER: {
1020       ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>> copts;
1021       copts.gc_limit = 0;
1022       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1023       break;
1024     }
1025     case MATCH_FILTER: {
1026       ComposeFstOptions<Arc, M, MatchComposeFilter<M>> copts;
1027       copts.gc_limit = 0;
1028       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1029       break;
1030     }
1031     case NO_MATCH_FILTER: {
1032       ComposeFstOptions<Arc, M, NoMatchComposeFilter<M>> copts;
1033       copts.gc_limit = 0;
1034       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1035       break;
1036     }
1037     case TRIVIAL_FILTER: {
1038       ComposeFstOptions<Arc, M, TrivialComposeFilter<M>> copts;
1039       copts.gc_limit = 0;
1040       *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1041       break;
1042     }
1043   }
1044   if (opts.connect) Connect(ofst);
1045 }
1046 
1047 }  // namespace fst
1048 
1049 #endif  // FST_COMPOSE_H_
1050