1 // relabel.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: johans@google.com (Johan Schalkwyk)
17 //
18 // \file
19 // Functions and classes to relabel an Fst (either on input or output)
20 //
21 #ifndef FST_LIB_RELABEL_H__
22 #define FST_LIB_RELABEL_H__
23
24 #include <unordered_map>
25 using std::unordered_map;
26 using std::unordered_multimap;
27 #include <string>
28 #include <utility>
29 using std::pair; using std::make_pair;
30 #include <vector>
31 using std::vector;
32
33 #include <fst/cache.h>
34 #include <fst/test-properties.h>
35
36
37 #include <unordered_map>
38 using std::unordered_map;
39 using std::unordered_multimap;
40
41 namespace fst {
42
43 //
44 // Relabels either the input labels or output labels. The old to
45 // new labels are specified using a vector of pair<Label,Label>.
46 // Any label associations not specified are assumed to be identity
47 // mapping.
48 //
49 // \param fst input fst, must be mutable
50 // \param ipairs vector of input label pairs indicating old to new mapping
51 // \param opairs vector of output label pairs indicating old to new mapping
52 //
53 template <class A>
Relabel(MutableFst<A> * fst,const vector<pair<typename A::Label,typename A::Label>> & ipairs,const vector<pair<typename A::Label,typename A::Label>> & opairs)54 void Relabel(
55 MutableFst<A> *fst,
56 const vector<pair<typename A::Label, typename A::Label> >& ipairs,
57 const vector<pair<typename A::Label, typename A::Label> >& opairs) {
58 typedef typename A::StateId StateId;
59 typedef typename A::Label Label;
60
61 uint64 props = fst->Properties(kFstProperties, false);
62
63 // construct label to label hash.
64 unordered_map<Label, Label> input_map;
65 for (size_t i = 0; i < ipairs.size(); ++i) {
66 input_map[ipairs[i].first] = ipairs[i].second;
67 }
68
69 unordered_map<Label, Label> output_map;
70 for (size_t i = 0; i < opairs.size(); ++i) {
71 output_map[opairs[i].first] = opairs[i].second;
72 }
73
74 for (StateIterator<MutableFst<A> > siter(*fst);
75 !siter.Done(); siter.Next()) {
76 StateId s = siter.Value();
77 for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
78 !aiter.Done(); aiter.Next()) {
79 A arc = aiter.Value();
80
81 // relabel input
82 // only relabel if relabel pair defined
83 typename unordered_map<Label, Label>::iterator it =
84 input_map.find(arc.ilabel);
85 if (it != input_map.end()) {
86 if (it->second == kNoLabel) {
87 FSTERROR() << "Input symbol id " << arc.ilabel
88 << " missing from target vocabulary";
89 fst->SetProperties(kError, kError);
90 return;
91 }
92 arc.ilabel = it->second;
93 }
94
95 // relabel output
96 it = output_map.find(arc.olabel);
97 if (it != output_map.end()) {
98 if (it->second == kNoLabel) {
99 FSTERROR() << "Output symbol id " << arc.olabel
100 << " missing from target vocabulary";
101 fst->SetProperties(kError, kError);
102 return;
103 }
104 arc.olabel = it->second;
105 }
106
107 aiter.SetValue(arc);
108 }
109 }
110
111 fst->SetProperties(RelabelProperties(props), kFstProperties);
112 }
113
114 //
115 // Relabels either the input labels or output labels. The old to
116 // new labels mappings are specified using an input Symbol set.
117 // Any label associations not specified are assumed to be identity
118 // mapping.
119 //
120 // \param fst input fst, must be mutable
121 // \param new_isymbols symbol set indicating new mapping of input symbols
122 // \param new_osymbols symbol set indicating new mapping of output symbols
123 //
124 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)125 void Relabel(MutableFst<A> *fst,
126 const SymbolTable* new_isymbols,
127 const SymbolTable* new_osymbols) {
128 Relabel(fst,
129 fst->InputSymbols(), new_isymbols, true,
130 fst->OutputSymbols(), new_osymbols, true);
131 }
132
133 template<class A>
Relabel(MutableFst<A> * fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,bool attach_new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,bool attach_new_osymbols)134 void Relabel(MutableFst<A> *fst,
135 const SymbolTable* old_isymbols,
136 const SymbolTable* new_isymbols,
137 bool attach_new_isymbols,
138 const SymbolTable* old_osymbols,
139 const SymbolTable* new_osymbols,
140 bool attach_new_osymbols) {
141 typedef typename A::StateId StateId;
142 typedef typename A::Label Label;
143
144 vector<pair<Label, Label> > ipairs;
145 if (old_isymbols && new_isymbols) {
146 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
147 syms_iter.Next()) {
148 string isymbol = syms_iter.Symbol();
149 int isymbol_val = syms_iter.Value();
150 int new_isymbol_val = new_isymbols->Find(isymbol);
151 ipairs.push_back(make_pair(isymbol_val, new_isymbol_val));
152 }
153 if (attach_new_isymbols)
154 fst->SetInputSymbols(new_isymbols);
155 }
156
157 vector<pair<Label, Label> > opairs;
158 if (old_osymbols && new_osymbols) {
159 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
160 syms_iter.Next()) {
161 string osymbol = syms_iter.Symbol();
162 int osymbol_val = syms_iter.Value();
163 int new_osymbol_val = new_osymbols->Find(osymbol);
164 opairs.push_back(make_pair(osymbol_val, new_osymbol_val));
165 }
166 if (attach_new_osymbols)
167 fst->SetOutputSymbols(new_osymbols);
168 }
169
170 // call relabel using vector of relabel pairs.
171 Relabel(fst, ipairs, opairs);
172 }
173
174
175 typedef CacheOptions RelabelFstOptions;
176
177 template <class A> class RelabelFst;
178
179 //
180 // \class RelabelFstImpl
181 // \brief Implementation for delayed relabeling
182 //
183 // Relabels an FST from one symbol set to another. Relabeling
184 // can either be on input or output space. RelabelFst implements
185 // a delayed version of the relabel. Arcs are relabeled on the fly
186 // and not cached. I.e each request is recomputed.
187 //
188 template<class A>
189 class RelabelFstImpl : public CacheImpl<A> {
190 friend class StateIterator< RelabelFst<A> >;
191 public:
192 using FstImpl<A>::SetType;
193 using FstImpl<A>::SetProperties;
194 using FstImpl<A>::WriteHeader;
195 using FstImpl<A>::SetInputSymbols;
196 using FstImpl<A>::SetOutputSymbols;
197
198 using CacheImpl<A>::PushArc;
199 using CacheImpl<A>::HasArcs;
200 using CacheImpl<A>::HasFinal;
201 using CacheImpl<A>::HasStart;
202 using CacheImpl<A>::SetArcs;
203 using CacheImpl<A>::SetFinal;
204 using CacheImpl<A>::SetStart;
205
206 typedef A Arc;
207 typedef typename A::Label Label;
208 typedef typename A::Weight Weight;
209 typedef typename A::StateId StateId;
210 typedef DefaultCacheStore<A> Store;
211 typedef typename Store::State State;
212
213
RelabelFstImpl(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)214 RelabelFstImpl(const Fst<A>& fst,
215 const vector<pair<Label, Label> >& ipairs,
216 const vector<pair<Label, Label> >& opairs,
217 const RelabelFstOptions &opts)
218 : CacheImpl<A>(opts), fst_(fst.Copy()),
219 relabel_input_(false), relabel_output_(false) {
220 uint64 props = fst.Properties(kCopyProperties, false);
221 SetProperties(RelabelProperties(props));
222 SetType("relabel");
223
224 // create input label map
225 if (ipairs.size() > 0) {
226 for (size_t i = 0; i < ipairs.size(); ++i) {
227 input_map_[ipairs[i].first] = ipairs[i].second;
228 }
229 relabel_input_ = true;
230 }
231
232 // create output label map
233 if (opairs.size() > 0) {
234 for (size_t i = 0; i < opairs.size(); ++i) {
235 output_map_[opairs[i].first] = opairs[i].second;
236 }
237 relabel_output_ = true;
238 }
239 }
240
RelabelFstImpl(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)241 RelabelFstImpl(const Fst<A>& fst,
242 const SymbolTable* old_isymbols,
243 const SymbolTable* new_isymbols,
244 const SymbolTable* old_osymbols,
245 const SymbolTable* new_osymbols,
246 const RelabelFstOptions &opts)
247 : CacheImpl<A>(opts), fst_(fst.Copy()),
248 relabel_input_(false), relabel_output_(false) {
249 SetType("relabel");
250
251 uint64 props = fst.Properties(kCopyProperties, false);
252 SetProperties(RelabelProperties(props));
253 SetInputSymbols(old_isymbols);
254 SetOutputSymbols(old_osymbols);
255
256 if (old_isymbols && new_isymbols &&
257 old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) {
258 for (SymbolTableIterator syms_iter(*old_isymbols); !syms_iter.Done();
259 syms_iter.Next()) {
260 input_map_[syms_iter.Value()] = new_isymbols->Find(syms_iter.Symbol());
261 }
262 SetInputSymbols(new_isymbols);
263 relabel_input_ = true;
264 }
265
266 if (old_osymbols && new_osymbols &&
267 old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) {
268 for (SymbolTableIterator syms_iter(*old_osymbols); !syms_iter.Done();
269 syms_iter.Next()) {
270 output_map_[syms_iter.Value()] =
271 new_osymbols->Find(syms_iter.Symbol());
272 }
273 SetOutputSymbols(new_osymbols);
274 relabel_output_ = true;
275 }
276 }
277
RelabelFstImpl(const RelabelFstImpl<A> & impl)278 RelabelFstImpl(const RelabelFstImpl<A>& impl)
279 : CacheImpl<A>(impl),
280 fst_(impl.fst_->Copy(true)),
281 input_map_(impl.input_map_),
282 output_map_(impl.output_map_),
283 relabel_input_(impl.relabel_input_),
284 relabel_output_(impl.relabel_output_) {
285 SetType("relabel");
286 SetProperties(impl.Properties(), kCopyProperties);
287 SetInputSymbols(impl.InputSymbols());
288 SetOutputSymbols(impl.OutputSymbols());
289 }
290
~RelabelFstImpl()291 ~RelabelFstImpl() { delete fst_; }
292
Start()293 StateId Start() {
294 if (!HasStart()) {
295 StateId s = fst_->Start();
296 SetStart(s);
297 }
298 return CacheImpl<A>::Start();
299 }
300
Final(StateId s)301 Weight Final(StateId s) {
302 if (!HasFinal(s)) {
303 SetFinal(s, fst_->Final(s));
304 }
305 return CacheImpl<A>::Final(s);
306 }
307
NumArcs(StateId s)308 size_t NumArcs(StateId s) {
309 if (!HasArcs(s)) {
310 Expand(s);
311 }
312 return CacheImpl<A>::NumArcs(s);
313 }
314
NumInputEpsilons(StateId s)315 size_t NumInputEpsilons(StateId s) {
316 if (!HasArcs(s)) {
317 Expand(s);
318 }
319 return CacheImpl<A>::NumInputEpsilons(s);
320 }
321
NumOutputEpsilons(StateId s)322 size_t NumOutputEpsilons(StateId s) {
323 if (!HasArcs(s)) {
324 Expand(s);
325 }
326 return CacheImpl<A>::NumOutputEpsilons(s);
327 }
328
Properties()329 uint64 Properties() const { return Properties(kFstProperties); }
330
331 // Set error if found; return FST impl properties.
Properties(uint64 mask)332 uint64 Properties(uint64 mask) const {
333 if ((mask & kError) && fst_->Properties(kError, false))
334 SetProperties(kError, kError);
335 return FstImpl<Arc>::Properties(mask);
336 }
337
InitArcIterator(StateId s,ArcIteratorData<A> * data)338 void InitArcIterator(StateId s, ArcIteratorData<A>* data) {
339 if (!HasArcs(s)) {
340 Expand(s);
341 }
342 CacheImpl<A>::InitArcIterator(s, data);
343 }
344
Expand(StateId s)345 void Expand(StateId s) {
346 for (ArcIterator<Fst<A> > aiter(*fst_, s); !aiter.Done(); aiter.Next()) {
347 A arc = aiter.Value();
348
349 // relabel input
350 if (relabel_input_) {
351 typename unordered_map<Label, Label>::iterator it =
352 input_map_.find(arc.ilabel);
353 if (it != input_map_.end()) { arc.ilabel = it->second; }
354 }
355
356 // relabel output
357 if (relabel_output_) {
358 typename unordered_map<Label, Label>::iterator it =
359 output_map_.find(arc.olabel);
360 if (it != output_map_.end()) { arc.olabel = it->second; }
361 }
362
363 PushArc(s, arc);
364 }
365 SetArcs(s);
366 }
367
368
369 private:
370 const Fst<A> *fst_;
371
372 unordered_map<Label, Label> input_map_;
373 unordered_map<Label, Label> output_map_;
374 bool relabel_input_;
375 bool relabel_output_;
376
377 void operator=(const RelabelFstImpl<A> &); // disallow
378 };
379
380
381 //
382 // \class RelabelFst
383 // \brief Delayed implementation of arc relabeling
384 //
385 // This class attaches interface to implementation and handles
386 // reference counting, delegating most methods to ImplToFst.
387 template <class A>
388 class RelabelFst : public ImplToFst< RelabelFstImpl<A> > {
389 public:
390 friend class ArcIterator< RelabelFst<A> >;
391 friend class StateIterator< RelabelFst<A> >;
392
393 typedef A Arc;
394 typedef typename A::Label Label;
395 typedef typename A::Weight Weight;
396 typedef typename A::StateId StateId;
397 typedef DefaultCacheStore<A> Store;
398 typedef typename Store::State State;
399 typedef RelabelFstImpl<A> Impl;
400
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs)401 RelabelFst(const Fst<A>& fst,
402 const vector<pair<Label, Label> >& ipairs,
403 const vector<pair<Label, Label> >& opairs)
404 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, RelabelFstOptions())) {}
405
RelabelFst(const Fst<A> & fst,const vector<pair<Label,Label>> & ipairs,const vector<pair<Label,Label>> & opairs,const RelabelFstOptions & opts)406 RelabelFst(const Fst<A>& fst,
407 const vector<pair<Label, Label> >& ipairs,
408 const vector<pair<Label, Label> >& opairs,
409 const RelabelFstOptions &opts)
410 : ImplToFst<Impl>(new Impl(fst, ipairs, opairs, opts)) {}
411
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols)412 RelabelFst(const Fst<A>& fst,
413 const SymbolTable* new_isymbols,
414 const SymbolTable* new_osymbols)
415 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
416 fst.OutputSymbols(), new_osymbols,
417 RelabelFstOptions())) {}
418
RelabelFst(const Fst<A> & fst,const SymbolTable * new_isymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)419 RelabelFst(const Fst<A>& fst,
420 const SymbolTable* new_isymbols,
421 const SymbolTable* new_osymbols,
422 const RelabelFstOptions &opts)
423 : ImplToFst<Impl>(new Impl(fst, fst.InputSymbols(), new_isymbols,
424 fst.OutputSymbols(), new_osymbols, opts)) {}
425
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols)426 RelabelFst(const Fst<A>& fst,
427 const SymbolTable* old_isymbols,
428 const SymbolTable* new_isymbols,
429 const SymbolTable* old_osymbols,
430 const SymbolTable* new_osymbols)
431 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
432 new_osymbols, RelabelFstOptions())) {}
433
RelabelFst(const Fst<A> & fst,const SymbolTable * old_isymbols,const SymbolTable * new_isymbols,const SymbolTable * old_osymbols,const SymbolTable * new_osymbols,const RelabelFstOptions & opts)434 RelabelFst(const Fst<A>& fst,
435 const SymbolTable* old_isymbols,
436 const SymbolTable* new_isymbols,
437 const SymbolTable* old_osymbols,
438 const SymbolTable* new_osymbols,
439 const RelabelFstOptions &opts)
440 : ImplToFst<Impl>(new Impl(fst, old_isymbols, new_isymbols, old_osymbols,
441 new_osymbols, opts)) {}
442
443 // See Fst<>::Copy() for doc.
444 RelabelFst(const RelabelFst<A> &fst, bool safe = false)
445 : ImplToFst<Impl>(fst, safe) {}
446
447 // Get a copy of this RelabelFst. See Fst<>::Copy() for further doc.
448 virtual RelabelFst<A> *Copy(bool safe = false) const {
449 return new RelabelFst<A>(*this, safe);
450 }
451
452 virtual void InitStateIterator(StateIteratorData<A> *data) const;
453
InitArcIterator(StateId s,ArcIteratorData<A> * data)454 virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
455 return GetImpl()->InitArcIterator(s, data);
456 }
457
458 private:
459 // Makes visible to friends.
GetImpl()460 Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
461
462 void operator=(const RelabelFst<A> &fst); // disallow
463 };
464
465 // Specialization for RelabelFst.
466 template<class A>
467 class StateIterator< RelabelFst<A> > : public StateIteratorBase<A> {
468 public:
469 typedef typename A::StateId StateId;
470
StateIterator(const RelabelFst<A> & fst)471 explicit StateIterator(const RelabelFst<A> &fst)
472 : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {}
473
Done()474 bool Done() const { return siter_.Done(); }
475
Value()476 StateId Value() const { return s_; }
477
Next()478 void Next() {
479 if (!siter_.Done()) {
480 ++s_;
481 siter_.Next();
482 }
483 }
484
Reset()485 void Reset() {
486 s_ = 0;
487 siter_.Reset();
488 }
489
490 private:
Done_()491 bool Done_() const { return Done(); }
Value_()492 StateId Value_() const { return Value(); }
Next_()493 void Next_() { Next(); }
Reset_()494 void Reset_() { Reset(); }
495
496 const RelabelFstImpl<A> *impl_;
497 StateIterator< Fst<A> > siter_;
498 StateId s_;
499
500 DISALLOW_COPY_AND_ASSIGN(StateIterator);
501 };
502
503
504 // Specialization for RelabelFst.
505 template <class A>
506 class ArcIterator< RelabelFst<A> >
507 : public CacheArcIterator< RelabelFst<A> > {
508 public:
509 typedef typename A::StateId StateId;
510
ArcIterator(const RelabelFst<A> & fst,StateId s)511 ArcIterator(const RelabelFst<A> &fst, StateId s)
512 : CacheArcIterator< RelabelFst<A> >(fst.GetImpl(), s) {
513 if (!fst.GetImpl()->HasArcs(s))
514 fst.GetImpl()->Expand(s);
515 }
516
517 private:
518 DISALLOW_COPY_AND_ASSIGN(ArcIterator);
519 };
520
521 template <class A> inline
InitStateIterator(StateIteratorData<A> * data)522 void RelabelFst<A>::InitStateIterator(StateIteratorData<A> *data) const {
523 data->base = new StateIterator< RelabelFst<A> >(*this);
524 }
525
526 // Useful alias when using StdArc.
527 typedef RelabelFst<StdArc> StdRelabelFst;
528
529 } // namespace fst
530
531 #endif // FST_LIB_RELABEL_H__
532