1 // Copyright 2005-2020 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the 'License');
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an 'AS IS' BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // See www.openfst.org for extensive documentation on this weighted
16 // finite-state transducer library.
17 //
18 // Class to compute the composition of two FSTs.
19
20 #ifndef FST_COMPOSE_H_
21 #define FST_COMPOSE_H_
22
23 #include <algorithm>
24 #include <memory>
25 #include <utility>
26
27 #include <fst/types.h>
28 #include <fst/log.h>
29
30 #include <fst/cache.h>
31 #include <fst/compose-filter.h>
32 #include <fst/fst-decl.h> // For optional argument declarations
33 #include <fst/lookahead-filter.h>
34 #include <fst/matcher.h>
35 #include <fst/state-table.h>
36 #include <fst/test-properties.h>
37
38
39 namespace fst {
40
41 // Delayed composition options templated on the arc type, the matcher,
42 // the composition filter, and the composition state table. By
43 // default, the matchers, filter, and state table are constructed by
44 // composition. If set below, the user can instead pass in these
45 // objects; in that case, ComposeFst takes their ownership. This
46 // version controls composition implemented between generic Fst<Arc>
47 // types and a shared matcher type M for Fst<Arc>. This should be
48 // adequate for most applications, giving a reasonable tradeoff
49 // between efficiency and code sharing (but see ComposeFstImplOptions).
50 template <class Arc, class M = Matcher<Fst<Arc>>,
51 class Filter = SequenceComposeFilter<M>,
52 class StateTable =
53 GenericComposeStateTable<Arc, typename Filter::FilterState>>
54 struct ComposeFstOptions : public CacheOptions {
55 M *matcher1; // FST1 matcher.
56 M *matcher2; // FST2 matcher.
57 Filter *filter; // Composition filter.
58 StateTable *state_table; // Composition state table.
59
60 explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(),
61 M *matcher1 = nullptr, M *matcher2 = nullptr,
62 Filter *filter = nullptr,
63 StateTable *state_table = nullptr)
CacheOptionsComposeFstOptions64 : CacheOptions(opts),
65 matcher1(matcher1),
66 matcher2(matcher2),
67 filter(filter),
68 state_table(state_table) {}
69 };
70
71 // Forward declaration of ComposeFstMatcher.
72 template <class C, class F, class T>
73 class ComposeFstMatcher;
74
75 // Delayed composition options templated on the two matcher types, the
76 // composition filter, the composition state table and the cache store. By
77 // default, the matchers, filter, state table and cache store are constructed
78 // by composition. If set below, the user can instead pass in these objects; in
79 // that case, ComposeFst takes their ownership. This version controls
80 // composition implemented using arbitrary matchers (of the same arc type but
81 // otherwise arbitrary FST type). The user must ensure the matchers are
82 // compatible. These options permit the most efficient use, but shares the
83 // least code. This is for advanced use only in the most demanding or
84 // specialized applications that can benefit from it; otherwise, prefer
85 // ComposeFstOptions).
86 template <class M1, class M2, class Filter = SequenceComposeFilter<M1, M2>,
87 class StateTable = GenericComposeStateTable<
88 typename M1::Arc, typename Filter::FilterState>,
89 class CacheStore = DefaultCacheStore<typename M1::Arc>>
90 struct ComposeFstImplOptions : public CacheImplOptions<CacheStore> {
91 M1 *matcher1; // FST1 matcher (see matcher.h)....
92 M2 *matcher2; // FST2 matcher.
93 Filter *filter; // Composition filter (see compose-filter.h).
94 StateTable
95 *state_table; // Composition state table (see compose-state-table.h).
96 bool own_state_table; // ComposeFstImpl takes ownership of 'state_table'?
97 bool allow_noncommute; // Allow non-commutative weights
98
99 explicit ComposeFstImplOptions(const CacheOptions &opts,
100 M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
101 Filter *filter = nullptr,
102 StateTable *state_table = nullptr)
103 : CacheImplOptions<CacheStore>(opts),
104 matcher1(matcher1),
105 matcher2(matcher2),
106 filter(filter),
107 state_table(state_table),
108 own_state_table(true),
109 allow_noncommute(false) {}
110
111 explicit ComposeFstImplOptions(const CacheImplOptions<CacheStore> &opts,
112 M1 *matcher1 = nullptr, M2 *matcher2 = nullptr,
113 Filter *filter = nullptr,
114 StateTable *state_table = nullptr)
115 : CacheImplOptions<CacheStore>(opts),
116 matcher1(matcher1),
117 matcher2(matcher2),
118 filter(filter),
119 state_table(state_table),
120 own_state_table(true),
121 allow_noncommute(false) {}
122
ComposeFstImplOptionsComposeFstImplOptions123 ComposeFstImplOptions()
124 : matcher1(nullptr),
125 matcher2(nullptr),
126 filter(nullptr),
127 state_table(nullptr),
128 own_state_table(true),
129 allow_noncommute(false) {}
130 };
131
132 namespace internal {
133
134 // Implementation of delayed composition. This base class is common to the
135 // variants with different matchers, composition filters and state tables.
136 template <class Arc, class CacheStore = DefaultCacheStore<Arc>,
137 class F = ComposeFst<Arc, CacheStore>>
138 class ComposeFstImplBase
139 : public CacheBaseImpl<typename CacheStore::State, CacheStore> {
140 public:
141 using FST = F;
142 using Label = typename Arc::Label;
143 using StateId = typename Arc::StateId;
144 using Weight = typename Arc::Weight;
145
146 using State = typename CacheStore::State;
147 using CacheImpl = CacheBaseImpl<State, CacheStore>;
148
149 using FstImpl<Arc>::SetType;
150 using FstImpl<Arc>::SetProperties;
151 using FstImpl<Arc>::Properties;
152 using FstImpl<Arc>::SetInputSymbols;
153 using FstImpl<Arc>::SetOutputSymbols;
154
155 using CacheImpl::HasArcs;
156 using CacheImpl::HasFinal;
157 using CacheImpl::HasStart;
158 using CacheImpl::SetFinal;
159 using CacheImpl::SetStart;
160
ComposeFstImplBase(const CacheImplOptions<CacheStore> & opts)161 explicit ComposeFstImplBase(const CacheImplOptions<CacheStore> &opts)
162 : CacheImpl(opts) {}
163
ComposeFstImplBase(const CacheOptions & opts)164 explicit ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {}
165
ComposeFstImplBase(const ComposeFstImplBase & impl)166 ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) {
167 SetType(impl.Type());
168 SetProperties(impl.Properties(), kCopyProperties);
169 SetInputSymbols(impl.InputSymbols());
170 SetOutputSymbols(impl.OutputSymbols());
171 }
172
173 virtual ComposeFstImplBase *Copy() const = 0;
174
~ComposeFstImplBase()175 ~ComposeFstImplBase() override {}
176
Start()177 StateId Start() {
178 if (!HasStart()) {
179 const auto start = ComputeStart();
180 if (start != kNoStateId) SetStart(start);
181 }
182 return CacheImpl::Start();
183 }
184
Final(StateId s)185 Weight Final(StateId s) {
186 if (!HasFinal(s)) SetFinal(s, ComputeFinal(s));
187 return CacheImpl::Final(s);
188 }
189
190 virtual void Expand(StateId s) = 0;
191
NumArcs(StateId s)192 size_t NumArcs(StateId s) {
193 if (!HasArcs(s)) Expand(s);
194 return CacheImpl::NumArcs(s);
195 }
196
NumInputEpsilons(StateId s)197 size_t NumInputEpsilons(StateId s) {
198 if (!HasArcs(s)) Expand(s);
199 return CacheImpl::NumInputEpsilons(s);
200 }
201
NumOutputEpsilons(StateId s)202 size_t NumOutputEpsilons(StateId s) {
203 if (!HasArcs(s)) Expand(s);
204 return CacheImpl::NumOutputEpsilons(s);
205 }
206
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)207 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) {
208 if (!HasArcs(s)) Expand(s);
209 CacheImpl::InitArcIterator(s, data);
210 }
211
InitMatcher(const F & fst,MatchType match_type)212 virtual MatcherBase<Arc> *InitMatcher(const F &fst,
213 MatchType match_type) const {
214 // Use the default matcher if no override is provided.
215 return nullptr;
216 }
217
218 protected:
219 virtual StateId ComputeStart() = 0;
220 virtual Weight ComputeFinal(StateId s) = 0;
221 };
222
223 // Implementation of delayed composition templated on the matchers (see
224 // matcher.h), composition filter (see compose-filter.h) and the composition
225 // state table (see compose-state-table.h).
226 template <class CacheStore, class Filter, class StateTable>
227 class ComposeFstImpl
228 : public ComposeFstImplBase<typename CacheStore::Arc, CacheStore> {
229 public:
230 using Matcher1 = typename Filter::Matcher1;
231 using Matcher2 = typename Filter::Matcher2;
232
233 using FST1 = typename Matcher1::FST;
234 using FST2 = typename Matcher2::FST;
235
236 using Arc = typename CacheStore::Arc;
237 using Label = typename Arc::Label;
238 using StateId = typename Arc::StateId;
239 using Weight = typename Arc::Weight;
240
241 using FilterState = typename Filter::FilterState;
242 using State = typename CacheStore::State;
243
244 using CacheImpl = CacheBaseImpl<State, CacheStore>;
245
246 using StateTuple = typename StateTable::StateTuple;
247
248 friend class ComposeFstMatcher<CacheStore, Filter, StateTable>;
249
250 using FstImpl<Arc>::SetInputSymbols;
251 using FstImpl<Arc>::SetOutputSymbols;
252 using FstImpl<Arc>::SetType;
253 using FstImpl<Arc>::SetProperties;
254
255 template <class M1, class M2>
256 ComposeFstImpl(const FST1 &fst1, const FST2 &fst2,
257 const ComposeFstImplOptions<M1, M2, Filter, StateTable,
258 CacheStore> &opts);
259
ComposeFstImpl(const ComposeFstImpl & impl)260 ComposeFstImpl(const ComposeFstImpl &impl)
261 : ComposeFstImplBase<Arc, CacheStore>(impl),
262 filter_(new Filter(*impl.filter_, true)),
263 matcher1_(filter_->GetMatcher1()),
264 matcher2_(filter_->GetMatcher2()),
265 fst1_(matcher1_->GetFst()),
266 fst2_(matcher2_->GetFst()),
267 state_table_(new StateTable(*impl.state_table_)),
268 own_state_table_(true),
269 match_type_(impl.match_type_) {}
270
~ComposeFstImpl()271 ~ComposeFstImpl() override {
272 if (own_state_table_) delete state_table_;
273 }
274
Copy()275 ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); }
276
Properties()277 uint64 Properties() const override { return Properties(kFstProperties); }
278
279 // Sets error if found, and returns other FST impl properties.
Properties(uint64 mask)280 uint64 Properties(uint64 mask) const override {
281 if ((mask & kError) &&
282 (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) ||
283 (matcher1_->Properties(0) & kError) ||
284 (matcher2_->Properties(0) & kError) |
285 (filter_->Properties(0) & kError) ||
286 state_table_->Error())) {
287 SetProperties(kError, kError);
288 }
289 return FstImpl<Arc>::Properties(mask);
290 }
291
292 // Arranges it so that the first arg to OrderedExpand is the Fst
293 // that will be matched on.
Expand(StateId s)294 void Expand(StateId s) override {
295 const auto &tuple = state_table_->Tuple(s);
296 const auto s1 = tuple.StateId1();
297 const auto s2 = tuple.StateId2();
298 filter_->SetState(s1, s2, tuple.GetFilterState());
299 if (MatchInput(s1, s2)) {
300 OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true);
301 } else {
302 OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false);
303 }
304 }
305
GetFst1()306 const FST1 &GetFst1() const { return fst1_; }
307
GetFst2()308 const FST2 &GetFst2() const { return fst2_; }
309
GetMatcher1()310 const Matcher1 *GetMatcher1() const { return matcher1_; }
311
GetMatcher1()312 Matcher1 *GetMatcher1() { return matcher1_; }
313
GetMatcher2()314 const Matcher2 *GetMatcher2() const { return matcher2_; }
315
GetMatcher2()316 Matcher2 *GetMatcher2() { return matcher2_; }
317
GetFilter()318 const Filter *GetFilter() const { return filter_.get(); }
319
GetFilter()320 Filter *GetFilter() { return filter_.get(); }
321
GetStateTable()322 const StateTable *GetStateTable() const { return state_table_; }
323
GetStateTable()324 StateTable *GetStateTable() { return state_table_; }
325
InitMatcher(const ComposeFst<Arc,CacheStore> & fst,MatchType match_type)326 MatcherBase<Arc> *InitMatcher(const ComposeFst<Arc, CacheStore> &fst,
327 MatchType match_type) const override {
328 const auto test_props = match_type == MATCH_INPUT
329 ? kFstProperties & ~kILabelInvariantProperties
330 : kFstProperties & ~kOLabelInvariantProperties;
331 // If both matchers support 'match_type' and we have a guarantee that a
332 // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of
333 // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then
334 // ComposeFstMatcher can be used.
335 if ((matcher1_->Type(false) == match_type) &&
336 (matcher2_->Type(false) == match_type) &&
337 (filter_->Properties(test_props) == test_props)) {
338 return new ComposeFstMatcher<CacheStore, Filter, StateTable>(&fst,
339 match_type);
340 }
341 return nullptr;
342 }
343
344 private:
345 // This does that actual matching of labels in the composition. The
346 // arguments are ordered so matching is called on state 'sa' of
347 // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg
348 // determines whether the input or output label of arcs at 'sb' is
349 // the one to match on.
350 template <class FST, class Matcher>
OrderedExpand(StateId s,const Fst<Arc> &,StateId sa,const FST & fstb,StateId sb,Matcher * matchera,bool match_input)351 void OrderedExpand(StateId s, const Fst<Arc> &, StateId sa, const FST &fstb,
352 StateId sb, Matcher *matchera, bool match_input) {
353 matchera->SetState(sa);
354 // First processes non-consuming symbols (e.g., epsilons) on FSTA.
355 const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0,
356 Weight::One(), sb);
357 MatchArc(s, matchera, loop, match_input);
358 // Then processes matches on FSTB.
359 for (ArcIterator<FST> iterb(fstb, sb); !iterb.Done(); iterb.Next()) {
360 MatchArc(s, matchera, iterb.Value(), match_input);
361 }
362 CacheImpl::SetArcs(s);
363 }
364
365 // Matches a single transition from 'fstb' against 'fata' at 's'.
366 template <class Matcher>
MatchArc(StateId s,Matcher * matchera,const Arc & arc,bool match_input)367 void MatchArc(StateId s, Matcher *matchera, const Arc &arc,
368 bool match_input) {
369 if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) {
370 for (; !matchera->Done(); matchera->Next()) {
371 auto arca = matchera->Value();
372 auto arcb = arc;
373 if (match_input) {
374 const auto &fs = filter_->FilterArc(&arcb, &arca);
375 if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs);
376 } else {
377 const auto &fs = filter_->FilterArc(&arca, &arcb);
378 if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs);
379 }
380 }
381 }
382 }
383
384 // Add a matching transition at 's'.
AddArc(StateId s,const Arc & arc1,const Arc & arc2,const FilterState & f)385 void AddArc(StateId s, const Arc &arc1, const Arc &arc2,
386 const FilterState &f) {
387 const StateTuple tuple(arc1.nextstate, arc2.nextstate, f);
388 CacheImpl::EmplaceArc(s, arc1.ilabel, arc2.olabel,
389 Times(arc1.weight, arc2.weight),
390 state_table_->FindState(tuple));
391 }
392
ComputeStart()393 StateId ComputeStart() override {
394 const auto s1 = fst1_.Start();
395 if (s1 == kNoStateId) return kNoStateId;
396 const auto s2 = fst2_.Start();
397 if (s2 == kNoStateId) return kNoStateId;
398 const auto &fs = filter_->Start();
399 const StateTuple tuple(s1, s2, fs);
400 return state_table_->FindState(tuple);
401 }
402
ComputeFinal(StateId s)403 Weight ComputeFinal(StateId s) override {
404 const auto &tuple = state_table_->Tuple(s);
405 const auto s1 = tuple.StateId1();
406 auto final1 = matcher1_->Final(s1);
407 if (final1 == Weight::Zero()) return final1;
408 const auto s2 = tuple.StateId2();
409 auto final2 = matcher2_->Final(s2);
410 if (final2 == Weight::Zero()) return final2;
411 filter_->SetState(s1, s2, tuple.GetFilterState());
412 filter_->FilterFinal(&final1, &final2);
413 return Times(final1, final2);
414 }
415
416 // Determines which side to match on per composition state.
MatchInput(StateId s1,StateId s2)417 bool MatchInput(StateId s1, StateId s2) {
418 switch (match_type_) {
419 case MATCH_INPUT:
420 return true;
421 case MATCH_OUTPUT:
422 return false;
423 default: // MATCH_BOTH
424 const auto priority1 = matcher1_->Priority(s1);
425 const auto priority2 = matcher2_->Priority(s2);
426 if (priority1 == kRequirePriority && priority2 == kRequirePriority) {
427 FSTERROR() << "ComposeFst: Both sides can't require match";
428 SetProperties(kError, kError);
429 return true;
430 }
431 if (priority1 == kRequirePriority) return false;
432 if (priority2 == kRequirePriority) {
433 return true;
434 }
435 return priority1 <= priority2;
436 }
437 }
438
439 // Identifies and verifies the capabilities of the matcher to be used for
440 // composition.
441 void SetMatchType();
442
443 std::unique_ptr<Filter> filter_;
444 Matcher1 *matcher1_; // Borrowed reference.
445 Matcher2 *matcher2_; // Borrowed reference.
446 const FST1 &fst1_;
447 const FST2 &fst2_;
448 StateTable *state_table_;
449 bool own_state_table_;
450
451 MatchType match_type_;
452 };
453
454 template <class CacheStore, class Filter, class StateTable>
455 template <class M1, class M2>
ComposeFstImpl(const FST1 & fst1,const FST2 & fst2,const ComposeFstImplOptions<M1,M2,Filter,StateTable,CacheStore> & opts)456 ComposeFstImpl<CacheStore, Filter, StateTable>::ComposeFstImpl(
457 const FST1 &fst1, const FST2 &fst2,
458 const ComposeFstImplOptions<M1, M2, Filter, StateTable, CacheStore> &opts)
459 : ComposeFstImplBase<Arc, CacheStore>(opts),
460 filter_(opts.filter
461 ? opts.filter
462 : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)),
463 matcher1_(filter_->GetMatcher1()),
464 matcher2_(filter_->GetMatcher2()),
465 fst1_(matcher1_->GetFst()),
466 fst2_(matcher2_->GetFst()),
467 state_table_(opts.state_table ? opts.state_table
468 : new StateTable(fst1_, fst2_)),
469 own_state_table_(opts.state_table ? opts.own_state_table : true) {
470 SetType("compose");
471 if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) {
472 FSTERROR() << "ComposeFst: Output symbol table of 1st argument "
473 << "does not match input symbol table of 2nd argument";
474 SetProperties(kError, kError);
475 }
476 SetInputSymbols(fst1_.InputSymbols());
477 SetOutputSymbols(fst2_.OutputSymbols());
478 SetMatchType();
479 VLOG(2) << "ComposeFstImpl: Match type: " << match_type_;
480 if (match_type_ == MATCH_NONE) SetProperties(kError, kError);
481 const auto fprops1 = fst1.Properties(kFstProperties, false);
482 const auto fprops2 = fst2.Properties(kFstProperties, false);
483 const auto mprops1 = matcher1_->Properties(fprops1);
484 const auto mprops2 = matcher2_->Properties(fprops2);
485 const auto cprops = ComposeProperties(mprops1, mprops2);
486 SetProperties(filter_->Properties(cprops), kCopyProperties);
487 if (state_table_->Error()) SetProperties(kError, kError);
488 }
489
490 template <class CacheStore, class Filter, class StateTable>
SetMatchType()491 void ComposeFstImpl<CacheStore, Filter, StateTable>::SetMatchType() {
492 // Ensures any required matching is possible and known.
493 if ((matcher1_->Flags() & kRequireMatch) &&
494 matcher1_->Type(true) != MATCH_OUTPUT) {
495 FSTERROR() << "ComposeFst: 1st argument cannot perform required matching "
496 << "(sort?).";
497 match_type_ = MATCH_NONE;
498 return;
499 }
500 if ((matcher2_->Flags() & kRequireMatch) &&
501 matcher2_->Type(true) != MATCH_INPUT) {
502 FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching "
503 << "(sort?).";
504 match_type_ = MATCH_NONE;
505 return;
506 }
507 // Finds which sides to match on (favoring minimal testing of capabilities).
508 const auto type1 = matcher1_->Type(false);
509 const auto type2 = matcher2_->Type(false);
510 if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) {
511 match_type_ = MATCH_BOTH;
512 } else if (type1 == MATCH_OUTPUT) {
513 match_type_ = MATCH_OUTPUT;
514 } else if (type2 == MATCH_INPUT) {
515 match_type_ = MATCH_INPUT;
516 } else if (matcher1_->Type(true) == MATCH_OUTPUT) {
517 match_type_ = MATCH_OUTPUT;
518 } else if (matcher2_->Type(true) == MATCH_INPUT) {
519 match_type_ = MATCH_INPUT;
520 } else {
521 FSTERROR() << "ComposeFst: 1st argument cannot match on output labels "
522 << "and 2nd argument cannot match on input labels (sort?).";
523 match_type_ = MATCH_NONE;
524 }
525 }
526
527 } // namespace internal
528
529 // Computes the composition of two transducers. This version is a delayed FST.
530 // If FST1 transduces string x to y with weight a and FST2 transduces y to z
531 // with weight b, then their composition transduces string x to z with weight
532 // Times(x, z).
533 //
534 // The output labels of the first transducer or the input labels of the second
535 // transducer must be sorted (with the default matcher). The weights need to
536 // form a commutative semiring (valid for TropicalWeight and LogWeight).
537 //
538 // Complexity:
539 //
540 // Assuming the first FST is unsorted and the second is sorted,
541 //
542 // Time: O(v1 v2 d1 (log d2 + m2)),
543 // Space: O(v1 v2)
544 //
545 // where vi = # of states visited, di = maximum out-degree, and mi the
546 // maximum multiplicity of the states visited, for the ith FST. Constant time
547 // and space to visit an input state or arc is assumed and exclusive of caching.
548 //
549 // Caveats:
550 // - ComposeFst does not trim its output (since it is a delayed operation).
551 // - The efficiency of composition can be strongly affected by several factors:
552 // - the choice of which transducer is sorted - prefer sorting the FST
553 // that has the greater average out-degree.
554 // - the amount of non-determinism
555 // - the presence and location of epsilon transitions - avoid epsilon
556 // transitions on the output side of the first transducer or
557 // the input side of the second transducer or prefer placing
558 // them later in a path since they delay matching and can
559 // introduce non-coaccessible states and transitions.
560 //
561 // This class attaches interface to implementation and handles reference
562 // counting, delegating most methods to ImplToFst. The CacheStore specifies the
563 // cache store (default declared in fst-decl.h).
564 template <class A, class CacheStore /* = DefaultCacheStore<A> */>
565 class ComposeFst
566 : public ImplToFst<internal::ComposeFstImplBase<A, CacheStore>> {
567 public:
568 using Arc = A;
569 using StateId = typename Arc::StateId;
570 using Weight = typename Arc::Weight;
571
572 using Store = CacheStore;
573 using State = typename CacheStore::State;
574
575 using Impl = internal::ComposeFstImplBase<A, CacheStore>;
576
577 friend class ArcIterator<ComposeFst<Arc, CacheStore>>;
578 friend class StateIterator<ComposeFst<Arc, CacheStore>>;
579 template <class, class, class>
580 friend class ComposeFstMatcher;
581
582 // Compose specifying only caching options.
583 ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
584 const CacheOptions &opts = CacheOptions())
CreateBase(fst1,fst2,opts)585 : ImplToFst<Impl>(CreateBase(fst1, fst2, opts)) {}
586
587 // Compose specifying one shared matcher type M. Requires that the input FSTs
588 // and matcher FST types be Fst<Arc>. Recommended for best code-sharing and
589 // matcher compatiblity.
590 template <class Matcher, class Filter, class StateTuple>
ComposeFst(const Fst<Arc> & fst1,const Fst<Arc> & fst2,const ComposeFstOptions<Arc,Matcher,Filter,StateTuple> & opts)591 ComposeFst(const Fst<Arc> &fst1, const Fst<Arc> &fst2,
592 const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts)
593 : ImplToFst<Impl>(CreateBase1(fst1, fst2, opts)) {}
594
595 // Compose specifying two matcher types Matcher1 and Matcher2. Requires input
596 // FST (of the same Arc type, but o.w. arbitrary) match the corresponding
597 // matcher FST types). Recommended only for advanced use in demanding or
598 // specialized applications due to potential code bloat and matcher
599 // incompatibilities.
600 template <class Matcher1, class Matcher2, class Filter, class StateTuple>
ComposeFst(const typename Matcher1::FST & fst1,const typename Matcher2::FST & fst2,const ComposeFstImplOptions<Matcher1,Matcher2,Filter,StateTuple,CacheStore> & opts)601 ComposeFst(const typename Matcher1::FST &fst1,
602 const typename Matcher2::FST &fst2,
603 const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
604 CacheStore> &opts)
605 : ImplToFst<Impl>(CreateBase2(fst1, fst2, opts)) {}
606
607 // See Fst<>::Copy() for doc.
608 ComposeFst(const ComposeFst &fst, bool safe = false)
609 : ImplToFst<Impl>(safe ? std::shared_ptr<Impl>(fst.GetImpl()->Copy())
610 : fst.GetSharedImpl()) {}
611
612 // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc.
613 ComposeFst *Copy(bool safe = false) const override {
614 return new ComposeFst(*this, safe);
615 }
616
617 inline void InitStateIterator(StateIteratorData<Arc> *data) const override;
618
InitArcIterator(StateId s,ArcIteratorData<Arc> * data)619 void InitArcIterator(StateId s, ArcIteratorData<Arc> *data) const override {
620 GetMutableImpl()->InitArcIterator(s, data);
621 }
622
InitMatcher(MatchType match_type)623 MatcherBase<Arc> *InitMatcher(MatchType match_type) const override {
624 return GetImpl()->InitMatcher(*this, match_type);
625 }
626
627 protected:
628 using ImplToFst<Impl>::GetImpl;
629 using ImplToFst<Impl>::GetMutableImpl;
630
ComposeFst(std::shared_ptr<Impl> impl)631 explicit ComposeFst(std::shared_ptr<Impl> impl) : ImplToFst<Impl>(impl) {}
632
633 // Create compose implementation specifying two matcher types.
634 template <class Matcher1, class Matcher2, class Filter, class StateTuple>
CreateBase2(const typename Matcher1::FST & fst1,const typename Matcher2::FST & fst2,const ComposeFstImplOptions<Matcher1,Matcher2,Filter,StateTuple,CacheStore> & opts)635 static std::shared_ptr<Impl> CreateBase2(
636 const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2,
637 const ComposeFstImplOptions<Matcher1, Matcher2, Filter, StateTuple,
638 CacheStore> &opts) {
639 auto impl = std::make_shared<
640 internal::ComposeFstImpl<CacheStore, Filter, StateTuple>>(fst1, fst2,
641 opts);
642 if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) {
643 const auto props1 = fst1.Properties(kUnweighted, true);
644 const auto props2 = fst2.Properties(kUnweighted, true);
645 if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) {
646 FSTERROR() << "ComposeFst: Weights must be a commutative semiring: "
647 << Weight::Type();
648 impl->SetProperties(kError, kError);
649 }
650 }
651 return impl;
652 }
653
654 // Create compose implementation specifying one matcher type; requires that
655 // input and matcher FST types be Fst<Arc>.
656 template <class Matcher, class Filter, class StateTuple>
CreateBase1(const Fst<Arc> & fst1,const Fst<Arc> & fst2,const ComposeFstOptions<Arc,Matcher,Filter,StateTuple> & opts)657 static std::shared_ptr<Impl> CreateBase1(
658 const Fst<Arc> &fst1, const Fst<Arc> &fst2,
659 const ComposeFstOptions<Arc, Matcher, Filter, StateTuple> &opts) {
660 ComposeFstImplOptions<Matcher, Matcher, Filter, StateTuple, CacheStore>
661 nopts(opts, opts.matcher1, opts.matcher2, opts.filter,
662 opts.state_table);
663 return CreateBase2(fst1, fst2, nopts);
664 }
665
666 // Create compose implementation specifying no matcher type.
CreateBase(const Fst<Arc> & fst1,const Fst<Arc> & fst2,const CacheOptions & opts)667 static std::shared_ptr<Impl> CreateBase(const Fst<Arc> &fst1,
668 const Fst<Arc> &fst2,
669 const CacheOptions &opts) {
670 switch (LookAheadMatchType(fst1, fst2)) { // Check for lookahead matchers
671 default:
672 case MATCH_NONE: { // Default composition (no look-ahead).
673 ComposeFstOptions<Arc> nopts(opts);
674 return CreateBase1(fst1, fst2, nopts);
675 }
676 case MATCH_OUTPUT: { // Lookahead on fst1.
677 using M = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::FstMatcher;
678 using F = typename DefaultLookAhead<Arc, MATCH_OUTPUT>::ComposeFilter;
679 ComposeFstOptions<Arc, M, F> nopts(opts);
680 return CreateBase1(fst1, fst2, nopts);
681 }
682 case MATCH_INPUT: { // Lookahead on fst2
683 using M = typename DefaultLookAhead<Arc, MATCH_INPUT>::FstMatcher;
684 using F = typename DefaultLookAhead<Arc, MATCH_INPUT>::ComposeFilter;
685 ComposeFstOptions<Arc, M, F> nopts(opts);
686 return CreateBase1(fst1, fst2, nopts);
687 }
688 }
689 }
690
691 private:
692 ComposeFst &operator=(const ComposeFst &fst) = delete;
693 };
694
695 // Specialization for ComposeFst.
696 template <class Arc, class CacheStore>
697 class StateIterator<ComposeFst<Arc, CacheStore>>
698 : public CacheStateIterator<ComposeFst<Arc, CacheStore>> {
699 public:
StateIterator(const ComposeFst<Arc,CacheStore> & fst)700 explicit StateIterator(const ComposeFst<Arc, CacheStore> &fst)
701 : CacheStateIterator<ComposeFst<Arc, CacheStore>>(fst,
702 fst.GetMutableImpl()) {}
703 };
704
705 // Specialization for ComposeFst.
706 template <class Arc, class CacheStore>
707 class ArcIterator<ComposeFst<Arc, CacheStore>>
708 : public CacheArcIterator<ComposeFst<Arc, CacheStore>> {
709 public:
710 using StateId = typename Arc::StateId;
711
ArcIterator(const ComposeFst<Arc,CacheStore> & fst,StateId s)712 ArcIterator(const ComposeFst<Arc, CacheStore> &fst, StateId s)
713 : CacheArcIterator<ComposeFst<Arc, CacheStore>>(fst.GetMutableImpl(), s) {
714 if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s);
715 }
716 };
717
718 template <class Arc, class CacheStore>
InitStateIterator(StateIteratorData<Arc> * data)719 inline void ComposeFst<Arc, CacheStore>::InitStateIterator(
720 StateIteratorData<Arc> *data) const {
721 data->base =
722 std::make_unique<StateIterator<ComposeFst<Arc, CacheStore>>>(*this);
723 }
724
725 // Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT,
726 // iff the underlying matchers for the two FSTS being composed support
727 // MATCH_INPUT or MATCH_OUTPUT, respectively.
728 template <class CacheStore, class Filter, class StateTable>
729 class ComposeFstMatcher : public MatcherBase<typename CacheStore::Arc> {
730 public:
731 using Arc = typename CacheStore::Arc;
732 using Label = typename Arc::Label;
733 using StateId = typename Arc::StateId;
734 using Weight = typename Arc::Weight;
735
736 using Matcher1 = typename Filter::Matcher1;
737 using Matcher2 = typename Filter::Matcher2;
738 using FilterState = typename Filter::FilterState;
739
740 using StateTuple = typename StateTable::StateTuple;
741 using Impl = internal::ComposeFstImpl<CacheStore, Filter, StateTable>;
742
743 // The compose FST arg must match the filter and state table types.
744 // This makes a copy of the FST.
ComposeFstMatcher(const ComposeFst<Arc,CacheStore> & fst,MatchType match_type)745 ComposeFstMatcher(const ComposeFst<Arc, CacheStore> &fst,
746 MatchType match_type)
747 : owned_fst_(fst.Copy()),
748 fst_(*owned_fst_),
749 impl_(fst::down_cast<const Impl *>(fst_.GetImpl())),
750 s_(kNoStateId),
751 match_type_(match_type),
752 matcher1_(impl_->matcher1_->Copy()),
753 matcher2_(impl_->matcher2_->Copy()),
754 current_loop_(false),
755 loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
756 if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
757 }
758
759 // The compose FST arg must match the filter and state table types.
760 // This doesn't copy the FST (although it may copy components).
ComposeFstMatcher(const ComposeFst<Arc,CacheStore> * fst,MatchType match_type)761 ComposeFstMatcher(const ComposeFst<Arc, CacheStore> *fst,
762 MatchType match_type)
763 : fst_(*fst),
764 impl_(fst::down_cast<const Impl *>(fst_.GetImpl())),
765 s_(kNoStateId),
766 match_type_(match_type),
767 matcher1_(impl_->matcher1_->Copy()),
768 matcher2_(impl_->matcher2_->Copy()),
769 current_loop_(false),
770 loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
771 if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
772 }
773
774 // This makes a copy of the FST.
775 ComposeFstMatcher(
776 const ComposeFstMatcher<CacheStore, Filter, StateTable> &matcher,
777 bool safe = false)
778 : owned_fst_(matcher.fst_.Copy(safe)),
779 fst_(*owned_fst_),
780 impl_(fst::down_cast<const Impl *>(fst_.GetImpl())),
781 s_(kNoStateId),
782 match_type_(matcher.match_type_),
783 matcher1_(matcher.matcher1_->Copy(safe)),
784 matcher2_(matcher.matcher2_->Copy(safe)),
785 current_loop_(false),
786 loop_(kNoLabel, 0, Weight::One(), kNoStateId) {
787 if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel);
788 }
789
790 ComposeFstMatcher *Copy(bool safe = false) const override {
791 return new ComposeFstMatcher(*this, safe);
792 }
793
Type(bool test)794 MatchType Type(bool test) const override {
795 if ((matcher1_->Type(test) == MATCH_NONE) ||
796 (matcher2_->Type(test) == MATCH_NONE)) {
797 return MATCH_NONE;
798 }
799 if (((matcher1_->Type(test) == MATCH_UNKNOWN) &&
800 (matcher2_->Type(test) == MATCH_UNKNOWN)) ||
801 ((matcher1_->Type(test) == MATCH_UNKNOWN) &&
802 (matcher2_->Type(test) == match_type_)) ||
803 ((matcher1_->Type(test) == match_type_) &&
804 (matcher2_->Type(test) == MATCH_UNKNOWN))) {
805 return MATCH_UNKNOWN;
806 }
807 if ((matcher1_->Type(test) == match_type_) &&
808 (matcher2_->Type(test) == match_type_)) {
809 return match_type_;
810 }
811 return MATCH_NONE;
812 }
813
GetFst()814 const Fst<Arc> &GetFst() const override { return fst_; }
815
Properties(uint64 inprops)816 uint64 Properties(uint64 inprops) const override { return inprops; }
817
SetState(StateId s)818 void SetState(StateId s) final {
819 if (s_ == s) return;
820 s_ = s;
821 const auto &tuple = impl_->state_table_->Tuple(s);
822 matcher1_->SetState(tuple.StateId1());
823 matcher2_->SetState(tuple.StateId2());
824 loop_.nextstate = s_;
825 }
826
Find(Label label)827 bool Find(Label label) final {
828 bool found = false;
829 current_loop_ = false;
830 if (label == 0) {
831 current_loop_ = true;
832 found = true;
833 }
834 if (match_type_ == MATCH_INPUT) {
835 found = found || FindLabel(label, matcher1_.get(), matcher2_.get());
836 } else { // match_type_ == MATCH_OUTPUT
837 found = found || FindLabel(label, matcher2_.get(), matcher1_.get());
838 }
839 return found;
840 }
841
Done()842 bool Done() const final {
843 return !current_loop_ && matcher1_->Done() && matcher2_->Done();
844 }
845
Value()846 const Arc &Value() const final { return current_loop_ ? loop_ : arc_; }
847
Next()848 void Next() final {
849 if (current_loop_) {
850 current_loop_ = false;
851 } else if (match_type_ == MATCH_INPUT) {
852 FindNext(matcher1_.get(), matcher2_.get());
853 } else { // match_type_ == MATCH_OUTPUT
854 FindNext(matcher2_.get(), matcher1_.get());
855 }
856 }
857
Priority(StateId s)858 ssize_t Priority(StateId s) final { return fst_.NumArcs(s); }
859
860 private:
861 // Processes a match with the filter and creates resulting arc.
MatchArc(StateId s,Arc * arc1,Arc * arc2)862 bool MatchArc(StateId s, Arc *arc1, Arc *arc2) {
863 const auto &fs = impl_->filter_->FilterArc(arc1, arc2);
864 if (fs == FilterState::NoState()) return false;
865 const StateTuple tuple(arc1->nextstate, arc2->nextstate, fs);
866 arc_.ilabel = arc1->ilabel;
867 arc_.olabel = arc2->olabel;
868 arc_.weight = Times(arc1->weight, arc2->weight);
869 arc_.nextstate = impl_->state_table_->FindState(tuple);
870 return true;
871 }
872
873 // Finds the first match allowed by the filter.
874 template <class MatcherA, class MatcherB>
FindLabel(Label label,MatcherA * matchera,MatcherB * matcherb)875 bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) {
876 if (matchera->Find(label)) {
877 matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel
878 : matchera->Value().ilabel);
879 return FindNext(matchera, matcherb);
880 }
881 return false;
882 }
883
884 // Finds the next match allowed by the filter, returning true iff such a
885 // match is found.
886 template <class MatcherA, class MatcherB>
FindNext(MatcherA * matchera,MatcherB * matcherb)887 bool FindNext(MatcherA *matchera, MatcherB *matcherb) {
888 // State when entering this function:
889 // 'matchera' is pointed to a match x, y for label x, and a match for y was
890 // requested on 'matcherb'.
891 while (!matchera->Done() || !matcherb->Done()) {
892 if (matcherb->Done()) {
893 // If no more matches for y on 'matcherb', moves forward on 'matchera'
894 // until a match x, y' is found such that there is a match for y' on
895 // 'matcherb'.
896 matchera->Next();
897 while (!matchera->Done() &&
898 !matcherb->Find(match_type_ == MATCH_INPUT
899 ? matchera->Value().olabel
900 : matchera->Value().ilabel)) {
901 matchera->Next();
902 }
903 }
904 while (!matcherb->Done()) {
905 // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is
906 // pointing to a match y', z' ('arcb'). If combining these two arcs is
907 // allowed by the filter (hence resulting in an arc x, z') return true.
908 // Position 'matcherb' on the next potential match for y' before
909 // returning.
910 auto arca = matchera->Value();
911 auto arcb = matcherb->Value();
912 // Position 'matcherb' on the next potential match for y'.
913 matcherb->Next();
914 // Returns true If combining these two arcs is allowed by the filter
915 // (hence resulting in an arc x, z'); otherwise consider next match
916 // for y' on 'matcherb'.
917 if (match_type_ == MATCH_INPUT) {
918 return MatchArc(s_, &arca, &arcb);
919 } else {
920 return MatchArc(s_, &arcb, &arca);
921 }
922 }
923 }
924 // Both 'matchera' and 'matcherb' are done, no more match to analyse.
925 return false;
926 }
927
928 std::unique_ptr<const ComposeFst<Arc, CacheStore>> owned_fst_;
929 const ComposeFst<Arc, CacheStore> &fst_;
930 const Impl *impl_;
931 StateId s_;
932 MatchType match_type_;
933 std::unique_ptr<Matcher1> matcher1_;
934 std::unique_ptr<Matcher2> matcher2_;
935 bool current_loop_;
936 Arc loop_;
937 Arc arc_;
938 };
939
940 // Useful alias when using StdArc.
941 using StdComposeFst = ComposeFst<StdArc>;
942
943 enum ComposeFilter {
944 AUTO_FILTER,
945 NULL_FILTER,
946 TRIVIAL_FILTER,
947 SEQUENCE_FILTER,
948 ALT_SEQUENCE_FILTER,
949 MATCH_FILTER,
950 NO_MATCH_FILTER
951 };
952
953 struct ComposeOptions {
954 bool connect; // Connect output?
955 ComposeFilter filter_type; // Pre-defined filter to use.
956
957 explicit ComposeOptions(bool connect = true,
958 ComposeFilter filter_type = AUTO_FILTER)
connectComposeOptions959 : connect(connect), filter_type(filter_type) {}
960 };
961
962 // Computes the composition of two transducers. This version writes
963 // the composed FST into a MutableFst. If FST1 transduces string x to
964 // y with weight a and FST2 transduces y to z with weight b, then
965 // their composition transduces string x to z with weight
966 // Times(a, b).
967 //
968 // The output labels of the first transducer or the input labels of
969 // the second transducer must be sorted. The weights need to form a
970 // commutative semiring (valid for TropicalWeight and LogWeight).
971 //
972 // Complexity:
973 //
974 // Assuming the first FST is unsorted and the second is sorted:
975 //
976 // Time: O(V1 V2 D1 (log D2 + M2)),
977 // Space: O(V1 V2 D1 M2)
978 //
979 // where Vi = # of states, Di = maximum out-degree, and Mi is the maximum
980 // multiplicity, for the ith FST.
981 //
982 // Caveats:
983 //
984 // - Compose trims its output.
985 // - The efficiency of composition can be strongly affected by several factors:
986 // - the choice of which transducer is sorted - prefer sorting the FST
987 // that has the greater average out-degree.
988 // - the amount of non-determinism
989 // - the presence and location of epsilon transitions - avoid epsilon
990 // transitions on the output side of the first transducer or
991 // the input side of the second transducer or prefer placing
992 // them later in a path since they delay matching and can
993 // introduce non-coaccessible states and transitions.
994 template <class Arc>
995 void Compose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
996 MutableFst<Arc> *ofst,
997 const ComposeOptions &opts = ComposeOptions()) {
998 using M = Matcher<Fst<Arc>>;
999 // In each case, we cache only the last state for fastest copy.
1000 switch (opts.filter_type) {
1001 case AUTO_FILTER: {
1002 CacheOptions nopts;
1003 nopts.gc_limit = 0;
1004 *ofst = ComposeFst<Arc>(ifst1, ifst2, nopts);
1005 break;
1006 }
1007 case NULL_FILTER: {
1008 ComposeFstOptions<Arc, M, NullComposeFilter<M>> copts;
1009 copts.gc_limit = 0;
1010 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1011 break;
1012 }
1013 case SEQUENCE_FILTER: {
1014 ComposeFstOptions<Arc, M, SequenceComposeFilter<M>> copts;
1015 copts.gc_limit = 0;
1016 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1017 break;
1018 }
1019 case ALT_SEQUENCE_FILTER: {
1020 ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M>> copts;
1021 copts.gc_limit = 0;
1022 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1023 break;
1024 }
1025 case MATCH_FILTER: {
1026 ComposeFstOptions<Arc, M, MatchComposeFilter<M>> copts;
1027 copts.gc_limit = 0;
1028 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1029 break;
1030 }
1031 case NO_MATCH_FILTER: {
1032 ComposeFstOptions<Arc, M, NoMatchComposeFilter<M>> copts;
1033 copts.gc_limit = 0;
1034 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1035 break;
1036 }
1037 case TRIVIAL_FILTER: {
1038 ComposeFstOptions<Arc, M, TrivialComposeFilter<M>> copts;
1039 copts.gc_limit = 0;
1040 *ofst = ComposeFst<Arc>(ifst1, ifst2, copts);
1041 break;
1042 }
1043 }
1044 if (opts.connect) Connect(ofst);
1045 }
1046
1047 } // namespace fst
1048
1049 #endif // FST_COMPOSE_H_
1050