1 // replace.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes for the recursive replacement of Fsts.
20 //
21 
22 #ifndef FST_LIB_REPLACE_H__
23 #define FST_LIB_REPLACE_H__
24 
25 #include <unordered_map>
26 using std::unordered_map;
27 using std::unordered_multimap;
28 #include <set>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34 
35 #include <fst/cache.h>
36 #include <fst/expanded-fst.h>
37 #include <fst/fst.h>
38 #include <fst/matcher.h>
39 #include <fst/replace-util.h>
40 #include <fst/state-table.h>
41 #include <fst/test-properties.h>
42 
43 namespace fst {
44 
45 //
46 // REPLACE STATE TUPLES AND TABLES
47 //
48 // The replace state table has the form
49 //
50 // template <class A, class P>
51 // class ReplaceStateTable {
52 //  public:
53 //   typedef A Arc;
54 //   typedef P PrefixId;
55 //   typedef typename A::StateId StateId;
56 //   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
57 //   typedef typename A::Label Label;
58 //   typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
59 //
60 //   // Required constuctor
61 //   ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
62 //                     Label root);
63 //
64 //   // Required copy constructor that does not copy state
65 //   ReplaceStateTable(const ReplaceStateTable<A,P> &table);
66 //
67 //   // Lookup state ID by tuple. If it doesn't exist, then add it.
68 //   StateId FindState(const StateTuple &tuple);
69 //
70 //   // Lookup state tuple by ID.
71 //   const StateTuple &Tuple(StateId id) const;
72 //
73 //   // Lookup prefix ID by stack prefix. If it doesn't exist, then add it.
74 //   PrefixId FindPrefixId(const StackPrefix &stack_prefix);
75 //
76 //  // Look stack prefix by ID.
77 //  const StackPrefix &GetStackPrefix(PrefixId id) const;
78 // };
79 
80 //
81 // Replace State Tuples
82 //
83 
84 // \struct ReplaceStateTuple
85 // \brief Tuple of information that uniquely defines a state in replace
86 template <class S, class P>
87 struct ReplaceStateTuple {
88   typedef S StateId;
89   typedef P PrefixId;
90 
ReplaceStateTupleReplaceStateTuple91   ReplaceStateTuple()
92       : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
93 
ReplaceStateTupleReplaceStateTuple94   ReplaceStateTuple(PrefixId p, StateId f, StateId s)
95       : prefix_id(p), fst_id(f), fst_state(s) {}
96 
97   PrefixId prefix_id;  // index in prefix table
98   StateId fst_id;      // current fst being walked
99   StateId fst_state;   // current state in fst being walked, not to be
100                        // confused with the state_id of the combined fst
101 };
102 
103 // Equality of replace state tuples.
104 template <class S, class P>
105 inline bool operator==(const ReplaceStateTuple<S, P>& x,
106                        const ReplaceStateTuple<S, P>& y) {
107   return x.prefix_id == y.prefix_id &&
108       x.fst_id == y.fst_id &&
109       x.fst_state == y.fst_state;
110 }
111 
112 // \class ReplaceRootSelector
113 // Functor returning true for tuples corresponding to states in the root FST
114 template <class S, class P>
115 class ReplaceRootSelector {
116  public:
operator()117   bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
118     return tuple.prefix_id == 0;
119   }
120 };
121 
122 // \class ReplaceFingerprint
123 // Fingerprint for general replace state tuples.
124 template <class S, class P>
125 class ReplaceFingerprint {
126  public:
ReplaceFingerprint(const vector<uint64> * size_array)127   explicit ReplaceFingerprint(const vector<uint64> *size_array)
128       : cumulative_size_array_(size_array) {}
129 
operator()130   uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
131     return tuple.prefix_id * (cumulative_size_array_->back()) +
132         cumulative_size_array_->at(tuple.fst_id - 1) +
133         tuple.fst_state;
134   }
135 
136  private:
137   const vector<uint64> *cumulative_size_array_;
138 };
139 
140 // \class ReplaceFstStateFingerprint
141 // Useful when the fst_state uniquely define the tuple.
142 template <class S, class P>
143 class ReplaceFstStateFingerprint {
144  public:
operator()145   uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
146     return tuple.fst_state;
147   }
148 };
149 
150 // \class ReplaceHash
151 // A generic hash function for replace state tuples.
152 template <typename S, typename P>
153 class ReplaceHash {
154  public:
operator()155   size_t operator()(const ReplaceStateTuple<S, P>& t) const {
156     return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
157   }
158  private:
159   static const size_t kPrime0;
160   static const size_t kPrime1;
161 };
162 
163 template <typename S, typename P>
164 const size_t ReplaceHash<S, P>::kPrime0 = 7853;
165 
166 template <typename S, typename P>
167 const size_t ReplaceHash<S, P>::kPrime1 = 7867;
168 
169 
170 //
171 // Replace Stack Prefix
172 //
173 
174 // \class ReplaceStackPrefix
175 // \brief Container for stack prefix.
176 template <class L, class S>
177 class ReplaceStackPrefix {
178  public:
179   typedef L Label;
180   typedef S StateId;
181 
182   // \class PrefixTuple
183   // \brief Tuple of fst_id and destination state (entry in stack prefix)
184   struct PrefixTuple {
PrefixTuplePrefixTuple185     PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
PrefixTuplePrefixTuple186     PrefixTuple() : fst_id(kNoLabel), nextstate(kNoStateId) {}
187 
188     Label   fst_id;
189     StateId nextstate;
190   };
191 
ReplaceStackPrefix()192   ReplaceStackPrefix() {}
193 
194   // copy constructor
ReplaceStackPrefix(const ReplaceStackPrefix & x)195   ReplaceStackPrefix(const ReplaceStackPrefix& x) :
196       prefix_(x.prefix_) {
197   }
198 
Push(StateId fst_id,StateId nextstate)199   void Push(StateId fst_id, StateId nextstate) {
200     prefix_.push_back(PrefixTuple(fst_id, nextstate));
201   }
202 
Pop()203   void Pop() {
204     prefix_.pop_back();
205   }
206 
Top()207   const PrefixTuple& Top() const {
208     return prefix_[prefix_.size()-1];
209   }
210 
Depth()211   size_t Depth() const {
212     return prefix_.size();
213   }
214 
215  public:
216   vector<PrefixTuple> prefix_;
217 };
218 
219 // Equality stack prefix classes
220 template <class L, class S>
221 inline bool operator==(const ReplaceStackPrefix<L, S>& x,
222                 const ReplaceStackPrefix<L, S>& y) {
223   if (x.prefix_.size() != y.prefix_.size()) return false;
224   for (size_t i = 0; i < x.prefix_.size(); ++i) {
225     if (x.prefix_[i].fst_id    != y.prefix_[i].fst_id ||
226         x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
227   }
228   return true;
229 }
230 
231 //
232 // \class ReplaceStackPrefixHash
233 // \brief Hash function for stack prefix to prefix id
234 template <class L, class S>
235 class ReplaceStackPrefixHash {
236  public:
operator()237   size_t operator()(const ReplaceStackPrefix<L, S>& x) const {
238     size_t sum = 0;
239     for (size_t i = 0; i < x.prefix_.size(); ++i) {
240       sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
241     }
242     return sum;
243   }
244 
245  private:
246   static const size_t kPrime0;
247 };
248 
249 template <class L, class S>
250 const size_t ReplaceStackPrefixHash<L, S>::kPrime0 = 7853;
251 
252 //
253 // Replace State Tables
254 //
255 
256 // \class VectorHashReplaceStateTable
257 // A two-level state table for replace.
258 // Warning: calls CountStates to compute the number of states of each
259 // component Fst.
260 template <class A, class P = ssize_t>
261 class VectorHashReplaceStateTable {
262  public:
263   typedef A Arc;
264   typedef typename A::StateId StateId;
265   typedef typename A::Label Label;
266   typedef P PrefixId;
267   typedef ReplaceStateTuple<StateId, P> StateTuple;
268   typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
269                                ReplaceRootSelector<StateId, P>,
270                                ReplaceFstStateFingerprint<StateId, P>,
271                                ReplaceFingerprint<StateId, P> > StateTable;
272   typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
273   typedef CompactHashBiTable<
274     PrefixId, StackPrefix,
275     ReplaceStackPrefixHash<Label, StateId> > StackPrefixTable;
276 
VectorHashReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)277   VectorHashReplaceStateTable(
278       const vector<pair<Label, const Fst<A>*> > &fst_tuples,
279       Label root) : root_size_(0) {
280     cumulative_size_array_.push_back(0);
281     for (size_t i = 0; i < fst_tuples.size(); ++i) {
282       if (fst_tuples[i].first == root) {
283         root_size_ = CountStates(*(fst_tuples[i].second));
284         cumulative_size_array_.push_back(cumulative_size_array_.back());
285       } else {
286         cumulative_size_array_.push_back(cumulative_size_array_.back() +
287                                          CountStates(*(fst_tuples[i].second)));
288       }
289     }
290     state_table_ = new StateTable(
291         new ReplaceRootSelector<StateId, P>,
292         new ReplaceFstStateFingerprint<StateId, P>,
293         new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
294         root_size_,
295         root_size_ + cumulative_size_array_.back());
296   }
297 
VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A,P> & table)298   VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
299       : root_size_(table.root_size_),
300         cumulative_size_array_(table.cumulative_size_array_),
301         prefix_table_(table.prefix_table_) {
302     state_table_ = new StateTable(
303         new ReplaceRootSelector<StateId, P>,
304         new ReplaceFstStateFingerprint<StateId, P>,
305         new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
306         root_size_,
307         root_size_ + cumulative_size_array_.back());
308   }
309 
~VectorHashReplaceStateTable()310   ~VectorHashReplaceStateTable() {
311     delete state_table_;
312   }
313 
FindState(const StateTuple & tuple)314   StateId FindState(const StateTuple &tuple) {
315     return state_table_->FindState(tuple);
316   }
317 
Tuple(StateId id)318   const StateTuple &Tuple(StateId id) const {
319     return state_table_->Tuple(id);
320   }
321 
FindPrefixId(const StackPrefix & prefix)322   PrefixId FindPrefixId(const StackPrefix &prefix) {
323     return prefix_table_.FindId(prefix);
324   }
325 
GetStackPrefix(PrefixId id)326   const StackPrefix &GetStackPrefix(PrefixId id) const {
327     return prefix_table_.FindEntry(id);
328   }
329 
330  private:
331   StateId root_size_;
332   vector<uint64> cumulative_size_array_;
333   StateTable *state_table_;
334   StackPrefixTable prefix_table_;
335 };
336 
337 
338 // \class DefaultReplaceStateTable
339 // Default replace state table
340 template <class A, class P = ssize_t>
341 class DefaultReplaceStateTable : public CompactHashStateTable<
342   ReplaceStateTuple<typename A::StateId, P>,
343   ReplaceHash<typename A::StateId, P> > {
344  public:
345   typedef A Arc;
346   typedef typename A::StateId StateId;
347   typedef typename A::Label Label;
348   typedef P PrefixId;
349   typedef ReplaceStateTuple<StateId, P> StateTuple;
350   typedef CompactHashStateTable<StateTuple,
351                                 ReplaceHash<StateId, PrefixId> > StateTable;
352   typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
353   typedef CompactHashBiTable<
354     PrefixId, StackPrefix,
355     ReplaceStackPrefixHash<Label, StateId> > StackPrefixTable;
356 
357   using StateTable::FindState;
358   using StateTable::Tuple;
359 
DefaultReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)360   DefaultReplaceStateTable(
361       const vector<pair<Label, const Fst<A>*> > &fst_tuples,
362       Label root) {}
363 
DefaultReplaceStateTable(const DefaultReplaceStateTable<A,P> & table)364   DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
365       : StateTable(), prefix_table_(table.prefix_table_) {}
366 
FindPrefixId(const StackPrefix & prefix)367   PrefixId FindPrefixId(const StackPrefix &prefix) {
368     return prefix_table_.FindId(prefix);
369   }
370 
GetStackPrefix(PrefixId id)371   const StackPrefix &GetStackPrefix(PrefixId id) const {
372     return prefix_table_.FindEntry(id);
373   }
374 
375  private:
376   StackPrefixTable prefix_table_;
377 };
378 
379 //
380 // REPLACE FST CLASS
381 //
382 
383 // By default ReplaceFst will copy the input label of the 'replace arc'.
384 // The call_label_type and return_label_type options specify how to manage
385 // the labels of the call arc and the return arc of the replace FST
386 template <class A, class T = DefaultReplaceStateTable<A>,
387           class C = DefaultCacheStore<A> >
388 struct ReplaceFstOptions : CacheImplOptions<C> {
389   int64 root;    // root rule for expansion
390   ReplaceLabelType call_label_type;  // how to label call arc
391   ReplaceLabelType return_label_type;  // how to label return arc
392   int64 call_output_label;  // specifies output label to put on call arc
393                             // if kNoLabel, use existing label on call arc
394                             // if 0, epsilon
395                             // otherwise, use this field as the output label
396   int64 return_label;  // specifies label to put on return arc
397   bool  take_ownership;  // take ownership of input Fst(s)
398   T*    state_table;
399 
ReplaceFstOptionsReplaceFstOptions400   ReplaceFstOptions(const CacheImplOptions<C> &opts, int64 r)
401       : CacheImplOptions<C>(opts),
402         root(r),
403         call_label_type(REPLACE_LABEL_INPUT),
404         return_label_type(REPLACE_LABEL_NEITHER),
405         call_output_label(kNoLabel),
406         return_label(0),
407         take_ownership(false),
408         state_table(0) {}
409 
ReplaceFstOptionsReplaceFstOptions410   ReplaceFstOptions(const CacheOptions &opts, int64 r)
411       : CacheImplOptions<C>(opts),
412         root(r),
413         call_label_type(REPLACE_LABEL_INPUT),
414         return_label_type(REPLACE_LABEL_NEITHER),
415         call_output_label(kNoLabel),
416         return_label(0),
417         take_ownership(false),
418         state_table(0) {}
419 
ReplaceFstOptionsReplaceFstOptions420   explicit ReplaceFstOptions(const fst::ReplaceUtilOptions<A> &opts)
421       : root(opts.root),
422         call_label_type(opts.call_label_type),
423         return_label_type(opts.return_label_type),
424         call_output_label(kNoLabel),
425         return_label(opts.return_label),
426         take_ownership(false),
427         state_table(0) {}
428 
ReplaceFstOptionsReplaceFstOptions429   explicit ReplaceFstOptions(int64 r)
430       : root(r),
431         call_label_type(REPLACE_LABEL_INPUT),
432         return_label_type(REPLACE_LABEL_NEITHER),
433         call_output_label(kNoLabel),
434         return_label(0),
435         take_ownership(false),
436         state_table(0) {}
437 
ReplaceFstOptionsReplaceFstOptions438   ReplaceFstOptions(int64 r, ReplaceLabelType call_label_type,
439                     ReplaceLabelType return_label_type, int64 return_label)
440       : root(r),
441         call_label_type(call_label_type),
442         return_label_type(return_label_type),
443         call_output_label(kNoLabel),
444         return_label(return_label),
445         take_ownership(false),
446         state_table(0) {}
447 
ReplaceFstOptionsReplaceFstOptions448   ReplaceFstOptions(int64 r, ReplaceLabelType call_label_type,
449                     ReplaceLabelType return_label_type, int64 call_output_label,
450                     int64 return_label)
451       : root(r),
452         call_label_type(call_label_type),
453         return_label_type(return_label_type),
454         call_output_label(call_output_label),
455         return_label(return_label),
456         take_ownership(false),
457         state_table(0) {}
458 
ReplaceFstOptionsReplaceFstOptions459   ReplaceFstOptions(int64 r, bool epsilon_replace_arc)  // b/w compatibility
460       : root(r),
461         call_label_type((epsilon_replace_arc) ? REPLACE_LABEL_NEITHER :
462                         REPLACE_LABEL_INPUT),
463         return_label_type(REPLACE_LABEL_NEITHER),
464         call_output_label((epsilon_replace_arc) ? 0 : kNoLabel),
465         return_label(0),
466         take_ownership(false),
467         state_table(0) {}
468 
ReplaceFstOptionsReplaceFstOptions469   ReplaceFstOptions()
470       : root(kNoLabel),
471         call_label_type(REPLACE_LABEL_INPUT),
472         return_label_type(REPLACE_LABEL_NEITHER),
473         call_output_label(kNoLabel),
474         return_label(0),
475         take_ownership(false),
476         state_table(0) {}
477 };
478 
479 // Forward declaration
480 template <class A, class T, class C> class ReplaceFstMatcher;
481 
482 // \class ReplaceFstImpl
483 // \brief Implementation class for replace class Fst
484 //
485 // The replace implementation class supports a dynamic
486 // expansion of a recursive transition network represented as Fst
487 // with dynamic replacable arcs.
488 //
489 template <class A, class T, class C>
490 class ReplaceFstImpl : public CacheBaseImpl<typename C::State, C> {
491   friend class ReplaceFstMatcher<A, T, C>;
492 
493  public:
494   typedef A Arc;
495   typedef typename A::Label   Label;
496   typedef typename A::Weight  Weight;
497   typedef typename A::StateId StateId;
498   typedef typename C::State State;
499   typedef CacheBaseImpl<State, C> CImpl;
500   typedef unordered_map<Label, Label> NonTerminalHash;
501   typedef T StateTable;
502   typedef typename T::PrefixId PrefixId;
503   typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
504   typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
505 
506   using FstImpl<A>::SetType;
507   using FstImpl<A>::SetProperties;
508   using FstImpl<A>::WriteHeader;
509   using FstImpl<A>::SetInputSymbols;
510   using FstImpl<A>::SetOutputSymbols;
511   using FstImpl<A>::InputSymbols;
512   using FstImpl<A>::OutputSymbols;
513 
514   using CImpl::PushArc;
515   using CImpl::HasArcs;
516   using CImpl::HasFinal;
517   using CImpl::HasStart;
518   using CImpl::SetArcs;
519   using CImpl::SetFinal;
520   using CImpl::SetStart;
521 
522   // constructor for replace class implementation.
523   // \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector<pair<Label,const Fst<A> * >> & fst_tuples,const ReplaceFstOptions<A,T,C> & opts)524   ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
525                  const ReplaceFstOptions<A, T, C> &opts)
526       : CImpl(opts),
527         call_label_type_(opts.call_label_type),
528         return_label_type_(opts.return_label_type),
529         call_output_label_(opts.call_output_label),
530         return_label_(opts.return_label),
531         state_table_(opts.state_table ? opts.state_table :
532                      new StateTable(fst_tuples, opts.root)) {
533     SetType("replace");
534 
535     // if the label is epsilon, then all REPLACE_LABEL_* options equivalent.
536     // Set the label_type to NEITHER for ease of setting properties later
537     if (call_output_label_ == 0)
538       call_label_type_ = REPLACE_LABEL_NEITHER;
539     if (return_label_ == 0)
540       return_label_type_ = REPLACE_LABEL_NEITHER;
541 
542     if (fst_tuples.size() > 0) {
543       SetInputSymbols(fst_tuples[0].second->InputSymbols());
544       SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
545     }
546 
547     bool all_negative = true;  // all nonterminals are negative?
548     bool dense_range = true;   // all nonterminals are positive
549                                // and form a dense range containing 1?
550     for (size_t i = 0; i < fst_tuples.size(); ++i) {
551       Label nonterminal = fst_tuples[i].first;
552       if (nonterminal >= 0)
553         all_negative = false;
554       if (nonterminal > fst_tuples.size() || nonterminal <= 0)
555         dense_range = false;
556     }
557 
558     vector<uint64> inprops;
559     bool all_ilabel_sorted = true;
560     bool all_olabel_sorted = true;
561     bool all_non_empty = true;
562     fst_array_.push_back(0);
563     for (size_t i = 0; i < fst_tuples.size(); ++i) {
564       Label label = fst_tuples[i].first;
565       const Fst<A> *fst = fst_tuples[i].second;
566       nonterminal_hash_[label] = fst_array_.size();
567       nonterminal_set_.insert(label);
568       fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
569       if (fst->Start() == kNoStateId)
570         all_non_empty = false;
571       if (!fst->Properties(kILabelSorted, false))
572         all_ilabel_sorted = false;
573       if (!fst->Properties(kOLabelSorted, false))
574         all_olabel_sorted = false;
575       inprops.push_back(fst->Properties(kCopyProperties, false));
576       if (i) {
577         if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
578           FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
579                      << " does not match input symbols of base Fst (0'th fst)";
580           SetProperties(kError, kError);
581         }
582         if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
583           FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
584                      << " does not match output symbols of base Fst "
585                      << "(0'th fst)";
586           SetProperties(kError, kError);
587         }
588       }
589     }
590     Label nonterminal = nonterminal_hash_[opts.root];
591     if ((nonterminal == 0) && (fst_array_.size() > 1)) {
592       FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
593                  << opts.root << "' in the input tuple vector";
594       SetProperties(kError, kError);
595     }
596     root_ = (nonterminal > 0) ? nonterminal : 1;
597 
598     SetProperties(ReplaceProperties(inprops, root_ - 1,
599                                     EpsilonOnInput(call_label_type_),
600                                     EpsilonOnInput(return_label_type_),
601                                     ReplaceTransducer(), all_non_empty));
602     // We assume that all terminals are positive.  The resulting
603     // ReplaceFst is known to be kILabelSorted when: (1) all sub-FSTs are
604     // kILabelSorted, (2) the input label of the return arc is epsilon,
605     // and (3) one of the 3 following conditions is satisfied:
606     //  1. the input label of the call arc is not epsilon
607     //  2. all non-terminals are negative, or
608     //  3. all non-terninals are positive and form a dense range containing 1.
609     if (all_ilabel_sorted && EpsilonOnInput(return_label_type_) &&
610         (!EpsilonOnInput(call_label_type_) || all_negative || dense_range)) {
611       SetProperties(kILabelSorted, kILabelSorted);
612     }
613     // Similarly, the resulting ReplaceFst is known to be
614     // kOLabelSorted when: (1) all sub-FSTs are kOLabelSorted, (2) the output
615     // label of the return arc is epsilon, and (3) one of the 3 following
616     // conditions is satisfied:
617     //  1. the output label of the call arc is not epsilon
618     //  2. all non-terminals are negative, or
619     //  3. all non-terninals are positive and form a dense range containing 1.
620     if (all_olabel_sorted && EpsilonOnOutput(return_label_type_) &&
621         (!EpsilonOnOutput(call_label_type_) || all_negative || dense_range))
622       SetProperties(kOLabelSorted, kOLabelSorted);
623 
624     // Enable optional caching as long as sorted and all non empty.
625     if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
626       always_cache_ = false;
627     else
628       always_cache_ = true;
629     VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
630             << (always_cache_ ? "true" : "false");
631   }
632 
ReplaceFstImpl(const ReplaceFstImpl & impl)633   ReplaceFstImpl(const ReplaceFstImpl& impl)
634       : CImpl(impl),
635         call_label_type_(impl.call_label_type_),
636         return_label_type_(impl.return_label_type_),
637         call_output_label_(impl.call_output_label_),
638         return_label_(impl.return_label_),
639         always_cache_(impl.always_cache_),
640         state_table_(new StateTable(*(impl.state_table_))),
641         nonterminal_set_(impl.nonterminal_set_),
642         nonterminal_hash_(impl.nonterminal_hash_),
643         root_(impl.root_) {
644     SetType("replace");
645     SetProperties(impl.Properties(), kCopyProperties);
646     SetInputSymbols(impl.InputSymbols());
647     SetOutputSymbols(impl.OutputSymbols());
648     fst_array_.reserve(impl.fst_array_.size());
649     fst_array_.push_back(0);
650     for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
651       fst_array_.push_back(impl.fst_array_[i]->Copy(true));
652     }
653   }
654 
~ReplaceFstImpl()655   ~ReplaceFstImpl() {
656     delete state_table_;
657     for (size_t i = 1; i < fst_array_.size(); ++i) {
658       delete fst_array_[i];
659     }
660   }
661 
662   // Computes the dependency graph of the replace class and returns
663   // true if the dependencies are cyclic. Cyclic dependencies will result
664   // in an un-expandable replace fst.
CyclicDependencies()665   bool CyclicDependencies() const {
666     ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_,
667                                 fst::ReplaceUtilOptions<A>(root_));
668     return replace_util.CyclicDependencies();
669   }
670 
671   // Returns or computes start state of replace fst.
Start()672   StateId Start() {
673     if (!HasStart()) {
674       if (fst_array_.size() == 1) {      // no fsts defined for replace
675         SetStart(kNoStateId);
676         return kNoStateId;
677       } else {
678         const Fst<A>* fst = fst_array_[root_];
679         StateId fst_start = fst->Start();
680         if (fst_start == kNoStateId)  // root Fst is empty
681           return kNoStateId;
682 
683         PrefixId prefix = GetPrefixId(StackPrefix());
684         StateId start = state_table_->FindState(
685             StateTuple(prefix, root_, fst_start));
686         SetStart(start);
687         return start;
688       }
689     } else {
690       return CImpl::Start();
691     }
692   }
693 
694   // Returns final weight of state (Weight::Zero() means state is not final).
Final(StateId s)695   Weight Final(StateId s) {
696     if (HasFinal(s)) {
697       return CImpl::Final(s);
698     } else {
699       const StateTuple& tuple  = state_table_->Tuple(s);
700       Weight final = Weight::Zero();
701 
702       if (tuple.prefix_id == 0) {
703         const Fst<A>* fst = fst_array_[tuple.fst_id];
704         StateId fst_state = tuple.fst_state;
705         final = fst->Final(fst_state);
706       }
707 
708       if (always_cache_ || HasArcs(s))
709         SetFinal(s, final);
710       return final;
711     }
712   }
713 
NumArcs(StateId s)714   size_t NumArcs(StateId s) {
715     if (HasArcs(s)) {  // If state cached, use the cached value.
716       return CImpl::NumArcs(s);
717     } else if (always_cache_) {  // If always caching, expand and cache state.
718       Expand(s);
719       return CImpl::NumArcs(s);
720     } else {  // Otherwise compute the number of arcs without expanding.
721       StateTuple tuple  = state_table_->Tuple(s);
722       if (tuple.fst_state == kNoStateId)
723         return 0;
724 
725       const Fst<A>* fst = fst_array_[tuple.fst_id];
726       size_t num_arcs = fst->NumArcs(tuple.fst_state);
727       if (ComputeFinalArc(tuple, 0))
728         num_arcs++;
729 
730       return num_arcs;
731     }
732   }
733 
734   // Returns whether a given label is a non terminal
IsNonTerminal(Label l)735   bool IsNonTerminal(Label l) const {
736     if (l < *nonterminal_set_.begin() || l > *nonterminal_set_.rbegin())
737       return false;
738     // TODO(allauzen): be smarter and take advantage of
739     // all_dense or all_negative.
740     // Use also in ComputeArc, this would require changes to replace
741     // so that recursing into an empty fst lead to a non co-accessible
742     // state instead of deleting the arc as done currently.
743     // Current use correct, since i/olabel sorted iff all_non_empty.
744     typename NonTerminalHash::const_iterator it =
745         nonterminal_hash_.find(l);
746     return it != nonterminal_hash_.end();
747   }
748 
NumInputEpsilons(StateId s)749   size_t NumInputEpsilons(StateId s) {
750     if (HasArcs(s)) {
751       // If state cached, use the cached value.
752       return CImpl::NumInputEpsilons(s);
753     } else if (always_cache_ || !Properties(kILabelSorted)) {
754       // If always caching or if the number of input epsilons is too expensive
755       // to compute without caching (i.e. not ilabel sorted),
756       // then expand and cache state.
757       Expand(s);
758       return CImpl::NumInputEpsilons(s);
759     } else {
760       // Otherwise, compute the number of input epsilons without caching.
761       StateTuple tuple  = state_table_->Tuple(s);
762       if (tuple.fst_state == kNoStateId)
763         return 0;
764       const Fst<A>* fst = fst_array_[tuple.fst_id];
765       size_t num  = 0;
766       if (!EpsilonOnInput(call_label_type_)) {
767         // If EpsilonOnInput(c) is false, all input epsilon arcs
768         // are also input epsilons arcs in the underlying machine.
769         num = fst->NumInputEpsilons(tuple.fst_state);
770       } else {
771         // Otherwise, one need to consider that all non-terminal arcs
772         // in the underlying machine also become input epsilon arc.
773         ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
774         for (; !aiter.Done() &&
775                  ((aiter.Value().ilabel == 0) ||
776                   IsNonTerminal(aiter.Value().olabel));
777              aiter.Next())
778           ++num;
779       }
780       if (EpsilonOnInput(return_label_type_) && ComputeFinalArc(tuple, 0))
781         num++;
782       return num;
783     }
784   }
785 
NumOutputEpsilons(StateId s)786   size_t NumOutputEpsilons(StateId s) {
787     if (HasArcs(s)) {
788       // If state cached, use the cached value.
789       return CImpl::NumOutputEpsilons(s);
790     } else if (always_cache_ || !Properties(kOLabelSorted)) {
791       // If always caching or if the number of output epsilons is too expensive
792       // to compute without caching (i.e. not olabel sorted),
793       // then expand and cache state.
794       Expand(s);
795       return CImpl::NumOutputEpsilons(s);
796     } else {
797       // Otherwise, compute the number of output epsilons without caching.
798       StateTuple tuple  = state_table_->Tuple(s);
799       if (tuple.fst_state == kNoStateId)
800         return 0;
801       const Fst<A>* fst = fst_array_[tuple.fst_id];
802       size_t num  = 0;
803       if (!EpsilonOnOutput(call_label_type_)) {
804         // If EpsilonOnOutput(c) is false, all output epsilon arcs
805         // are also output epsilons arcs in the underlying machine.
806         num = fst->NumOutputEpsilons(tuple.fst_state);
807       } else {
808         // Otherwise, one need to consider that all non-terminal arcs
809         // in the underlying machine also become output epsilon arc.
810         ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
811         for (; !aiter.Done() &&
812                  ((aiter.Value().olabel == 0) ||
813                   IsNonTerminal(aiter.Value().olabel));
814              aiter.Next())
815           ++num;
816       }
817       if (EpsilonOnOutput(return_label_type_) && ComputeFinalArc(tuple, 0))
818         num++;
819       return num;
820     }
821   }
822 
Properties()823   uint64 Properties() const { return Properties(kFstProperties); }
824 
825   // Set error if found; return FST impl properties.
Properties(uint64 mask)826   uint64 Properties(uint64 mask) const {
827     if (mask & kError) {
828       for (size_t i = 1; i < fst_array_.size(); ++i) {
829         if (fst_array_[i]->Properties(kError, false))
830           SetProperties(kError, kError);
831       }
832     }
833     return FstImpl<Arc>::Properties(mask);
834   }
835 
836   // return the base arc iterator, if arcs have not been computed yet,
837   // extend/recurse for new arcs.
InitArcIterator(StateId s,ArcIteratorData<A> * data)838   void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
839     if (!HasArcs(s))
840       Expand(s);
841     CImpl::InitArcIterator(s, data);
842     // TODO(allauzen): Set behaviour of generic iterator
843     // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
844     // relies on current behaviour.
845   }
846 
847 
848   // Extend current state (walk arcs one level deep)
Expand(StateId s)849   void Expand(StateId s) {
850     StateTuple tuple = state_table_->Tuple(s);
851 
852     // If local fst is empty
853     if (tuple.fst_state == kNoStateId) {
854       SetArcs(s);
855       return;
856     }
857 
858     ArcIterator< Fst<A> > aiter(
859         *(fst_array_[tuple.fst_id]), tuple.fst_state);
860     Arc arc;
861 
862     // Create a final arc when needed
863     if (ComputeFinalArc(tuple, &arc))
864       PushArc(s, arc);
865 
866     // Expand all arcs leaving the state
867     for (; !aiter.Done(); aiter.Next()) {
868       if (ComputeArc(tuple, aiter.Value(), &arc))
869         PushArc(s, arc);
870     }
871 
872     SetArcs(s);
873   }
874 
Expand(StateId s,const StateTuple & tuple,const ArcIteratorData<A> & data)875   void Expand(StateId s, const StateTuple &tuple,
876               const ArcIteratorData<A> &data) {
877      // If local fst is empty
878     if (tuple.fst_state == kNoStateId) {
879       SetArcs(s);
880       return;
881     }
882 
883     ArcIterator< Fst<A> > aiter(data);
884     Arc arc;
885 
886     // Create a final arc when needed
887     if (ComputeFinalArc(tuple, &arc))
888       AddArc(s, arc);
889 
890     // Expand all arcs leaving the state
891     for (; !aiter.Done(); aiter.Next()) {
892       if (ComputeArc(tuple, aiter.Value(), &arc))
893         AddArc(s, arc);
894     }
895 
896     SetArcs(s);
897   }
898 
899   // If arcp == 0, only returns if a final arc is required, does not
900   // actually compute it.
901   bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
902                        uint32 flags = kArcValueFlags) {
903     const Fst<A>* fst = fst_array_[tuple.fst_id];
904     StateId fst_state = tuple.fst_state;
905     if (fst_state == kNoStateId)
906       return false;
907 
908     // if state is final, pop up stack
909     if (fst->Final(fst_state) != Weight::Zero() && tuple.prefix_id) {
910       if (arcp) {
911         arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
912         arcp->olabel =
913             (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
914         if (flags & kArcNextStateValue) {
915           const StackPrefix& stack = state_table_->GetStackPrefix(
916               tuple.prefix_id);
917           PrefixId prefix_id = PopPrefix(stack);
918           const typename StackPrefix::PrefixTuple& top = stack.Top();
919           arcp->nextstate = state_table_->FindState(
920               StateTuple(prefix_id, top.fst_id, top.nextstate));
921         }
922         if (flags & kArcWeightValue)
923           arcp->weight = fst->Final(fst_state);
924       }
925       return true;
926     } else {
927       return false;
928     }
929   }
930 
931   // Compute the arc in the replace fst corresponding to a given
932   // in the underlying machine. Returns false if the underlying arc
933   // corresponds to no arc in the replace.
934   bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
935                   uint32 flags = kArcValueFlags) {
936     if (!EpsilonOnInput(call_label_type_) &&
937         (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
938       *arcp = arc;
939       return true;
940     }
941 
942     if (arc.olabel == 0 ||
943         arc.olabel < *nonterminal_set_.begin() ||
944         arc.olabel > *nonterminal_set_.rbegin()) {  // expand local fst
945       StateId nextstate = flags & kArcNextStateValue
946           ? state_table_->FindState(
947               StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
948           : kNoStateId;
949       *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
950     } else {
951       // check for non terminal
952       typename NonTerminalHash::const_iterator it =
953           nonterminal_hash_.find(arc.olabel);
954       if (it != nonterminal_hash_.end()) {  // recurse into non terminal
955         Label nonterminal = it->second;
956         const Fst<A>* nt_fst = fst_array_[nonterminal];
957         PrefixId nt_prefix = PushPrefix(
958             state_table_->GetStackPrefix(tuple.prefix_id),
959             tuple.fst_id, arc.nextstate);
960 
961         // if start state is valid replace, else arc is implicitly
962         // deleted
963         StateId nt_start = nt_fst->Start();
964         if (nt_start != kNoStateId) {
965           StateId nt_nextstate =  flags & kArcNextStateValue
966               ? state_table_->FindState(
967                   StateTuple(nt_prefix, nonterminal, nt_start))
968               : kNoStateId;
969           Label ilabel = (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
970           Label olabel = (EpsilonOnOutput(call_label_type_)) ?
971               0 : ((call_output_label_ == kNoLabel) ?
972                    arc.olabel : call_output_label_);
973           *arcp = A(ilabel, olabel, arc.weight, nt_nextstate);
974         } else {
975           return false;
976         }
977       } else {
978         StateId nextstate = flags & kArcNextStateValue
979             ? state_table_->FindState(
980                 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
981             : kNoStateId;
982         *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
983       }
984     }
985     return true;
986   }
987 
988   // Returns the arc iterator flags supported by this Fst.
ArcIteratorFlags()989   uint32 ArcIteratorFlags() const {
990     uint32 flags = kArcValueFlags;
991     if (!always_cache_)
992       flags |= kArcNoCache;
993     return flags;
994   }
995 
GetStateTable()996   T* GetStateTable() const {
997     return state_table_;
998   }
999 
GetFst(Label fst_id)1000   const Fst<A>* GetFst(Label fst_id) const {
1001     return fst_array_[fst_id];
1002   }
1003 
GetFstId(Label nonterminal)1004   Label GetFstId(Label nonterminal) const {
1005     typename NonTerminalHash::const_iterator it =
1006         nonterminal_hash_.find(nonterminal);
1007     if (it == nonterminal_hash_.end()) {
1008       FSTERROR() << "ReplaceFstImpl::GetFstId: nonterminal not found: "
1009                  << nonterminal;
1010     }
1011     return it->second;
1012   }
1013 
1014   // returns true if label type on call arc results in epsilon input label
EpsilonOnCallInput()1015   bool EpsilonOnCallInput() {
1016     return EpsilonOnInput(call_label_type_);
1017   }
1018 
1019   // private methods
1020  private:
1021   // hash stack prefix (return unique index into stackprefix table)
GetPrefixId(const StackPrefix & prefix)1022   PrefixId GetPrefixId(const StackPrefix& prefix) {
1023     return state_table_->FindPrefixId(prefix);
1024   }
1025 
1026   // prefix id after a stack pop
PopPrefix(StackPrefix prefix)1027   PrefixId PopPrefix(StackPrefix prefix) {
1028     prefix.Pop();
1029     return GetPrefixId(prefix);
1030   }
1031 
1032   // prefix id after a stack push
PushPrefix(StackPrefix prefix,Label fst_id,StateId nextstate)1033   PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
1034     prefix.Push(fst_id, nextstate);
1035     return GetPrefixId(prefix);
1036   }
1037 
1038   // returns true if label type on arc results in epsilon input label
EpsilonOnInput(ReplaceLabelType label_type)1039   bool EpsilonOnInput(ReplaceLabelType label_type) {
1040     if (label_type == REPLACE_LABEL_NEITHER ||
1041         label_type == REPLACE_LABEL_OUTPUT) return true;
1042     return false;
1043   }
1044 
1045   // returns true if label type on arc results in epsilon input label
EpsilonOnOutput(ReplaceLabelType label_type)1046   bool EpsilonOnOutput(ReplaceLabelType label_type) {
1047     if (label_type == REPLACE_LABEL_NEITHER ||
1048         label_type == REPLACE_LABEL_INPUT) return true;
1049     return false;
1050   }
1051 
1052   // returns true if for either the call or return arc ilabel != olabel
ReplaceTransducer()1053   bool ReplaceTransducer() {
1054     if (call_label_type_ == REPLACE_LABEL_INPUT ||
1055         call_label_type_ == REPLACE_LABEL_OUTPUT ||
1056         (call_label_type_ == REPLACE_LABEL_BOTH &&
1057             call_output_label_ != kNoLabel) ||
1058         return_label_type_ == REPLACE_LABEL_INPUT ||
1059         return_label_type_ == REPLACE_LABEL_OUTPUT) return true;
1060     return false;
1061   }
1062 
1063   // private data
1064  private:
1065   // runtime options
1066   ReplaceLabelType call_label_type_;  // how to label call arc
1067   ReplaceLabelType return_label_type_;  // how to label return arc
1068   int64 call_output_label_;  // specifies output label to put on call arc
1069   int64 return_label_;  // specifies label to put on return arc
1070   bool always_cache_;  // Optionally caching arc iterator disabled when true
1071 
1072   // state table
1073   StateTable *state_table_;
1074 
1075   // replace components
1076   set<Label> nonterminal_set_;
1077   NonTerminalHash nonterminal_hash_;
1078   vector<const Fst<A>*> fst_array_;
1079   Label root_;
1080 
1081   void operator=(const ReplaceFstImpl<A, T, C> &);  // disallow
1082 };
1083 
1084 
1085 //
1086 // \class ReplaceFst
1087 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
1088 // This version is a delayed Fst.
1089 //
1090 // ReplaceFst supports dynamic replacement of arcs in one Fst with
1091 // another Fst. This replacement is recursive.  ReplaceFst can be used
1092 // to support a variety of delayed constructions such as recursive
1093 // transition networks, union, or closure.  It is constructed with an
1094 // array of Fst(s). One Fst represents the root (or topology)
1095 // machine. The root Fst refers to other Fsts by recursively replacing
1096 // arcs labeled as non-terminals with the matching non-terminal
1097 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
1098 // to determine whether the arc is a non-terminal arc or not. A
1099 // non-terminal can be any label that is not a non-zero terminal label
1100 // in the output alphabet.
1101 //
1102 // Note that the constructor uses a vector of pair<>. These correspond
1103 // to the tuple of non-terminal Label and corresponding Fst. For example
1104 // to implement the closure operation we need 2 Fsts. The first root
1105 // Fst is a single Arc on the start State that self loops, it references
1106 // the particular machine for which we are performing the closure operation.
1107 //
1108 // The ReplaceFst class supports an optionally caching arc iterator:
1109 //    ArcIterator< ReplaceFst<A> >
1110 // The ReplaceFst need to be built such that it is known to be ilabel
1111 // or olabel sorted (see usage below).
1112 //
1113 // Observe that Matcher<Fst<A> > will use the optionally caching arc
1114 // iterator when available (Fst is ilabel sorted and matching on the
1115 // input, or Fst is olabel sorted and matching on the output).
1116 // In order to obtain the most efficient behaviour, it is recommended
1117 // to set call_label_type to REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH
1118 // and return_label_type to REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER
1119 // (this means that the call arc does not have epsilon on the input side
1120 // and the return arc has epsilon on the input side) and matching on the
1121 // input side.
1122 //
1123 // This class attaches interface to implementation and handles
1124 // reference counting, delegating most methods to ImplToFst.
1125 template <class A, class T = DefaultReplaceStateTable<A>,
1126           class C /* = DefaultCacheStore<A> */ >
1127 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T, C> > {
1128  public:
1129   friend class ArcIterator< ReplaceFst<A, T, C> >;
1130   friend class StateIterator< ReplaceFst<A, T, C> >;
1131   friend class ReplaceFstMatcher<A, T, C>;
1132 
1133   typedef A Arc;
1134   typedef typename A::Label   Label;
1135   typedef typename A::Weight  Weight;
1136   typedef typename A::StateId StateId;
1137   typedef T StateTable;
1138   typedef C Store;
1139   typedef typename C::State State;
1140   typedef CacheBaseImpl<State, C> CImpl;
1141   typedef ReplaceFstImpl<A, T, C> Impl;
1142 
1143   using ImplToFst<Impl>::Properties;
1144 
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,Label root)1145   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
1146              Label root)
1147       : ImplToFst<Impl>(
1148             new Impl(fst_array, ReplaceFstOptions<A, T, C>(root))) {}
1149 
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,const ReplaceFstOptions<A,T,C> & opts)1150   ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
1151              const ReplaceFstOptions<A, T, C> &opts)
1152       : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
1153 
1154   // See Fst<>::Copy() for doc.
1155   ReplaceFst(const ReplaceFst<A, T, C>& fst, bool safe = false)
1156       : ImplToFst<Impl>(fst, safe) {}
1157 
1158   // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
1159   virtual ReplaceFst<A, T, C> *Copy(bool safe = false) const {
1160     return new ReplaceFst<A, T, C>(*this, safe);
1161   }
1162 
1163   virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
1164 
InitArcIterator(StateId s,ArcIteratorData<A> * data)1165   virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
1166     GetImpl()->InitArcIterator(s, data);
1167   }
1168 
InitMatcher(MatchType match_type)1169   virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
1170     if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1171         ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
1172          (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
1173       return new ReplaceFstMatcher<A, T, C>(*this, match_type);
1174     } else {
1175       VLOG(2) << "Not using replace matcher";
1176       return 0;
1177     }
1178   }
1179 
CyclicDependencies()1180   bool CyclicDependencies() const {
1181     return GetImpl()->CyclicDependencies();
1182   }
1183 
GetStateTable()1184   const StateTable& GetStateTable() const {
1185     return *GetImpl()->GetStateTable();
1186   }
1187 
GetFst(Label nonterminal)1188   const Fst<A> &GetFst(Label nonterminal) const {
1189     return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1190   }
1191 
1192  private:
1193   // Makes visible to friends.
GetImpl()1194   Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
1195 
1196   void operator=(const ReplaceFst<A, T, C> &fst);  // disallow
1197 };
1198 
1199 
1200 // Specialization for ReplaceFst.
1201 template<class A, class T, class C>
1202 class StateIterator< ReplaceFst<A, T, C> >
1203     : public CacheStateIterator< ReplaceFst<A, T, C> > {
1204  public:
StateIterator(const ReplaceFst<A,T,C> & fst)1205   explicit StateIterator(const ReplaceFst<A, T, C> &fst)
1206       : CacheStateIterator< ReplaceFst<A, T, C> >(fst, fst.GetImpl()) {}
1207 
1208  private:
1209   DISALLOW_COPY_AND_ASSIGN(StateIterator);
1210 };
1211 
1212 
1213 // Specialization for ReplaceFst.
1214 // Implements optional caching. It can be used as follows:
1215 //
1216 //   ReplaceFst<A> replace;
1217 //   ArcIterator< ReplaceFst<A> > aiter(replace, s);
1218 //   // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
1219 //   aiter.SetFlags(kArcNoCache, kArcNoCache);
1220 //   // Use the arc iterator, no arc will be cached, no state will be expanded.
1221 //   // The varied 'kArcValueFlags' can be used to decide which part
1222 //   // of arc values needs to be computed.
1223 //   aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1224 //   // Only want the ilabel for this arc
1225 //   aiter.Value();  // Does not compute the destination state.
1226 //   aiter.Next();
1227 //   aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1228 //   // Want both ilabel and nextstate for that arc
1229 //   aiter.Value();  // Does compute the destination state and inserts it
1230 //                   // in the replace state table.
1231 //   // No Arc has been cached at that point.
1232 //
1233 template <class A, class T, class C>
1234 class ArcIterator< ReplaceFst<A, T, C> > {
1235  public:
1236   typedef A Arc;
1237   typedef typename A::StateId StateId;
1238 
ArcIterator(const ReplaceFst<A,T,C> & fst,StateId s)1239   ArcIterator(const ReplaceFst<A, T, C> &fst, StateId s)
1240       : fst_(fst), state_(s), pos_(0), offset_(0), flags_(kArcValueFlags),
1241         arcs_(0), data_flags_(0), final_flags_(0) {
1242     cache_data_.ref_count = 0;
1243     local_data_.ref_count = 0;
1244 
1245     // If FST does not support optional caching, force caching.
1246     if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1247         !(fst_.GetImpl()->HasArcs(state_)))
1248       fst_.GetImpl()->Expand(state_);
1249 
1250     // If state is already cached, use cached arcs array.
1251     if (fst_.GetImpl()->HasArcs(state_)) {
1252       (fst_.GetImpl())
1253           ->CacheBaseImpl<typename C::State, C>::InitArcIterator(
1254               state_, &cache_data_);
1255       num_arcs_ = cache_data_.narcs;
1256       arcs_ = cache_data_.arcs;      // 'arcs_' is a ptr to the cached arcs.
1257       data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1258     } else {  // Otherwise delay decision until Value() is called.
1259       tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
1260       if (tuple_.fst_state == kNoStateId) {
1261         num_arcs_ = 0;
1262       } else {
1263         // The decision to cache or not to cache has been defered
1264         // until Value() or SetFlags() is called. However, the arc
1265         // iterator is set up now to be ready for non-caching in order
1266         // to keep the Value() method simple and efficient.
1267         const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1268         fst->InitArcIterator(tuple_.fst_state, &local_data_);
1269         // 'arcs_' is a pointer to the arcs in the underlying machine.
1270         arcs_ = local_data_.arcs;
1271         // Compute the final arc (but not its destination state)
1272         // if a final arc is required.
1273         bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
1274             tuple_,
1275             &final_arc_,
1276             kArcValueFlags & ~kArcNextStateValue);
1277         // Set the arc value flags that hold for 'final_arc_'.
1278         final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1279         // Compute the number of arcs.
1280         num_arcs_ = local_data_.narcs;
1281         if (has_final_arc)
1282           ++num_arcs_;
1283         // Set the offset between the underlying arc positions and
1284         // the positions in the arc iterator.
1285         offset_ = num_arcs_ - local_data_.narcs;
1286         // Defers the decision to cache or not until Value() or
1287         // SetFlags() is called.
1288         data_flags_ = 0;
1289       }
1290     }
1291   }
1292 
~ArcIterator()1293   ~ArcIterator() {
1294     if (cache_data_.ref_count)
1295       --(*cache_data_.ref_count);
1296     if (local_data_.ref_count)
1297       --(*local_data_.ref_count);
1298   }
1299 
ExpandAndCache()1300   void ExpandAndCache() const   {
1301     // TODO(allauzen): revisit this
1302     // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
1303     // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
1304     //                                               &cache_data_);
1305     //
1306     fst_.InitArcIterator(state_, &cache_data_);  // Expand and cache state.
1307     arcs_ = cache_data_.arcs;  // 'arcs_' is a pointer to the cached arcs.
1308     data_flags_ = kArcValueFlags;  // All the arc member values are valid.
1309     offset_ = 0;  // No offset
1310   }
1311 
Init()1312   void Init() {
1313     if (flags_ & kArcNoCache) {  // If caching is disabled
1314       // 'arcs_' is a pointer to the arcs in the underlying machine.
1315       arcs_ = local_data_.arcs;
1316       // Set the arcs value flags that hold for 'arcs_'.
1317       data_flags_ = kArcWeightValue;
1318       if (!fst_.GetImpl()->EpsilonOnCallInput())
1319           data_flags_ |= kArcILabelValue;
1320       // Set the offset between the underlying arc positions and
1321       // the positions in the arc iterator.
1322       offset_ = num_arcs_ - local_data_.narcs;
1323     } else {  // Otherwise, expand and cache
1324       ExpandAndCache();
1325     }
1326   }
1327 
Done()1328   bool Done() const { return pos_ >= num_arcs_; }
1329 
Value()1330   const A& Value() const {
1331     // If 'data_flags_' was set to 0, non-caching was not requested
1332     if (!data_flags_) {
1333       // TODO(allauzen): revisit this.
1334       if (flags_ & kArcNoCache) {
1335         // Should never happen.
1336         FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
1337       }
1338       ExpandAndCache();  // Expand and cache.
1339     }
1340 
1341     if (pos_ - offset_ >= 0) {  // The requested arc is not the 'final' arc.
1342       const A& arc = arcs_[pos_ - offset_];
1343       if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1344         // If the value flags for 'arc' match the recquired value flags
1345         // then return 'arc'.
1346         return arc;
1347       } else {
1348         // Otherwise, compute the corresponding arc on-the-fly.
1349         fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
1350         return arc_;
1351       }
1352     } else {  // The requested arc is the 'final' arc.
1353       if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1354         // If the arc value flags that hold for the final arc
1355         // do not match the requested value flags, then
1356         // 'final_arc_' needs to be updated.
1357         fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
1358                                     flags_ & kArcValueFlags);
1359         final_flags_ = flags_ & kArcValueFlags;
1360       }
1361       return final_arc_;
1362     }
1363   }
1364 
Next()1365   void Next() { ++pos_; }
1366 
Position()1367   size_t Position() const { return pos_; }
1368 
Reset()1369   void Reset() { pos_ = 0;  }
1370 
Seek(size_t pos)1371   void Seek(size_t pos) { pos_ = pos; }
1372 
Flags()1373   uint32 Flags() const { return flags_; }
1374 
SetFlags(uint32 f,uint32 mask)1375   void SetFlags(uint32 f, uint32 mask) {
1376     // Update the flags taking into account what flags are supported
1377     // by the Fst.
1378     flags_ &= ~mask;
1379     flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
1380     // If non-caching is not requested (and caching has not already
1381     // been performed), then flush 'data_flags_' to request caching
1382     // during the next call to Value().
1383     if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1384       if (!fst_.GetImpl()->HasArcs(state_))
1385          data_flags_ = 0;
1386     }
1387     // If 'data_flags_' has been flushed but non-caching is requested
1388     // before calling Value(), then set up the iterator for non-caching.
1389     if ((f & kArcNoCache) && (!data_flags_))
1390       Init();
1391   }
1392 
1393  private:
1394   const ReplaceFst<A, T, C> &fst_;           // Reference to the FST
1395   StateId state_;                         // State in the FST
1396   mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
1397 
1398   ssize_t pos_;             // Current position
1399   mutable ssize_t offset_;  // Offset between position in iterator and in arcs_
1400   ssize_t num_arcs_;        // Number of arcs at state_
1401   uint32 flags_;            // Behavorial flags for the arc iterator
1402   mutable Arc arc_;         // Memory to temporarily store computed arcs
1403 
1404   mutable ArcIteratorData<Arc> cache_data_;  // Arc iterator data in cache
1405   mutable ArcIteratorData<Arc> local_data_;  // Arc iterator data in local fst
1406 
1407   mutable const A* arcs_;       // Array of arcs
1408   mutable uint32 data_flags_;   // Arc value flags valid for data in arcs_
1409   mutable Arc final_arc_;       // Final arc (when required)
1410   mutable uint32 final_flags_;  // Arc value flags valid for final_arc_
1411 
1412   DISALLOW_COPY_AND_ASSIGN(ArcIterator);
1413 };
1414 
1415 
1416 template <class A, class T, class C>
1417 class ReplaceFstMatcher : public MatcherBase<A> {
1418  public:
1419   typedef ReplaceFst<A, T, C> FST;
1420   typedef A Arc;
1421   typedef typename A::StateId StateId;
1422   typedef typename A::Label Label;
1423   typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
1424 
ReplaceFstMatcher(const ReplaceFst<A,T,C> & fst,fst::MatchType match_type)1425   ReplaceFstMatcher(const ReplaceFst<A, T, C> &fst,
1426                     fst::MatchType match_type)
1427       : fst_(fst),
1428         impl_(fst_.GetImpl()),
1429         s_(fst::kNoStateId),
1430         match_type_(match_type),
1431         current_loop_(false),
1432         final_arc_(false),
1433         loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1434     if (match_type_ == fst::MATCH_OUTPUT)
1435       swap(loop_.ilabel, loop_.olabel);
1436     InitMatchers();
1437   }
1438 
1439   ReplaceFstMatcher(const ReplaceFstMatcher<A, T, C> &matcher,
1440                     bool safe = false)
1441       : fst_(matcher.fst_),
1442         impl_(fst_.GetImpl()),
1443         s_(fst::kNoStateId),
1444         match_type_(matcher.match_type_),
1445         current_loop_(false),
1446         final_arc_(false),
1447         loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1448     if (match_type_ == fst::MATCH_OUTPUT)
1449       swap(loop_.ilabel, loop_.olabel);
1450     InitMatchers();
1451   }
1452 
1453   // Create a local matcher for each component Fst of replace.
1454   // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
1455   // is used to match each non-terminal arc, since these non-terminal
1456   // turn into epsilons on recursion.
InitMatchers()1457   void InitMatchers() {
1458     const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
1459     matcher_.resize(fst_array.size(), 0);
1460     for (size_t i = 0; i < fst_array.size(); ++i) {
1461       if (fst_array[i]) {
1462         matcher_[i] =
1463             new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
1464 
1465         typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
1466         for (; it != impl_->nonterminal_set_.end(); ++it) {
1467           matcher_[i]->AddMultiEpsLabel(*it);
1468         }
1469       }
1470     }
1471   }
1472 
1473   virtual ReplaceFstMatcher<A, T, C> *Copy(bool safe = false) const {
1474     return new ReplaceFstMatcher<A, T, C>(*this, safe);
1475   }
1476 
~ReplaceFstMatcher()1477   virtual ~ReplaceFstMatcher() {
1478     for (size_t i = 0; i < matcher_.size(); ++i)
1479       delete matcher_[i];
1480   }
1481 
Type(bool test)1482   virtual MatchType Type(bool test) const {
1483     if (match_type_ == MATCH_NONE)
1484       return match_type_;
1485 
1486     uint64 true_prop =  match_type_ == MATCH_INPUT ?
1487         kILabelSorted : kOLabelSorted;
1488     uint64 false_prop = match_type_ == MATCH_INPUT ?
1489         kNotILabelSorted : kNotOLabelSorted;
1490     uint64 props = fst_.Properties(true_prop | false_prop, test);
1491 
1492     if (props & true_prop)
1493       return match_type_;
1494     else if (props & false_prop)
1495       return MATCH_NONE;
1496     else
1497       return MATCH_UNKNOWN;
1498   }
1499 
GetFst()1500   virtual const Fst<A> &GetFst() const {
1501     return fst_;
1502   }
1503 
Properties(uint64 props)1504   virtual uint64 Properties(uint64 props) const {
1505     return props;
1506   }
1507 
1508  private:
1509   // Set the sate from which our matching happens.
SetState_(StateId s)1510   virtual void SetState_(StateId s) {
1511     if (s_ == s) return;
1512 
1513     s_ = s;
1514     tuple_ = impl_->GetStateTable()->Tuple(s_);
1515     if (tuple_.fst_state == kNoStateId) {
1516       done_ = true;
1517       return;
1518     }
1519     // Get current matcher. Used for non epsilon matching
1520     current_matcher_ = matcher_[tuple_.fst_id];
1521     current_matcher_->SetState(tuple_.fst_state);
1522     loop_.nextstate = s_;
1523 
1524     final_arc_ = false;
1525   }
1526 
1527   // Search for label, from previous set state. If label == 0, first
1528   // hallucinate and epsilon loop, else use the underlying matcher to
1529   // search for the label or epsilons.
1530   // - Note since the ReplaceFST recursion on non-terminal arcs causes
1531   //   epsilon transitions to be created we use the MultiEpsilonMatcher
1532   //   to search for possible matches of non terminals.
1533   // - If the component Fst reaches a final state we also need to add
1534   //   the exiting final arc.
Find_(Label label)1535   virtual bool Find_(Label label) {
1536     bool found = false;
1537     label_ = label;
1538     if (label_ == 0 || label_ == kNoLabel) {
1539       // Compute loop directly, saving Replace::ComputeArc
1540       if (label_ == 0) {
1541         current_loop_ = true;
1542         found = true;
1543       }
1544       // Search for matching multi epsilons
1545       final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
1546       found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1547     } else {
1548       // Search on sub machine directly using sub machine matcher.
1549       found = current_matcher_->Find(label_);
1550     }
1551     return found;
1552   }
1553 
Done_()1554   virtual bool Done_() const {
1555     return !current_loop_ && !final_arc_ && current_matcher_->Done();
1556   }
1557 
Value_()1558   virtual const Arc& Value_() const {
1559     if (current_loop_) {
1560       return loop_;
1561     }
1562     if (final_arc_) {
1563       impl_->ComputeFinalArc(tuple_, &arc_);
1564       return arc_;
1565     }
1566     const Arc& component_arc = current_matcher_->Value();
1567     impl_->ComputeArc(tuple_, component_arc, &arc_);
1568     return arc_;
1569   }
1570 
Next_()1571   virtual void Next_() {
1572     if (current_loop_) {
1573       current_loop_ = false;
1574       return;
1575     }
1576     if (final_arc_) {
1577       final_arc_ = false;
1578       return;
1579     }
1580     current_matcher_->Next();
1581   }
1582 
Priority_(StateId s)1583   virtual ssize_t Priority_(StateId s) { return fst_.NumArcs(s); }
1584 
1585   const ReplaceFst<A, T, C>& fst_;
1586   ReplaceFstImpl<A, T, C> *impl_;
1587   LocalMatcher* current_matcher_;
1588   vector<LocalMatcher*> matcher_;
1589 
1590   StateId s_;                        // Current state
1591   Label label_;                      // Current label
1592 
1593   MatchType match_type_;             // Supplied by caller
1594   mutable bool done_;
1595   mutable bool current_loop_;        // Current arc is the implicit loop
1596   mutable bool final_arc_;           // Current arc for exiting recursion
1597   mutable typename T::StateTuple tuple_;  // Tuple corresponding to state_
1598   mutable Arc arc_;
1599   Arc loop_;
1600 
1601   DISALLOW_COPY_AND_ASSIGN(ReplaceFstMatcher);
1602 };
1603 
1604 template <class A, class T, class C> inline
InitStateIterator(StateIteratorData<A> * data)1605 void ReplaceFst<A, T, C>::InitStateIterator(StateIteratorData<A> *data) const {
1606   data->base = new StateIterator< ReplaceFst<A, T, C> >(*this);
1607 }
1608 
1609 typedef ReplaceFst<StdArc> StdReplaceFst;
1610 
1611 // // Recursivively replaces arcs in the root Fst with other Fsts.
1612 // This version writes the result of replacement to an output MutableFst.
1613 //
1614 // Replace supports replacement of arcs in one Fst with another
1615 // Fst. This replacement is recursive.  Replace takes an array of
1616 // Fst(s). One Fst represents the root (or topology) machine. The root
1617 // Fst refers to other Fsts by recursively replacing arcs labeled as
1618 // non-terminals with the matching non-terminal Fst. Currently Replace
1619 // uses the output symbols of the arcs to determine whether the arc is
1620 // a non-terminal arc or not. A non-terminal can be any label that is
1621 // not a non-zero terminal label in the output alphabet.  Note that
1622 // input argument is a vector of pair<>. These correspond to the tuple
1623 // of non-terminal Label and corresponding Fst.
1624 template<class Arc>
1625 void Replace(const vector<pair<typename Arc::Label,
1626              const Fst<Arc>* > >& ifst_array,
1627              MutableFst<Arc> *ofst, ReplaceFstOptions<Arc> opts =
1628              ReplaceFstOptions<Arc>()) {
1629   opts.gc = true;
1630   opts.gc_limit = 0;  // Cache only the last state for fastest copy.
1631   *ofst = ReplaceFst<Arc>(ifst_array, opts);
1632 }
1633 
1634 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,fst::ReplaceUtilOptions<Arc> opts)1635 void Replace(const vector<pair<typename Arc::Label,
1636              const Fst<Arc>* > >& ifst_array,
1637              MutableFst<Arc> *ofst, fst::ReplaceUtilOptions<Arc> opts) {
1638   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1639 }
1640 
1641 // Included for backward compatibility with 'epsilon_on_replace' arguments
1642 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root,bool epsilon_on_replace)1643 void Replace(const vector<pair<typename Arc::Label,
1644              const Fst<Arc>* > >& ifst_array,
1645              MutableFst<Arc> *ofst, typename Arc::Label root,
1646              bool epsilon_on_replace) {
1647   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1648 }
1649 
1650 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root)1651 void Replace(const vector<pair<typename Arc::Label,
1652              const Fst<Arc>* > >& ifst_array,
1653              MutableFst<Arc> *ofst, typename Arc::Label root) {
1654   Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1655 }
1656 
1657 }  // namespace fst
1658 
1659 #endif  // FST_LIB_REPLACE_H__
1660