1 // replace.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes for the recursive replacement of Fsts.
20 //
21
22 #ifndef FST_LIB_REPLACE_H__
23 #define FST_LIB_REPLACE_H__
24
25 #include <unordered_map>
26 using std::unordered_map;
27 using std::unordered_multimap;
28 #include <set>
29 #include <string>
30 #include <utility>
31 using std::pair; using std::make_pair;
32 #include <vector>
33 using std::vector;
34
35 #include <fst/cache.h>
36 #include <fst/expanded-fst.h>
37 #include <fst/fst.h>
38 #include <fst/matcher.h>
39 #include <fst/replace-util.h>
40 #include <fst/state-table.h>
41 #include <fst/test-properties.h>
42
43 namespace fst {
44
45 //
46 // REPLACE STATE TUPLES AND TABLES
47 //
48 // The replace state table has the form
49 //
50 // template <class A, class P>
51 // class ReplaceStateTable {
52 // public:
53 // typedef A Arc;
54 // typedef P PrefixId;
55 // typedef typename A::StateId StateId;
56 // typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
57 // typedef typename A::Label Label;
58 // typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
59 //
60 // // Required constuctor
61 // ReplaceStateTable(const vector<pair<Label, const Fst<A>*> > &fst_tuples,
62 // Label root);
63 //
64 // // Required copy constructor that does not copy state
65 // ReplaceStateTable(const ReplaceStateTable<A,P> &table);
66 //
67 // // Lookup state ID by tuple. If it doesn't exist, then add it.
68 // StateId FindState(const StateTuple &tuple);
69 //
70 // // Lookup state tuple by ID.
71 // const StateTuple &Tuple(StateId id) const;
72 //
73 // // Lookup prefix ID by stack prefix. If it doesn't exist, then add it.
74 // PrefixId FindPrefixId(const StackPrefix &stack_prefix);
75 //
76 // // Look stack prefix by ID.
77 // const StackPrefix &GetStackPrefix(PrefixId id) const;
78 // };
79
80 //
81 // Replace State Tuples
82 //
83
84 // \struct ReplaceStateTuple
85 // \brief Tuple of information that uniquely defines a state in replace
86 template <class S, class P>
87 struct ReplaceStateTuple {
88 typedef S StateId;
89 typedef P PrefixId;
90
ReplaceStateTupleReplaceStateTuple91 ReplaceStateTuple()
92 : prefix_id(-1), fst_id(kNoStateId), fst_state(kNoStateId) {}
93
ReplaceStateTupleReplaceStateTuple94 ReplaceStateTuple(PrefixId p, StateId f, StateId s)
95 : prefix_id(p), fst_id(f), fst_state(s) {}
96
97 PrefixId prefix_id; // index in prefix table
98 StateId fst_id; // current fst being walked
99 StateId fst_state; // current state in fst being walked, not to be
100 // confused with the state_id of the combined fst
101 };
102
103 // Equality of replace state tuples.
104 template <class S, class P>
105 inline bool operator==(const ReplaceStateTuple<S, P>& x,
106 const ReplaceStateTuple<S, P>& y) {
107 return x.prefix_id == y.prefix_id &&
108 x.fst_id == y.fst_id &&
109 x.fst_state == y.fst_state;
110 }
111
112 // \class ReplaceRootSelector
113 // Functor returning true for tuples corresponding to states in the root FST
114 template <class S, class P>
115 class ReplaceRootSelector {
116 public:
operator()117 bool operator()(const ReplaceStateTuple<S, P> &tuple) const {
118 return tuple.prefix_id == 0;
119 }
120 };
121
122 // \class ReplaceFingerprint
123 // Fingerprint for general replace state tuples.
124 template <class S, class P>
125 class ReplaceFingerprint {
126 public:
ReplaceFingerprint(const vector<uint64> * size_array)127 explicit ReplaceFingerprint(const vector<uint64> *size_array)
128 : cumulative_size_array_(size_array) {}
129
operator()130 uint64 operator()(const ReplaceStateTuple<S, P> &tuple) const {
131 return tuple.prefix_id * (cumulative_size_array_->back()) +
132 cumulative_size_array_->at(tuple.fst_id - 1) +
133 tuple.fst_state;
134 }
135
136 private:
137 const vector<uint64> *cumulative_size_array_;
138 };
139
140 // \class ReplaceFstStateFingerprint
141 // Useful when the fst_state uniquely define the tuple.
142 template <class S, class P>
143 class ReplaceFstStateFingerprint {
144 public:
operator()145 uint64 operator()(const ReplaceStateTuple<S, P>& tuple) const {
146 return tuple.fst_state;
147 }
148 };
149
150 // \class ReplaceHash
151 // A generic hash function for replace state tuples.
152 template <typename S, typename P>
153 class ReplaceHash {
154 public:
operator()155 size_t operator()(const ReplaceStateTuple<S, P>& t) const {
156 return t.prefix_id + t.fst_id * kPrime0 + t.fst_state * kPrime1;
157 }
158 private:
159 static const size_t kPrime0;
160 static const size_t kPrime1;
161 };
162
163 template <typename S, typename P>
164 const size_t ReplaceHash<S, P>::kPrime0 = 7853;
165
166 template <typename S, typename P>
167 const size_t ReplaceHash<S, P>::kPrime1 = 7867;
168
169
170 //
171 // Replace Stack Prefix
172 //
173
174 // \class ReplaceStackPrefix
175 // \brief Container for stack prefix.
176 template <class L, class S>
177 class ReplaceStackPrefix {
178 public:
179 typedef L Label;
180 typedef S StateId;
181
182 // \class PrefixTuple
183 // \brief Tuple of fst_id and destination state (entry in stack prefix)
184 struct PrefixTuple {
PrefixTuplePrefixTuple185 PrefixTuple(Label f, StateId s) : fst_id(f), nextstate(s) {}
PrefixTuplePrefixTuple186 PrefixTuple() : fst_id(kNoLabel), nextstate(kNoStateId) {}
187
188 Label fst_id;
189 StateId nextstate;
190 };
191
ReplaceStackPrefix()192 ReplaceStackPrefix() {}
193
194 // copy constructor
ReplaceStackPrefix(const ReplaceStackPrefix & x)195 ReplaceStackPrefix(const ReplaceStackPrefix& x) :
196 prefix_(x.prefix_) {
197 }
198
Push(StateId fst_id,StateId nextstate)199 void Push(StateId fst_id, StateId nextstate) {
200 prefix_.push_back(PrefixTuple(fst_id, nextstate));
201 }
202
Pop()203 void Pop() {
204 prefix_.pop_back();
205 }
206
Top()207 const PrefixTuple& Top() const {
208 return prefix_[prefix_.size()-1];
209 }
210
Depth()211 size_t Depth() const {
212 return prefix_.size();
213 }
214
215 public:
216 vector<PrefixTuple> prefix_;
217 };
218
219 // Equality stack prefix classes
220 template <class L, class S>
221 inline bool operator==(const ReplaceStackPrefix<L, S>& x,
222 const ReplaceStackPrefix<L, S>& y) {
223 if (x.prefix_.size() != y.prefix_.size()) return false;
224 for (size_t i = 0; i < x.prefix_.size(); ++i) {
225 if (x.prefix_[i].fst_id != y.prefix_[i].fst_id ||
226 x.prefix_[i].nextstate != y.prefix_[i].nextstate) return false;
227 }
228 return true;
229 }
230
231 //
232 // \class ReplaceStackPrefixHash
233 // \brief Hash function for stack prefix to prefix id
234 template <class L, class S>
235 class ReplaceStackPrefixHash {
236 public:
operator()237 size_t operator()(const ReplaceStackPrefix<L, S>& x) const {
238 size_t sum = 0;
239 for (size_t i = 0; i < x.prefix_.size(); ++i) {
240 sum += x.prefix_[i].fst_id + x.prefix_[i].nextstate*kPrime0;
241 }
242 return sum;
243 }
244
245 private:
246 static const size_t kPrime0;
247 };
248
249 template <class L, class S>
250 const size_t ReplaceStackPrefixHash<L, S>::kPrime0 = 7853;
251
252 //
253 // Replace State Tables
254 //
255
256 // \class VectorHashReplaceStateTable
257 // A two-level state table for replace.
258 // Warning: calls CountStates to compute the number of states of each
259 // component Fst.
260 template <class A, class P = ssize_t>
261 class VectorHashReplaceStateTable {
262 public:
263 typedef A Arc;
264 typedef typename A::StateId StateId;
265 typedef typename A::Label Label;
266 typedef P PrefixId;
267 typedef ReplaceStateTuple<StateId, P> StateTuple;
268 typedef VectorHashStateTable<ReplaceStateTuple<StateId, P>,
269 ReplaceRootSelector<StateId, P>,
270 ReplaceFstStateFingerprint<StateId, P>,
271 ReplaceFingerprint<StateId, P> > StateTable;
272 typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
273 typedef CompactHashBiTable<
274 PrefixId, StackPrefix,
275 ReplaceStackPrefixHash<Label, StateId> > StackPrefixTable;
276
VectorHashReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)277 VectorHashReplaceStateTable(
278 const vector<pair<Label, const Fst<A>*> > &fst_tuples,
279 Label root) : root_size_(0) {
280 cumulative_size_array_.push_back(0);
281 for (size_t i = 0; i < fst_tuples.size(); ++i) {
282 if (fst_tuples[i].first == root) {
283 root_size_ = CountStates(*(fst_tuples[i].second));
284 cumulative_size_array_.push_back(cumulative_size_array_.back());
285 } else {
286 cumulative_size_array_.push_back(cumulative_size_array_.back() +
287 CountStates(*(fst_tuples[i].second)));
288 }
289 }
290 state_table_ = new StateTable(
291 new ReplaceRootSelector<StateId, P>,
292 new ReplaceFstStateFingerprint<StateId, P>,
293 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
294 root_size_,
295 root_size_ + cumulative_size_array_.back());
296 }
297
VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A,P> & table)298 VectorHashReplaceStateTable(const VectorHashReplaceStateTable<A, P> &table)
299 : root_size_(table.root_size_),
300 cumulative_size_array_(table.cumulative_size_array_),
301 prefix_table_(table.prefix_table_) {
302 state_table_ = new StateTable(
303 new ReplaceRootSelector<StateId, P>,
304 new ReplaceFstStateFingerprint<StateId, P>,
305 new ReplaceFingerprint<StateId, P>(&cumulative_size_array_),
306 root_size_,
307 root_size_ + cumulative_size_array_.back());
308 }
309
~VectorHashReplaceStateTable()310 ~VectorHashReplaceStateTable() {
311 delete state_table_;
312 }
313
FindState(const StateTuple & tuple)314 StateId FindState(const StateTuple &tuple) {
315 return state_table_->FindState(tuple);
316 }
317
Tuple(StateId id)318 const StateTuple &Tuple(StateId id) const {
319 return state_table_->Tuple(id);
320 }
321
FindPrefixId(const StackPrefix & prefix)322 PrefixId FindPrefixId(const StackPrefix &prefix) {
323 return prefix_table_.FindId(prefix);
324 }
325
GetStackPrefix(PrefixId id)326 const StackPrefix &GetStackPrefix(PrefixId id) const {
327 return prefix_table_.FindEntry(id);
328 }
329
330 private:
331 StateId root_size_;
332 vector<uint64> cumulative_size_array_;
333 StateTable *state_table_;
334 StackPrefixTable prefix_table_;
335 };
336
337
338 // \class DefaultReplaceStateTable
339 // Default replace state table
340 template <class A, class P = ssize_t>
341 class DefaultReplaceStateTable : public CompactHashStateTable<
342 ReplaceStateTuple<typename A::StateId, P>,
343 ReplaceHash<typename A::StateId, P> > {
344 public:
345 typedef A Arc;
346 typedef typename A::StateId StateId;
347 typedef typename A::Label Label;
348 typedef P PrefixId;
349 typedef ReplaceStateTuple<StateId, P> StateTuple;
350 typedef CompactHashStateTable<StateTuple,
351 ReplaceHash<StateId, PrefixId> > StateTable;
352 typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
353 typedef CompactHashBiTable<
354 PrefixId, StackPrefix,
355 ReplaceStackPrefixHash<Label, StateId> > StackPrefixTable;
356
357 using StateTable::FindState;
358 using StateTable::Tuple;
359
DefaultReplaceStateTable(const vector<pair<Label,const Fst<A> * >> & fst_tuples,Label root)360 DefaultReplaceStateTable(
361 const vector<pair<Label, const Fst<A>*> > &fst_tuples,
362 Label root) {}
363
DefaultReplaceStateTable(const DefaultReplaceStateTable<A,P> & table)364 DefaultReplaceStateTable(const DefaultReplaceStateTable<A, P> &table)
365 : StateTable(), prefix_table_(table.prefix_table_) {}
366
FindPrefixId(const StackPrefix & prefix)367 PrefixId FindPrefixId(const StackPrefix &prefix) {
368 return prefix_table_.FindId(prefix);
369 }
370
GetStackPrefix(PrefixId id)371 const StackPrefix &GetStackPrefix(PrefixId id) const {
372 return prefix_table_.FindEntry(id);
373 }
374
375 private:
376 StackPrefixTable prefix_table_;
377 };
378
379 //
380 // REPLACE FST CLASS
381 //
382
383 // By default ReplaceFst will copy the input label of the 'replace arc'.
384 // The call_label_type and return_label_type options specify how to manage
385 // the labels of the call arc and the return arc of the replace FST
386 template <class A, class T = DefaultReplaceStateTable<A>,
387 class C = DefaultCacheStore<A> >
388 struct ReplaceFstOptions : CacheImplOptions<C> {
389 int64 root; // root rule for expansion
390 ReplaceLabelType call_label_type; // how to label call arc
391 ReplaceLabelType return_label_type; // how to label return arc
392 int64 call_output_label; // specifies output label to put on call arc
393 // if kNoLabel, use existing label on call arc
394 // if 0, epsilon
395 // otherwise, use this field as the output label
396 int64 return_label; // specifies label to put on return arc
397 bool take_ownership; // take ownership of input Fst(s)
398 T* state_table;
399
ReplaceFstOptionsReplaceFstOptions400 ReplaceFstOptions(const CacheImplOptions<C> &opts, int64 r)
401 : CacheImplOptions<C>(opts),
402 root(r),
403 call_label_type(REPLACE_LABEL_INPUT),
404 return_label_type(REPLACE_LABEL_NEITHER),
405 call_output_label(kNoLabel),
406 return_label(0),
407 take_ownership(false),
408 state_table(0) {}
409
ReplaceFstOptionsReplaceFstOptions410 ReplaceFstOptions(const CacheOptions &opts, int64 r)
411 : CacheImplOptions<C>(opts),
412 root(r),
413 call_label_type(REPLACE_LABEL_INPUT),
414 return_label_type(REPLACE_LABEL_NEITHER),
415 call_output_label(kNoLabel),
416 return_label(0),
417 take_ownership(false),
418 state_table(0) {}
419
ReplaceFstOptionsReplaceFstOptions420 explicit ReplaceFstOptions(const fst::ReplaceUtilOptions<A> &opts)
421 : root(opts.root),
422 call_label_type(opts.call_label_type),
423 return_label_type(opts.return_label_type),
424 call_output_label(kNoLabel),
425 return_label(opts.return_label),
426 take_ownership(false),
427 state_table(0) {}
428
ReplaceFstOptionsReplaceFstOptions429 explicit ReplaceFstOptions(int64 r)
430 : root(r),
431 call_label_type(REPLACE_LABEL_INPUT),
432 return_label_type(REPLACE_LABEL_NEITHER),
433 call_output_label(kNoLabel),
434 return_label(0),
435 take_ownership(false),
436 state_table(0) {}
437
ReplaceFstOptionsReplaceFstOptions438 ReplaceFstOptions(int64 r, ReplaceLabelType call_label_type,
439 ReplaceLabelType return_label_type, int64 return_label)
440 : root(r),
441 call_label_type(call_label_type),
442 return_label_type(return_label_type),
443 call_output_label(kNoLabel),
444 return_label(return_label),
445 take_ownership(false),
446 state_table(0) {}
447
ReplaceFstOptionsReplaceFstOptions448 ReplaceFstOptions(int64 r, ReplaceLabelType call_label_type,
449 ReplaceLabelType return_label_type, int64 call_output_label,
450 int64 return_label)
451 : root(r),
452 call_label_type(call_label_type),
453 return_label_type(return_label_type),
454 call_output_label(call_output_label),
455 return_label(return_label),
456 take_ownership(false),
457 state_table(0) {}
458
ReplaceFstOptionsReplaceFstOptions459 ReplaceFstOptions(int64 r, bool epsilon_replace_arc) // b/w compatibility
460 : root(r),
461 call_label_type((epsilon_replace_arc) ? REPLACE_LABEL_NEITHER :
462 REPLACE_LABEL_INPUT),
463 return_label_type(REPLACE_LABEL_NEITHER),
464 call_output_label((epsilon_replace_arc) ? 0 : kNoLabel),
465 return_label(0),
466 take_ownership(false),
467 state_table(0) {}
468
ReplaceFstOptionsReplaceFstOptions469 ReplaceFstOptions()
470 : root(kNoLabel),
471 call_label_type(REPLACE_LABEL_INPUT),
472 return_label_type(REPLACE_LABEL_NEITHER),
473 call_output_label(kNoLabel),
474 return_label(0),
475 take_ownership(false),
476 state_table(0) {}
477 };
478
479 // Forward declaration
480 template <class A, class T, class C> class ReplaceFstMatcher;
481
482 // \class ReplaceFstImpl
483 // \brief Implementation class for replace class Fst
484 //
485 // The replace implementation class supports a dynamic
486 // expansion of a recursive transition network represented as Fst
487 // with dynamic replacable arcs.
488 //
489 template <class A, class T, class C>
490 class ReplaceFstImpl : public CacheBaseImpl<typename C::State, C> {
491 friend class ReplaceFstMatcher<A, T, C>;
492
493 public:
494 typedef A Arc;
495 typedef typename A::Label Label;
496 typedef typename A::Weight Weight;
497 typedef typename A::StateId StateId;
498 typedef typename C::State State;
499 typedef CacheBaseImpl<State, C> CImpl;
500 typedef unordered_map<Label, Label> NonTerminalHash;
501 typedef T StateTable;
502 typedef typename T::PrefixId PrefixId;
503 typedef ReplaceStateTuple<StateId, PrefixId> StateTuple;
504 typedef ReplaceStackPrefix<Label, StateId> StackPrefix;
505
506 using FstImpl<A>::SetType;
507 using FstImpl<A>::SetProperties;
508 using FstImpl<A>::WriteHeader;
509 using FstImpl<A>::SetInputSymbols;
510 using FstImpl<A>::SetOutputSymbols;
511 using FstImpl<A>::InputSymbols;
512 using FstImpl<A>::OutputSymbols;
513
514 using CImpl::PushArc;
515 using CImpl::HasArcs;
516 using CImpl::HasFinal;
517 using CImpl::HasStart;
518 using CImpl::SetArcs;
519 using CImpl::SetFinal;
520 using CImpl::SetStart;
521
522 // constructor for replace class implementation.
523 // \param fst_tuples array of label/fst tuples, one for each non-terminal
ReplaceFstImpl(const vector<pair<Label,const Fst<A> * >> & fst_tuples,const ReplaceFstOptions<A,T,C> & opts)524 ReplaceFstImpl(const vector< pair<Label, const Fst<A>* > >& fst_tuples,
525 const ReplaceFstOptions<A, T, C> &opts)
526 : CImpl(opts),
527 call_label_type_(opts.call_label_type),
528 return_label_type_(opts.return_label_type),
529 call_output_label_(opts.call_output_label),
530 return_label_(opts.return_label),
531 state_table_(opts.state_table ? opts.state_table :
532 new StateTable(fst_tuples, opts.root)) {
533 SetType("replace");
534
535 // if the label is epsilon, then all REPLACE_LABEL_* options equivalent.
536 // Set the label_type to NEITHER for ease of setting properties later
537 if (call_output_label_ == 0)
538 call_label_type_ = REPLACE_LABEL_NEITHER;
539 if (return_label_ == 0)
540 return_label_type_ = REPLACE_LABEL_NEITHER;
541
542 if (fst_tuples.size() > 0) {
543 SetInputSymbols(fst_tuples[0].second->InputSymbols());
544 SetOutputSymbols(fst_tuples[0].second->OutputSymbols());
545 }
546
547 bool all_negative = true; // all nonterminals are negative?
548 bool dense_range = true; // all nonterminals are positive
549 // and form a dense range containing 1?
550 for (size_t i = 0; i < fst_tuples.size(); ++i) {
551 Label nonterminal = fst_tuples[i].first;
552 if (nonterminal >= 0)
553 all_negative = false;
554 if (nonterminal > fst_tuples.size() || nonterminal <= 0)
555 dense_range = false;
556 }
557
558 vector<uint64> inprops;
559 bool all_ilabel_sorted = true;
560 bool all_olabel_sorted = true;
561 bool all_non_empty = true;
562 fst_array_.push_back(0);
563 for (size_t i = 0; i < fst_tuples.size(); ++i) {
564 Label label = fst_tuples[i].first;
565 const Fst<A> *fst = fst_tuples[i].second;
566 nonterminal_hash_[label] = fst_array_.size();
567 nonterminal_set_.insert(label);
568 fst_array_.push_back(opts.take_ownership ? fst : fst->Copy());
569 if (fst->Start() == kNoStateId)
570 all_non_empty = false;
571 if (!fst->Properties(kILabelSorted, false))
572 all_ilabel_sorted = false;
573 if (!fst->Properties(kOLabelSorted, false))
574 all_olabel_sorted = false;
575 inprops.push_back(fst->Properties(kCopyProperties, false));
576 if (i) {
577 if (!CompatSymbols(InputSymbols(), fst->InputSymbols())) {
578 FSTERROR() << "ReplaceFstImpl: input symbols of Fst " << i
579 << " does not match input symbols of base Fst (0'th fst)";
580 SetProperties(kError, kError);
581 }
582 if (!CompatSymbols(OutputSymbols(), fst->OutputSymbols())) {
583 FSTERROR() << "ReplaceFstImpl: output symbols of Fst " << i
584 << " does not match output symbols of base Fst "
585 << "(0'th fst)";
586 SetProperties(kError, kError);
587 }
588 }
589 }
590 Label nonterminal = nonterminal_hash_[opts.root];
591 if ((nonterminal == 0) && (fst_array_.size() > 1)) {
592 FSTERROR() << "ReplaceFstImpl: no Fst corresponding to root label '"
593 << opts.root << "' in the input tuple vector";
594 SetProperties(kError, kError);
595 }
596 root_ = (nonterminal > 0) ? nonterminal : 1;
597
598 SetProperties(ReplaceProperties(inprops, root_ - 1,
599 EpsilonOnInput(call_label_type_),
600 EpsilonOnInput(return_label_type_),
601 ReplaceTransducer(), all_non_empty));
602 // We assume that all terminals are positive. The resulting
603 // ReplaceFst is known to be kILabelSorted when: (1) all sub-FSTs are
604 // kILabelSorted, (2) the input label of the return arc is epsilon,
605 // and (3) one of the 3 following conditions is satisfied:
606 // 1. the input label of the call arc is not epsilon
607 // 2. all non-terminals are negative, or
608 // 3. all non-terninals are positive and form a dense range containing 1.
609 if (all_ilabel_sorted && EpsilonOnInput(return_label_type_) &&
610 (!EpsilonOnInput(call_label_type_) || all_negative || dense_range)) {
611 SetProperties(kILabelSorted, kILabelSorted);
612 }
613 // Similarly, the resulting ReplaceFst is known to be
614 // kOLabelSorted when: (1) all sub-FSTs are kOLabelSorted, (2) the output
615 // label of the return arc is epsilon, and (3) one of the 3 following
616 // conditions is satisfied:
617 // 1. the output label of the call arc is not epsilon
618 // 2. all non-terminals are negative, or
619 // 3. all non-terninals are positive and form a dense range containing 1.
620 if (all_olabel_sorted && EpsilonOnOutput(return_label_type_) &&
621 (!EpsilonOnOutput(call_label_type_) || all_negative || dense_range))
622 SetProperties(kOLabelSorted, kOLabelSorted);
623
624 // Enable optional caching as long as sorted and all non empty.
625 if (Properties(kILabelSorted | kOLabelSorted) && all_non_empty)
626 always_cache_ = false;
627 else
628 always_cache_ = true;
629 VLOG(2) << "ReplaceFstImpl::ReplaceFstImpl: always_cache = "
630 << (always_cache_ ? "true" : "false");
631 }
632
ReplaceFstImpl(const ReplaceFstImpl & impl)633 ReplaceFstImpl(const ReplaceFstImpl& impl)
634 : CImpl(impl),
635 call_label_type_(impl.call_label_type_),
636 return_label_type_(impl.return_label_type_),
637 call_output_label_(impl.call_output_label_),
638 return_label_(impl.return_label_),
639 always_cache_(impl.always_cache_),
640 state_table_(new StateTable(*(impl.state_table_))),
641 nonterminal_set_(impl.nonterminal_set_),
642 nonterminal_hash_(impl.nonterminal_hash_),
643 root_(impl.root_) {
644 SetType("replace");
645 SetProperties(impl.Properties(), kCopyProperties);
646 SetInputSymbols(impl.InputSymbols());
647 SetOutputSymbols(impl.OutputSymbols());
648 fst_array_.reserve(impl.fst_array_.size());
649 fst_array_.push_back(0);
650 for (size_t i = 1; i < impl.fst_array_.size(); ++i) {
651 fst_array_.push_back(impl.fst_array_[i]->Copy(true));
652 }
653 }
654
~ReplaceFstImpl()655 ~ReplaceFstImpl() {
656 delete state_table_;
657 for (size_t i = 1; i < fst_array_.size(); ++i) {
658 delete fst_array_[i];
659 }
660 }
661
662 // Computes the dependency graph of the replace class and returns
663 // true if the dependencies are cyclic. Cyclic dependencies will result
664 // in an un-expandable replace fst.
CyclicDependencies()665 bool CyclicDependencies() const {
666 ReplaceUtil<A> replace_util(fst_array_, nonterminal_hash_,
667 fst::ReplaceUtilOptions<A>(root_));
668 return replace_util.CyclicDependencies();
669 }
670
671 // Returns or computes start state of replace fst.
Start()672 StateId Start() {
673 if (!HasStart()) {
674 if (fst_array_.size() == 1) { // no fsts defined for replace
675 SetStart(kNoStateId);
676 return kNoStateId;
677 } else {
678 const Fst<A>* fst = fst_array_[root_];
679 StateId fst_start = fst->Start();
680 if (fst_start == kNoStateId) // root Fst is empty
681 return kNoStateId;
682
683 PrefixId prefix = GetPrefixId(StackPrefix());
684 StateId start = state_table_->FindState(
685 StateTuple(prefix, root_, fst_start));
686 SetStart(start);
687 return start;
688 }
689 } else {
690 return CImpl::Start();
691 }
692 }
693
694 // Returns final weight of state (Weight::Zero() means state is not final).
Final(StateId s)695 Weight Final(StateId s) {
696 if (HasFinal(s)) {
697 return CImpl::Final(s);
698 } else {
699 const StateTuple& tuple = state_table_->Tuple(s);
700 Weight final = Weight::Zero();
701
702 if (tuple.prefix_id == 0) {
703 const Fst<A>* fst = fst_array_[tuple.fst_id];
704 StateId fst_state = tuple.fst_state;
705 final = fst->Final(fst_state);
706 }
707
708 if (always_cache_ || HasArcs(s))
709 SetFinal(s, final);
710 return final;
711 }
712 }
713
NumArcs(StateId s)714 size_t NumArcs(StateId s) {
715 if (HasArcs(s)) { // If state cached, use the cached value.
716 return CImpl::NumArcs(s);
717 } else if (always_cache_) { // If always caching, expand and cache state.
718 Expand(s);
719 return CImpl::NumArcs(s);
720 } else { // Otherwise compute the number of arcs without expanding.
721 StateTuple tuple = state_table_->Tuple(s);
722 if (tuple.fst_state == kNoStateId)
723 return 0;
724
725 const Fst<A>* fst = fst_array_[tuple.fst_id];
726 size_t num_arcs = fst->NumArcs(tuple.fst_state);
727 if (ComputeFinalArc(tuple, 0))
728 num_arcs++;
729
730 return num_arcs;
731 }
732 }
733
734 // Returns whether a given label is a non terminal
IsNonTerminal(Label l)735 bool IsNonTerminal(Label l) const {
736 if (l < *nonterminal_set_.begin() || l > *nonterminal_set_.rbegin())
737 return false;
738 // TODO(allauzen): be smarter and take advantage of
739 // all_dense or all_negative.
740 // Use also in ComputeArc, this would require changes to replace
741 // so that recursing into an empty fst lead to a non co-accessible
742 // state instead of deleting the arc as done currently.
743 // Current use correct, since i/olabel sorted iff all_non_empty.
744 typename NonTerminalHash::const_iterator it =
745 nonterminal_hash_.find(l);
746 return it != nonterminal_hash_.end();
747 }
748
NumInputEpsilons(StateId s)749 size_t NumInputEpsilons(StateId s) {
750 if (HasArcs(s)) {
751 // If state cached, use the cached value.
752 return CImpl::NumInputEpsilons(s);
753 } else if (always_cache_ || !Properties(kILabelSorted)) {
754 // If always caching or if the number of input epsilons is too expensive
755 // to compute without caching (i.e. not ilabel sorted),
756 // then expand and cache state.
757 Expand(s);
758 return CImpl::NumInputEpsilons(s);
759 } else {
760 // Otherwise, compute the number of input epsilons without caching.
761 StateTuple tuple = state_table_->Tuple(s);
762 if (tuple.fst_state == kNoStateId)
763 return 0;
764 const Fst<A>* fst = fst_array_[tuple.fst_id];
765 size_t num = 0;
766 if (!EpsilonOnInput(call_label_type_)) {
767 // If EpsilonOnInput(c) is false, all input epsilon arcs
768 // are also input epsilons arcs in the underlying machine.
769 num = fst->NumInputEpsilons(tuple.fst_state);
770 } else {
771 // Otherwise, one need to consider that all non-terminal arcs
772 // in the underlying machine also become input epsilon arc.
773 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
774 for (; !aiter.Done() &&
775 ((aiter.Value().ilabel == 0) ||
776 IsNonTerminal(aiter.Value().olabel));
777 aiter.Next())
778 ++num;
779 }
780 if (EpsilonOnInput(return_label_type_) && ComputeFinalArc(tuple, 0))
781 num++;
782 return num;
783 }
784 }
785
NumOutputEpsilons(StateId s)786 size_t NumOutputEpsilons(StateId s) {
787 if (HasArcs(s)) {
788 // If state cached, use the cached value.
789 return CImpl::NumOutputEpsilons(s);
790 } else if (always_cache_ || !Properties(kOLabelSorted)) {
791 // If always caching or if the number of output epsilons is too expensive
792 // to compute without caching (i.e. not olabel sorted),
793 // then expand and cache state.
794 Expand(s);
795 return CImpl::NumOutputEpsilons(s);
796 } else {
797 // Otherwise, compute the number of output epsilons without caching.
798 StateTuple tuple = state_table_->Tuple(s);
799 if (tuple.fst_state == kNoStateId)
800 return 0;
801 const Fst<A>* fst = fst_array_[tuple.fst_id];
802 size_t num = 0;
803 if (!EpsilonOnOutput(call_label_type_)) {
804 // If EpsilonOnOutput(c) is false, all output epsilon arcs
805 // are also output epsilons arcs in the underlying machine.
806 num = fst->NumOutputEpsilons(tuple.fst_state);
807 } else {
808 // Otherwise, one need to consider that all non-terminal arcs
809 // in the underlying machine also become output epsilon arc.
810 ArcIterator<Fst<A> > aiter(*fst, tuple.fst_state);
811 for (; !aiter.Done() &&
812 ((aiter.Value().olabel == 0) ||
813 IsNonTerminal(aiter.Value().olabel));
814 aiter.Next())
815 ++num;
816 }
817 if (EpsilonOnOutput(return_label_type_) && ComputeFinalArc(tuple, 0))
818 num++;
819 return num;
820 }
821 }
822
Properties()823 uint64 Properties() const { return Properties(kFstProperties); }
824
825 // Set error if found; return FST impl properties.
Properties(uint64 mask)826 uint64 Properties(uint64 mask) const {
827 if (mask & kError) {
828 for (size_t i = 1; i < fst_array_.size(); ++i) {
829 if (fst_array_[i]->Properties(kError, false))
830 SetProperties(kError, kError);
831 }
832 }
833 return FstImpl<Arc>::Properties(mask);
834 }
835
836 // return the base arc iterator, if arcs have not been computed yet,
837 // extend/recurse for new arcs.
InitArcIterator(StateId s,ArcIteratorData<A> * data)838 void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
839 if (!HasArcs(s))
840 Expand(s);
841 CImpl::InitArcIterator(s, data);
842 // TODO(allauzen): Set behaviour of generic iterator
843 // Warning: ArcIterator<ReplaceFst<A> >::InitCache()
844 // relies on current behaviour.
845 }
846
847
848 // Extend current state (walk arcs one level deep)
Expand(StateId s)849 void Expand(StateId s) {
850 StateTuple tuple = state_table_->Tuple(s);
851
852 // If local fst is empty
853 if (tuple.fst_state == kNoStateId) {
854 SetArcs(s);
855 return;
856 }
857
858 ArcIterator< Fst<A> > aiter(
859 *(fst_array_[tuple.fst_id]), tuple.fst_state);
860 Arc arc;
861
862 // Create a final arc when needed
863 if (ComputeFinalArc(tuple, &arc))
864 PushArc(s, arc);
865
866 // Expand all arcs leaving the state
867 for (; !aiter.Done(); aiter.Next()) {
868 if (ComputeArc(tuple, aiter.Value(), &arc))
869 PushArc(s, arc);
870 }
871
872 SetArcs(s);
873 }
874
Expand(StateId s,const StateTuple & tuple,const ArcIteratorData<A> & data)875 void Expand(StateId s, const StateTuple &tuple,
876 const ArcIteratorData<A> &data) {
877 // If local fst is empty
878 if (tuple.fst_state == kNoStateId) {
879 SetArcs(s);
880 return;
881 }
882
883 ArcIterator< Fst<A> > aiter(data);
884 Arc arc;
885
886 // Create a final arc when needed
887 if (ComputeFinalArc(tuple, &arc))
888 AddArc(s, arc);
889
890 // Expand all arcs leaving the state
891 for (; !aiter.Done(); aiter.Next()) {
892 if (ComputeArc(tuple, aiter.Value(), &arc))
893 AddArc(s, arc);
894 }
895
896 SetArcs(s);
897 }
898
899 // If arcp == 0, only returns if a final arc is required, does not
900 // actually compute it.
901 bool ComputeFinalArc(const StateTuple &tuple, A* arcp,
902 uint32 flags = kArcValueFlags) {
903 const Fst<A>* fst = fst_array_[tuple.fst_id];
904 StateId fst_state = tuple.fst_state;
905 if (fst_state == kNoStateId)
906 return false;
907
908 // if state is final, pop up stack
909 if (fst->Final(fst_state) != Weight::Zero() && tuple.prefix_id) {
910 if (arcp) {
911 arcp->ilabel = (EpsilonOnInput(return_label_type_)) ? 0 : return_label_;
912 arcp->olabel =
913 (EpsilonOnOutput(return_label_type_)) ? 0 : return_label_;
914 if (flags & kArcNextStateValue) {
915 const StackPrefix& stack = state_table_->GetStackPrefix(
916 tuple.prefix_id);
917 PrefixId prefix_id = PopPrefix(stack);
918 const typename StackPrefix::PrefixTuple& top = stack.Top();
919 arcp->nextstate = state_table_->FindState(
920 StateTuple(prefix_id, top.fst_id, top.nextstate));
921 }
922 if (flags & kArcWeightValue)
923 arcp->weight = fst->Final(fst_state);
924 }
925 return true;
926 } else {
927 return false;
928 }
929 }
930
931 // Compute the arc in the replace fst corresponding to a given
932 // in the underlying machine. Returns false if the underlying arc
933 // corresponds to no arc in the replace.
934 bool ComputeArc(const StateTuple &tuple, const A &arc, A* arcp,
935 uint32 flags = kArcValueFlags) {
936 if (!EpsilonOnInput(call_label_type_) &&
937 (flags == (flags & (kArcILabelValue | kArcWeightValue)))) {
938 *arcp = arc;
939 return true;
940 }
941
942 if (arc.olabel == 0 ||
943 arc.olabel < *nonterminal_set_.begin() ||
944 arc.olabel > *nonterminal_set_.rbegin()) { // expand local fst
945 StateId nextstate = flags & kArcNextStateValue
946 ? state_table_->FindState(
947 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
948 : kNoStateId;
949 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
950 } else {
951 // check for non terminal
952 typename NonTerminalHash::const_iterator it =
953 nonterminal_hash_.find(arc.olabel);
954 if (it != nonterminal_hash_.end()) { // recurse into non terminal
955 Label nonterminal = it->second;
956 const Fst<A>* nt_fst = fst_array_[nonterminal];
957 PrefixId nt_prefix = PushPrefix(
958 state_table_->GetStackPrefix(tuple.prefix_id),
959 tuple.fst_id, arc.nextstate);
960
961 // if start state is valid replace, else arc is implicitly
962 // deleted
963 StateId nt_start = nt_fst->Start();
964 if (nt_start != kNoStateId) {
965 StateId nt_nextstate = flags & kArcNextStateValue
966 ? state_table_->FindState(
967 StateTuple(nt_prefix, nonterminal, nt_start))
968 : kNoStateId;
969 Label ilabel = (EpsilonOnInput(call_label_type_)) ? 0 : arc.ilabel;
970 Label olabel = (EpsilonOnOutput(call_label_type_)) ?
971 0 : ((call_output_label_ == kNoLabel) ?
972 arc.olabel : call_output_label_);
973 *arcp = A(ilabel, olabel, arc.weight, nt_nextstate);
974 } else {
975 return false;
976 }
977 } else {
978 StateId nextstate = flags & kArcNextStateValue
979 ? state_table_->FindState(
980 StateTuple(tuple.prefix_id, tuple.fst_id, arc.nextstate))
981 : kNoStateId;
982 *arcp = A(arc.ilabel, arc.olabel, arc.weight, nextstate);
983 }
984 }
985 return true;
986 }
987
988 // Returns the arc iterator flags supported by this Fst.
ArcIteratorFlags()989 uint32 ArcIteratorFlags() const {
990 uint32 flags = kArcValueFlags;
991 if (!always_cache_)
992 flags |= kArcNoCache;
993 return flags;
994 }
995
GetStateTable()996 T* GetStateTable() const {
997 return state_table_;
998 }
999
GetFst(Label fst_id)1000 const Fst<A>* GetFst(Label fst_id) const {
1001 return fst_array_[fst_id];
1002 }
1003
GetFstId(Label nonterminal)1004 Label GetFstId(Label nonterminal) const {
1005 typename NonTerminalHash::const_iterator it =
1006 nonterminal_hash_.find(nonterminal);
1007 if (it == nonterminal_hash_.end()) {
1008 FSTERROR() << "ReplaceFstImpl::GetFstId: nonterminal not found: "
1009 << nonterminal;
1010 }
1011 return it->second;
1012 }
1013
1014 // returns true if label type on call arc results in epsilon input label
EpsilonOnCallInput()1015 bool EpsilonOnCallInput() {
1016 return EpsilonOnInput(call_label_type_);
1017 }
1018
1019 // private methods
1020 private:
1021 // hash stack prefix (return unique index into stackprefix table)
GetPrefixId(const StackPrefix & prefix)1022 PrefixId GetPrefixId(const StackPrefix& prefix) {
1023 return state_table_->FindPrefixId(prefix);
1024 }
1025
1026 // prefix id after a stack pop
PopPrefix(StackPrefix prefix)1027 PrefixId PopPrefix(StackPrefix prefix) {
1028 prefix.Pop();
1029 return GetPrefixId(prefix);
1030 }
1031
1032 // prefix id after a stack push
PushPrefix(StackPrefix prefix,Label fst_id,StateId nextstate)1033 PrefixId PushPrefix(StackPrefix prefix, Label fst_id, StateId nextstate) {
1034 prefix.Push(fst_id, nextstate);
1035 return GetPrefixId(prefix);
1036 }
1037
1038 // returns true if label type on arc results in epsilon input label
EpsilonOnInput(ReplaceLabelType label_type)1039 bool EpsilonOnInput(ReplaceLabelType label_type) {
1040 if (label_type == REPLACE_LABEL_NEITHER ||
1041 label_type == REPLACE_LABEL_OUTPUT) return true;
1042 return false;
1043 }
1044
1045 // returns true if label type on arc results in epsilon input label
EpsilonOnOutput(ReplaceLabelType label_type)1046 bool EpsilonOnOutput(ReplaceLabelType label_type) {
1047 if (label_type == REPLACE_LABEL_NEITHER ||
1048 label_type == REPLACE_LABEL_INPUT) return true;
1049 return false;
1050 }
1051
1052 // returns true if for either the call or return arc ilabel != olabel
ReplaceTransducer()1053 bool ReplaceTransducer() {
1054 if (call_label_type_ == REPLACE_LABEL_INPUT ||
1055 call_label_type_ == REPLACE_LABEL_OUTPUT ||
1056 (call_label_type_ == REPLACE_LABEL_BOTH &&
1057 call_output_label_ != kNoLabel) ||
1058 return_label_type_ == REPLACE_LABEL_INPUT ||
1059 return_label_type_ == REPLACE_LABEL_OUTPUT) return true;
1060 return false;
1061 }
1062
1063 // private data
1064 private:
1065 // runtime options
1066 ReplaceLabelType call_label_type_; // how to label call arc
1067 ReplaceLabelType return_label_type_; // how to label return arc
1068 int64 call_output_label_; // specifies output label to put on call arc
1069 int64 return_label_; // specifies label to put on return arc
1070 bool always_cache_; // Optionally caching arc iterator disabled when true
1071
1072 // state table
1073 StateTable *state_table_;
1074
1075 // replace components
1076 set<Label> nonterminal_set_;
1077 NonTerminalHash nonterminal_hash_;
1078 vector<const Fst<A>*> fst_array_;
1079 Label root_;
1080
1081 void operator=(const ReplaceFstImpl<A, T, C> &); // disallow
1082 };
1083
1084
1085 //
1086 // \class ReplaceFst
1087 // \brief Recursivively replaces arcs in the root Fst with other Fsts.
1088 // This version is a delayed Fst.
1089 //
1090 // ReplaceFst supports dynamic replacement of arcs in one Fst with
1091 // another Fst. This replacement is recursive. ReplaceFst can be used
1092 // to support a variety of delayed constructions such as recursive
1093 // transition networks, union, or closure. It is constructed with an
1094 // array of Fst(s). One Fst represents the root (or topology)
1095 // machine. The root Fst refers to other Fsts by recursively replacing
1096 // arcs labeled as non-terminals with the matching non-terminal
1097 // Fst. Currently the ReplaceFst uses the output symbols of the arcs
1098 // to determine whether the arc is a non-terminal arc or not. A
1099 // non-terminal can be any label that is not a non-zero terminal label
1100 // in the output alphabet.
1101 //
1102 // Note that the constructor uses a vector of pair<>. These correspond
1103 // to the tuple of non-terminal Label and corresponding Fst. For example
1104 // to implement the closure operation we need 2 Fsts. The first root
1105 // Fst is a single Arc on the start State that self loops, it references
1106 // the particular machine for which we are performing the closure operation.
1107 //
1108 // The ReplaceFst class supports an optionally caching arc iterator:
1109 // ArcIterator< ReplaceFst<A> >
1110 // The ReplaceFst need to be built such that it is known to be ilabel
1111 // or olabel sorted (see usage below).
1112 //
1113 // Observe that Matcher<Fst<A> > will use the optionally caching arc
1114 // iterator when available (Fst is ilabel sorted and matching on the
1115 // input, or Fst is olabel sorted and matching on the output).
1116 // In order to obtain the most efficient behaviour, it is recommended
1117 // to set call_label_type to REPLACE_LABEL_INPUT or REPLACE_LABEL_BOTH
1118 // and return_label_type to REPLACE_LABEL_OUTPUT or REPLACE_LABEL_NEITHER
1119 // (this means that the call arc does not have epsilon on the input side
1120 // and the return arc has epsilon on the input side) and matching on the
1121 // input side.
1122 //
1123 // This class attaches interface to implementation and handles
1124 // reference counting, delegating most methods to ImplToFst.
1125 template <class A, class T = DefaultReplaceStateTable<A>,
1126 class C /* = DefaultCacheStore<A> */ >
1127 class ReplaceFst : public ImplToFst< ReplaceFstImpl<A, T, C> > {
1128 public:
1129 friend class ArcIterator< ReplaceFst<A, T, C> >;
1130 friend class StateIterator< ReplaceFst<A, T, C> >;
1131 friend class ReplaceFstMatcher<A, T, C>;
1132
1133 typedef A Arc;
1134 typedef typename A::Label Label;
1135 typedef typename A::Weight Weight;
1136 typedef typename A::StateId StateId;
1137 typedef T StateTable;
1138 typedef C Store;
1139 typedef typename C::State State;
1140 typedef CacheBaseImpl<State, C> CImpl;
1141 typedef ReplaceFstImpl<A, T, C> Impl;
1142
1143 using ImplToFst<Impl>::Properties;
1144
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,Label root)1145 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
1146 Label root)
1147 : ImplToFst<Impl>(
1148 new Impl(fst_array, ReplaceFstOptions<A, T, C>(root))) {}
1149
ReplaceFst(const vector<pair<Label,const Fst<A> * >> & fst_array,const ReplaceFstOptions<A,T,C> & opts)1150 ReplaceFst(const vector<pair<Label, const Fst<A>* > >& fst_array,
1151 const ReplaceFstOptions<A, T, C> &opts)
1152 : ImplToFst<Impl>(new Impl(fst_array, opts)) {}
1153
1154 // See Fst<>::Copy() for doc.
1155 ReplaceFst(const ReplaceFst<A, T, C>& fst, bool safe = false)
1156 : ImplToFst<Impl>(fst, safe) {}
1157
1158 // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc.
1159 virtual ReplaceFst<A, T, C> *Copy(bool safe = false) const {
1160 return new ReplaceFst<A, T, C>(*this, safe);
1161 }
1162
1163 virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
1164
InitArcIterator(StateId s,ArcIteratorData<A> * data)1165 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
1166 GetImpl()->InitArcIterator(s, data);
1167 }
1168
InitMatcher(MatchType match_type)1169 virtual MatcherBase<A> *InitMatcher(MatchType match_type) const {
1170 if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1171 ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) ||
1172 (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) {
1173 return new ReplaceFstMatcher<A, T, C>(*this, match_type);
1174 } else {
1175 VLOG(2) << "Not using replace matcher";
1176 return 0;
1177 }
1178 }
1179
CyclicDependencies()1180 bool CyclicDependencies() const {
1181 return GetImpl()->CyclicDependencies();
1182 }
1183
GetStateTable()1184 const StateTable& GetStateTable() const {
1185 return *GetImpl()->GetStateTable();
1186 }
1187
GetFst(Label nonterminal)1188 const Fst<A> &GetFst(Label nonterminal) const {
1189 return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal));
1190 }
1191
1192 private:
1193 // Makes visible to friends.
GetImpl()1194 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
1195
1196 void operator=(const ReplaceFst<A, T, C> &fst); // disallow
1197 };
1198
1199
1200 // Specialization for ReplaceFst.
1201 template<class A, class T, class C>
1202 class StateIterator< ReplaceFst<A, T, C> >
1203 : public CacheStateIterator< ReplaceFst<A, T, C> > {
1204 public:
StateIterator(const ReplaceFst<A,T,C> & fst)1205 explicit StateIterator(const ReplaceFst<A, T, C> &fst)
1206 : CacheStateIterator< ReplaceFst<A, T, C> >(fst, fst.GetImpl()) {}
1207
1208 private:
1209 DISALLOW_COPY_AND_ASSIGN(StateIterator);
1210 };
1211
1212
1213 // Specialization for ReplaceFst.
1214 // Implements optional caching. It can be used as follows:
1215 //
1216 // ReplaceFst<A> replace;
1217 // ArcIterator< ReplaceFst<A> > aiter(replace, s);
1218 // // Note: ArcIterator< Fst<A> > is always a caching arc iterator.
1219 // aiter.SetFlags(kArcNoCache, kArcNoCache);
1220 // // Use the arc iterator, no arc will be cached, no state will be expanded.
1221 // // The varied 'kArcValueFlags' can be used to decide which part
1222 // // of arc values needs to be computed.
1223 // aiter.SetFlags(kArcILabelValue, kArcValueFlags);
1224 // // Only want the ilabel for this arc
1225 // aiter.Value(); // Does not compute the destination state.
1226 // aiter.Next();
1227 // aiter.SetFlags(kArcNextStateValue, kArcNextStateValue);
1228 // // Want both ilabel and nextstate for that arc
1229 // aiter.Value(); // Does compute the destination state and inserts it
1230 // // in the replace state table.
1231 // // No Arc has been cached at that point.
1232 //
1233 template <class A, class T, class C>
1234 class ArcIterator< ReplaceFst<A, T, C> > {
1235 public:
1236 typedef A Arc;
1237 typedef typename A::StateId StateId;
1238
ArcIterator(const ReplaceFst<A,T,C> & fst,StateId s)1239 ArcIterator(const ReplaceFst<A, T, C> &fst, StateId s)
1240 : fst_(fst), state_(s), pos_(0), offset_(0), flags_(kArcValueFlags),
1241 arcs_(0), data_flags_(0), final_flags_(0) {
1242 cache_data_.ref_count = 0;
1243 local_data_.ref_count = 0;
1244
1245 // If FST does not support optional caching, force caching.
1246 if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) &&
1247 !(fst_.GetImpl()->HasArcs(state_)))
1248 fst_.GetImpl()->Expand(state_);
1249
1250 // If state is already cached, use cached arcs array.
1251 if (fst_.GetImpl()->HasArcs(state_)) {
1252 (fst_.GetImpl())
1253 ->template CacheBaseImpl<typename C::State, C>::InitArcIterator(
1254 state_, &cache_data_);
1255 num_arcs_ = cache_data_.narcs;
1256 arcs_ = cache_data_.arcs; // 'arcs_' is a ptr to the cached arcs.
1257 data_flags_ = kArcValueFlags; // All the arc member values are valid.
1258 } else { // Otherwise delay decision until Value() is called.
1259 tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(state_);
1260 if (tuple_.fst_state == kNoStateId) {
1261 num_arcs_ = 0;
1262 } else {
1263 // The decision to cache or not to cache has been defered
1264 // until Value() or SetFlags() is called. However, the arc
1265 // iterator is set up now to be ready for non-caching in order
1266 // to keep the Value() method simple and efficient.
1267 const Fst<A>* fst = fst_.GetImpl()->GetFst(tuple_.fst_id);
1268 fst->InitArcIterator(tuple_.fst_state, &local_data_);
1269 // 'arcs_' is a pointer to the arcs in the underlying machine.
1270 arcs_ = local_data_.arcs;
1271 // Compute the final arc (but not its destination state)
1272 // if a final arc is required.
1273 bool has_final_arc = fst_.GetImpl()->ComputeFinalArc(
1274 tuple_,
1275 &final_arc_,
1276 kArcValueFlags & ~kArcNextStateValue);
1277 // Set the arc value flags that hold for 'final_arc_'.
1278 final_flags_ = kArcValueFlags & ~kArcNextStateValue;
1279 // Compute the number of arcs.
1280 num_arcs_ = local_data_.narcs;
1281 if (has_final_arc)
1282 ++num_arcs_;
1283 // Set the offset between the underlying arc positions and
1284 // the positions in the arc iterator.
1285 offset_ = num_arcs_ - local_data_.narcs;
1286 // Defers the decision to cache or not until Value() or
1287 // SetFlags() is called.
1288 data_flags_ = 0;
1289 }
1290 }
1291 }
1292
~ArcIterator()1293 ~ArcIterator() {
1294 if (cache_data_.ref_count)
1295 --(*cache_data_.ref_count);
1296 if (local_data_.ref_count)
1297 --(*local_data_.ref_count);
1298 }
1299
ExpandAndCache()1300 void ExpandAndCache() const {
1301 // TODO(allauzen): revisit this
1302 // fst_.GetImpl()->Expand(state_, tuple_, local_data_);
1303 // (fst_.GetImpl())->CacheImpl<A>*>::InitArcIterator(state_,
1304 // &cache_data_);
1305 //
1306 fst_.InitArcIterator(state_, &cache_data_); // Expand and cache state.
1307 arcs_ = cache_data_.arcs; // 'arcs_' is a pointer to the cached arcs.
1308 data_flags_ = kArcValueFlags; // All the arc member values are valid.
1309 offset_ = 0; // No offset
1310 }
1311
Init()1312 void Init() {
1313 if (flags_ & kArcNoCache) { // If caching is disabled
1314 // 'arcs_' is a pointer to the arcs in the underlying machine.
1315 arcs_ = local_data_.arcs;
1316 // Set the arcs value flags that hold for 'arcs_'.
1317 data_flags_ = kArcWeightValue;
1318 if (!fst_.GetImpl()->EpsilonOnCallInput())
1319 data_flags_ |= kArcILabelValue;
1320 // Set the offset between the underlying arc positions and
1321 // the positions in the arc iterator.
1322 offset_ = num_arcs_ - local_data_.narcs;
1323 } else { // Otherwise, expand and cache
1324 ExpandAndCache();
1325 }
1326 }
1327
Done()1328 bool Done() const { return pos_ >= num_arcs_; }
1329
Value()1330 const A& Value() const {
1331 // If 'data_flags_' was set to 0, non-caching was not requested
1332 if (!data_flags_) {
1333 // TODO(allauzen): revisit this.
1334 if (flags_ & kArcNoCache) {
1335 // Should never happen.
1336 FSTERROR() << "ReplaceFst: inconsistent arc iterator flags";
1337 }
1338 ExpandAndCache(); // Expand and cache.
1339 }
1340
1341 if (pos_ - offset_ >= 0) { // The requested arc is not the 'final' arc.
1342 const A& arc = arcs_[pos_ - offset_];
1343 if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) {
1344 // If the value flags for 'arc' match the recquired value flags
1345 // then return 'arc'.
1346 return arc;
1347 } else {
1348 // Otherwise, compute the corresponding arc on-the-fly.
1349 fst_.GetImpl()->ComputeArc(tuple_, arc, &arc_, flags_ & kArcValueFlags);
1350 return arc_;
1351 }
1352 } else { // The requested arc is the 'final' arc.
1353 if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) {
1354 // If the arc value flags that hold for the final arc
1355 // do not match the requested value flags, then
1356 // 'final_arc_' needs to be updated.
1357 fst_.GetImpl()->ComputeFinalArc(tuple_, &final_arc_,
1358 flags_ & kArcValueFlags);
1359 final_flags_ = flags_ & kArcValueFlags;
1360 }
1361 return final_arc_;
1362 }
1363 }
1364
Next()1365 void Next() { ++pos_; }
1366
Position()1367 size_t Position() const { return pos_; }
1368
Reset()1369 void Reset() { pos_ = 0; }
1370
Seek(size_t pos)1371 void Seek(size_t pos) { pos_ = pos; }
1372
Flags()1373 uint32 Flags() const { return flags_; }
1374
SetFlags(uint32 f,uint32 mask)1375 void SetFlags(uint32 f, uint32 mask) {
1376 // Update the flags taking into account what flags are supported
1377 // by the Fst.
1378 flags_ &= ~mask;
1379 flags_ |= (f & fst_.GetImpl()->ArcIteratorFlags());
1380 // If non-caching is not requested (and caching has not already
1381 // been performed), then flush 'data_flags_' to request caching
1382 // during the next call to Value().
1383 if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) {
1384 if (!fst_.GetImpl()->HasArcs(state_))
1385 data_flags_ = 0;
1386 }
1387 // If 'data_flags_' has been flushed but non-caching is requested
1388 // before calling Value(), then set up the iterator for non-caching.
1389 if ((f & kArcNoCache) && (!data_flags_))
1390 Init();
1391 }
1392
1393 private:
1394 const ReplaceFst<A, T, C> &fst_; // Reference to the FST
1395 StateId state_; // State in the FST
1396 mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
1397
1398 ssize_t pos_; // Current position
1399 mutable ssize_t offset_; // Offset between position in iterator and in arcs_
1400 ssize_t num_arcs_; // Number of arcs at state_
1401 uint32 flags_; // Behavorial flags for the arc iterator
1402 mutable Arc arc_; // Memory to temporarily store computed arcs
1403
1404 mutable ArcIteratorData<Arc> cache_data_; // Arc iterator data in cache
1405 mutable ArcIteratorData<Arc> local_data_; // Arc iterator data in local fst
1406
1407 mutable const A* arcs_; // Array of arcs
1408 mutable uint32 data_flags_; // Arc value flags valid for data in arcs_
1409 mutable Arc final_arc_; // Final arc (when required)
1410 mutable uint32 final_flags_; // Arc value flags valid for final_arc_
1411
1412 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
1413 };
1414
1415
1416 template <class A, class T, class C>
1417 class ReplaceFstMatcher : public MatcherBase<A> {
1418 public:
1419 typedef ReplaceFst<A, T, C> FST;
1420 typedef A Arc;
1421 typedef typename A::StateId StateId;
1422 typedef typename A::Label Label;
1423 typedef MultiEpsMatcher<Matcher<Fst<A> > > LocalMatcher;
1424
ReplaceFstMatcher(const ReplaceFst<A,T,C> & fst,fst::MatchType match_type)1425 ReplaceFstMatcher(const ReplaceFst<A, T, C> &fst,
1426 fst::MatchType match_type)
1427 : fst_(fst),
1428 impl_(fst_.GetImpl()),
1429 s_(fst::kNoStateId),
1430 match_type_(match_type),
1431 current_loop_(false),
1432 final_arc_(false),
1433 loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1434 if (match_type_ == fst::MATCH_OUTPUT)
1435 swap(loop_.ilabel, loop_.olabel);
1436 InitMatchers();
1437 }
1438
1439 ReplaceFstMatcher(const ReplaceFstMatcher<A, T, C> &matcher,
1440 bool safe = false)
1441 : fst_(matcher.fst_),
1442 impl_(fst_.GetImpl()),
1443 s_(fst::kNoStateId),
1444 match_type_(matcher.match_type_),
1445 current_loop_(false),
1446 final_arc_(false),
1447 loop_(fst::kNoLabel, 0, A::Weight::One(), fst::kNoStateId) {
1448 if (match_type_ == fst::MATCH_OUTPUT)
1449 swap(loop_.ilabel, loop_.olabel);
1450 InitMatchers();
1451 }
1452
1453 // Create a local matcher for each component Fst of replace.
1454 // LocalMatcher is a multi epsilon wrapper matcher. MultiEpsilonMatcher
1455 // is used to match each non-terminal arc, since these non-terminal
1456 // turn into epsilons on recursion.
InitMatchers()1457 void InitMatchers() {
1458 const vector<const Fst<A>*>& fst_array = impl_->fst_array_;
1459 matcher_.resize(fst_array.size(), 0);
1460 for (size_t i = 0; i < fst_array.size(); ++i) {
1461 if (fst_array[i]) {
1462 matcher_[i] =
1463 new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList);
1464
1465 typename set<Label>::iterator it = impl_->nonterminal_set_.begin();
1466 for (; it != impl_->nonterminal_set_.end(); ++it) {
1467 matcher_[i]->AddMultiEpsLabel(*it);
1468 }
1469 }
1470 }
1471 }
1472
1473 virtual ReplaceFstMatcher<A, T, C> *Copy(bool safe = false) const {
1474 return new ReplaceFstMatcher<A, T, C>(*this, safe);
1475 }
1476
~ReplaceFstMatcher()1477 virtual ~ReplaceFstMatcher() {
1478 for (size_t i = 0; i < matcher_.size(); ++i)
1479 delete matcher_[i];
1480 }
1481
Type(bool test)1482 virtual MatchType Type(bool test) const {
1483 if (match_type_ == MATCH_NONE)
1484 return match_type_;
1485
1486 uint64 true_prop = match_type_ == MATCH_INPUT ?
1487 kILabelSorted : kOLabelSorted;
1488 uint64 false_prop = match_type_ == MATCH_INPUT ?
1489 kNotILabelSorted : kNotOLabelSorted;
1490 uint64 props = fst_.Properties(true_prop | false_prop, test);
1491
1492 if (props & true_prop)
1493 return match_type_;
1494 else if (props & false_prop)
1495 return MATCH_NONE;
1496 else
1497 return MATCH_UNKNOWN;
1498 }
1499
GetFst()1500 virtual const Fst<A> &GetFst() const {
1501 return fst_;
1502 }
1503
Properties(uint64 props)1504 virtual uint64 Properties(uint64 props) const {
1505 return props;
1506 }
1507
1508 private:
1509 // Set the sate from which our matching happens.
SetState_(StateId s)1510 virtual void SetState_(StateId s) {
1511 if (s_ == s) return;
1512
1513 s_ = s;
1514 tuple_ = impl_->GetStateTable()->Tuple(s_);
1515 if (tuple_.fst_state == kNoStateId) {
1516 done_ = true;
1517 return;
1518 }
1519 // Get current matcher. Used for non epsilon matching
1520 current_matcher_ = matcher_[tuple_.fst_id];
1521 current_matcher_->SetState(tuple_.fst_state);
1522 loop_.nextstate = s_;
1523
1524 final_arc_ = false;
1525 }
1526
1527 // Search for label, from previous set state. If label == 0, first
1528 // hallucinate and epsilon loop, else use the underlying matcher to
1529 // search for the label or epsilons.
1530 // - Note since the ReplaceFST recursion on non-terminal arcs causes
1531 // epsilon transitions to be created we use the MultiEpsilonMatcher
1532 // to search for possible matches of non terminals.
1533 // - If the component Fst reaches a final state we also need to add
1534 // the exiting final arc.
Find_(Label label)1535 virtual bool Find_(Label label) {
1536 bool found = false;
1537 label_ = label;
1538 if (label_ == 0 || label_ == kNoLabel) {
1539 // Compute loop directly, saving Replace::ComputeArc
1540 if (label_ == 0) {
1541 current_loop_ = true;
1542 found = true;
1543 }
1544 // Search for matching multi epsilons
1545 final_arc_ = impl_->ComputeFinalArc(tuple_, 0);
1546 found = current_matcher_->Find(kNoLabel) || final_arc_ || found;
1547 } else {
1548 // Search on sub machine directly using sub machine matcher.
1549 found = current_matcher_->Find(label_);
1550 }
1551 return found;
1552 }
1553
Done_()1554 virtual bool Done_() const {
1555 return !current_loop_ && !final_arc_ && current_matcher_->Done();
1556 }
1557
Value_()1558 virtual const Arc& Value_() const {
1559 if (current_loop_) {
1560 return loop_;
1561 }
1562 if (final_arc_) {
1563 impl_->ComputeFinalArc(tuple_, &arc_);
1564 return arc_;
1565 }
1566 const Arc& component_arc = current_matcher_->Value();
1567 impl_->ComputeArc(tuple_, component_arc, &arc_);
1568 return arc_;
1569 }
1570
Next_()1571 virtual void Next_() {
1572 if (current_loop_) {
1573 current_loop_ = false;
1574 return;
1575 }
1576 if (final_arc_) {
1577 final_arc_ = false;
1578 return;
1579 }
1580 current_matcher_->Next();
1581 }
1582
Priority_(StateId s)1583 virtual ssize_t Priority_(StateId s) { return fst_.NumArcs(s); }
1584
1585 const ReplaceFst<A, T, C>& fst_;
1586 ReplaceFstImpl<A, T, C> *impl_;
1587 LocalMatcher* current_matcher_;
1588 vector<LocalMatcher*> matcher_;
1589
1590 StateId s_; // Current state
1591 Label label_; // Current label
1592
1593 MatchType match_type_; // Supplied by caller
1594 mutable bool done_;
1595 mutable bool current_loop_; // Current arc is the implicit loop
1596 mutable bool final_arc_; // Current arc for exiting recursion
1597 mutable typename T::StateTuple tuple_; // Tuple corresponding to state_
1598 mutable Arc arc_;
1599 Arc loop_;
1600
1601 DISALLOW_COPY_AND_ASSIGN(ReplaceFstMatcher);
1602 };
1603
1604 template <class A, class T, class C> inline
InitStateIterator(StateIteratorData<A> * data)1605 void ReplaceFst<A, T, C>::InitStateIterator(StateIteratorData<A> *data) const {
1606 data->base = new StateIterator< ReplaceFst<A, T, C> >(*this);
1607 }
1608
1609 typedef ReplaceFst<StdArc> StdReplaceFst;
1610
1611 // // Recursivively replaces arcs in the root Fst with other Fsts.
1612 // This version writes the result of replacement to an output MutableFst.
1613 //
1614 // Replace supports replacement of arcs in one Fst with another
1615 // Fst. This replacement is recursive. Replace takes an array of
1616 // Fst(s). One Fst represents the root (or topology) machine. The root
1617 // Fst refers to other Fsts by recursively replacing arcs labeled as
1618 // non-terminals with the matching non-terminal Fst. Currently Replace
1619 // uses the output symbols of the arcs to determine whether the arc is
1620 // a non-terminal arc or not. A non-terminal can be any label that is
1621 // not a non-zero terminal label in the output alphabet. Note that
1622 // input argument is a vector of pair<>. These correspond to the tuple
1623 // of non-terminal Label and corresponding Fst.
1624 template<class Arc>
1625 void Replace(const vector<pair<typename Arc::Label,
1626 const Fst<Arc>* > >& ifst_array,
1627 MutableFst<Arc> *ofst, ReplaceFstOptions<Arc> opts =
1628 ReplaceFstOptions<Arc>()) {
1629 opts.gc = true;
1630 opts.gc_limit = 0; // Cache only the last state for fastest copy.
1631 *ofst = ReplaceFst<Arc>(ifst_array, opts);
1632 }
1633
1634 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,fst::ReplaceUtilOptions<Arc> opts)1635 void Replace(const vector<pair<typename Arc::Label,
1636 const Fst<Arc>* > >& ifst_array,
1637 MutableFst<Arc> *ofst, fst::ReplaceUtilOptions<Arc> opts) {
1638 Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(opts));
1639 }
1640
1641 // Included for backward compatibility with 'epsilon_on_replace' arguments
1642 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root,bool epsilon_on_replace)1643 void Replace(const vector<pair<typename Arc::Label,
1644 const Fst<Arc>* > >& ifst_array,
1645 MutableFst<Arc> *ofst, typename Arc::Label root,
1646 bool epsilon_on_replace) {
1647 Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root, epsilon_on_replace));
1648 }
1649
1650 template<class Arc>
Replace(const vector<pair<typename Arc::Label,const Fst<Arc> * >> & ifst_array,MutableFst<Arc> * ofst,typename Arc::Label root)1651 void Replace(const vector<pair<typename Arc::Label,
1652 const Fst<Arc>* > >& ifst_array,
1653 MutableFst<Arc> *ofst, typename Arc::Label root) {
1654 Replace(ifst_array, ofst, ReplaceFstOptions<Arc>(root));
1655 }
1656
1657 } // namespace fst
1658
1659 #endif // FST_LIB_REPLACE_H__
1660