1 // mutable-fst.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Expanded FST augmented with mutators - interface class definition
20 // and mutable arc iterator interface.
21 //
22 
23 #ifndef FST_LIB_MUTABLE_FST_H__
24 #define FST_LIB_MUTABLE_FST_H__
25 
26 #include <stddef.h>
27 #include <sys/types.h>
28 #include <string>
29 #include <vector>
30 using std::vector;
31 
32 #include <fst/expanded-fst.h>
33 
34 
35 namespace fst {
36 
37 template <class A> class MutableArcIteratorData;
38 
39 // An expanded FST plus mutators (use MutableArcIterator to modify arcs).
40 template <class A>
41 class MutableFst : public ExpandedFst<A> {
42  public:
43   typedef A Arc;
44   typedef typename A::Weight Weight;
45   typedef typename A::StateId StateId;
46 
47   virtual MutableFst<A> &operator=(const Fst<A> &fst) = 0;
48 
49   MutableFst<A> &operator=(const MutableFst<A> &fst) {
50     return operator=(static_cast<const Fst<A> &>(fst));
51   }
52 
53   virtual void SetStart(StateId) = 0;           // Set the initial state
54   virtual void SetFinal(StateId, Weight) = 0;   // Set a state's final weight
55   virtual void SetProperties(uint64 props,
56                              uint64 mask) = 0;  // Set property bits wrt mask
57 
58   virtual StateId AddState() = 0;               // Add a state, return its ID
59   virtual void AddArc(StateId, const A &arc) = 0;   // Add an arc to state
60 
61   virtual void DeleteStates(const vector<StateId>&) = 0;  // Delete some states
62   virtual void DeleteStates() = 0;              // Delete all states
63   virtual void DeleteArcs(StateId, size_t n) = 0;  // Delete some arcs at state
64   virtual void DeleteArcs(StateId) = 0;         // Delete all arcs at state
65 
ReserveStates(StateId n)66   virtual void ReserveStates(StateId n) { }  // Optional, best effort only.
ReserveArcs(StateId s,size_t n)67   virtual void ReserveArcs(StateId s, size_t n) { }  // Optional, Best effort.
68 
69   // Return input label symbol table; return NULL if not specified
70   virtual const SymbolTable* InputSymbols() const = 0;
71   // Return output label symbol table; return NULL if not specified
72   virtual const SymbolTable* OutputSymbols() const = 0;
73 
74   // Return input label symbol table; return NULL if not specified
75   virtual SymbolTable* MutableInputSymbols() = 0;
76   // Return output label symbol table; return NULL if not specified
77   virtual SymbolTable* MutableOutputSymbols() = 0;
78 
79   // Set input label symbol table; NULL signifies not unspecified
80   virtual void SetInputSymbols(const SymbolTable* isyms) = 0;
81   // Set output label symbol table; NULL signifies not unspecified
82   virtual void SetOutputSymbols(const SymbolTable* osyms) = 0;
83 
84   // Get a copy of this MutableFst. See Fst<>::Copy() for further doc.
85   virtual MutableFst<A> *Copy(bool safe = false) const = 0;
86 
87   // Read an MutableFst from an input stream; return NULL on error.
Read(istream & strm,const FstReadOptions & opts)88   static MutableFst<A> *Read(istream &strm, const FstReadOptions &opts) {
89     FstReadOptions ropts(opts);
90     FstHeader hdr;
91     if (ropts.header)
92       hdr = *opts.header;
93     else {
94       if (!hdr.Read(strm, opts.source))
95         return 0;
96       ropts.header = &hdr;
97     }
98     if (!(hdr.Properties() & kMutable)) {
99       LOG(ERROR) << "MutableFst::Read: Not a MutableFst: " << ropts.source;
100       return 0;
101     }
102     FstRegister<A> *registr = FstRegister<A>::GetRegister();
103     const typename FstRegister<A>::Reader reader =
104       registr->GetReader(hdr.FstType());
105     if (!reader) {
106       LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << hdr.FstType()
107                  << "\" (arc type = \"" << A::Type()
108                  << "\"): " << ropts.source;
109       return 0;
110     }
111     Fst<A> *fst = reader(strm, ropts);
112     if (!fst) return 0;
113     return static_cast<MutableFst<A> *>(fst);
114   }
115 
116   // Read a MutableFst from a file; return NULL on error.
117   // Empty filename reads from standard input. If 'convert' is true,
118   // convert to a mutable FST of type 'convert_type' if file is
119   // a non-mutable FST.
120   static MutableFst<A> *Read(const string &filename, bool convert = false,
121                              const string &convert_type = "vector") {
122     if (convert == false) {
123       if (!filename.empty()) {
124         ifstream strm(filename.c_str(), ifstream::in | ifstream::binary);
125         if (!strm) {
126           LOG(ERROR) << "MutableFst::Read: Can't open file: " << filename;
127           return 0;
128         }
129         return Read(strm, FstReadOptions(filename));
130       } else {
131         return Read(cin, FstReadOptions("standard input"));
132       }
133     } else {  // Converts to 'convert_type' if not mutable.
134       Fst<A> *ifst = Fst<A>::Read(filename);
135       if (!ifst) return 0;
136       if (ifst->Properties(kMutable, false)) {
137         return static_cast<MutableFst *>(ifst);
138       } else {
139         Fst<A> *ofst = Convert(*ifst, convert_type);
140         delete ifst;
141         if (!ofst) return 0;
142         if (!ofst->Properties(kMutable, false))
143           LOG(ERROR) << "MutableFst: bad convert type: " << convert_type;
144         return static_cast<MutableFst *>(ofst);
145       }
146     }
147   }
148 
149   // For generic mutuble arc iterator construction; not normally called
150   // directly by users.
151   virtual void InitMutableArcIterator(StateId s,
152                                       MutableArcIteratorData<A> *) = 0;
153 };
154 
155 // Mutable arc iterator interface, templated on the Arc definition; used
156 // for mutable Arc iterator specializations that are returned by
157 // the InitMutableArcIterator MutableFst method.
158 template <class A>
159 class MutableArcIteratorBase : public ArcIteratorBase<A> {
160  public:
161   typedef A Arc;
162 
SetValue(const A & arc)163   void SetValue(const A &arc) { SetValue_(arc); }  // Set current arc's content
164 
165  private:
166   virtual void SetValue_(const A &arc) = 0;
167 };
168 
169 template <class A>
170 struct MutableArcIteratorData {
171   MutableArcIteratorBase<A> *base;  // Specific iterator
172 };
173 
174 // Generic mutable arc iterator, templated on the FST definition
175 // - a wrapper around pointer to specific one.
176 // Here is a typical use: \code
177 //   for (MutableArcIterator<StdFst> aiter(&fst, s);
178 //        !aiter.Done();
179 //         aiter.Next()) {
180 //     StdArc arc = aiter.Value();
181 //     arc.ilabel = 7;
182 //     aiter.SetValue(arc);
183 //     ...
184 //   } \endcode
185 // This version requires function calls.
186 template <class F>
187 class MutableArcIterator {
188  public:
189   typedef F FST;
190   typedef typename F::Arc Arc;
191   typedef typename Arc::StateId StateId;
192 
MutableArcIterator(F * fst,StateId s)193   MutableArcIterator(F *fst, StateId s) {
194     fst->InitMutableArcIterator(s, &data_);
195   }
~MutableArcIterator()196   ~MutableArcIterator() { delete data_.base; }
197 
Done()198   bool Done() const { return data_.base->Done(); }
Value()199   const Arc& Value() const { return data_.base->Value(); }
Next()200   void Next() { data_.base->Next(); }
Position()201   size_t Position() const { return data_.base->Position(); }
Reset()202   void Reset() { data_.base->Reset(); }
Seek(size_t a)203   void Seek(size_t a) { data_.base->Seek(a); }
SetValue(const Arc & a)204   void SetValue(const Arc &a) { data_.base->SetValue(a); }
Flags()205   uint32 Flags() const { return data_.base->Flags(); }
SetFlags(uint32 f,uint32 m)206   void SetFlags(uint32 f, uint32 m) {
207     return data_.base->SetFlags(f, m);
208   }
209 
210  private:
211   MutableArcIteratorData<Arc> data_;
212   DISALLOW_COPY_AND_ASSIGN(MutableArcIterator);
213 };
214 
215 
216 namespace internal {
217 
218 //  MutableFst<A> case - abstract methods.
219 template <class A> inline
Final(const MutableFst<A> & fst,typename A::StateId s)220 typename A::Weight Final(const MutableFst<A> &fst, typename A::StateId s) {
221   return fst.Final(s);
222 }
223 
224 template <class A> inline
NumArcs(const MutableFst<A> & fst,typename A::StateId s)225 ssize_t NumArcs(const MutableFst<A> &fst, typename A::StateId s) {
226   return fst.NumArcs(s);
227 }
228 
229 template <class A> inline
NumInputEpsilons(const MutableFst<A> & fst,typename A::StateId s)230 ssize_t NumInputEpsilons(const MutableFst<A> &fst, typename A::StateId s) {
231   return fst.NumInputEpsilons(s);
232 }
233 
234 template <class A> inline
NumOutputEpsilons(const MutableFst<A> & fst,typename A::StateId s)235 ssize_t NumOutputEpsilons(const MutableFst<A> &fst, typename A::StateId s) {
236   return fst.NumOutputEpsilons(s);
237 }
238 
239 }  // namespace internal
240 
241 
242 // A useful alias when using StdArc.
243 typedef MutableFst<StdArc> StdMutableFst;
244 
245 
246 // This is a helper class template useful for attaching a MutableFst
247 // interface to its implementation, handling reference counting and
248 // copy-on-write.
249 template <class I, class F = MutableFst<typename I::Arc> >
250 class ImplToMutableFst : public ImplToExpandedFst<I, F> {
251  public:
252   typedef typename I::Arc Arc;
253   typedef typename Arc::Weight Weight;
254   typedef typename Arc::StateId StateId;
255 
256   using ImplToFst<I, F>::GetImpl;
257   using ImplToFst<I, F>::SetImpl;
258 
SetStart(StateId s)259   virtual void SetStart(StateId s) {
260     MutateCheck();
261     GetImpl()->SetStart(s);
262   }
263 
SetFinal(StateId s,Weight w)264   virtual void SetFinal(StateId s, Weight w) {
265     MutateCheck();
266     GetImpl()->SetFinal(s, w);
267   }
268 
SetProperties(uint64 props,uint64 mask)269   virtual void SetProperties(uint64 props, uint64 mask) {
270     // Can skip mutate check if extrinsic properties don't change,
271     // since it is then safe to update all (shallow) copies
272     uint64 exprops = kExtrinsicProperties & mask;
273     if (GetImpl()->Properties(exprops) != (props & exprops))
274       MutateCheck();
275     GetImpl()->SetProperties(props, mask);
276   }
277 
AddState()278   virtual StateId AddState() {
279     MutateCheck();
280     return GetImpl()->AddState();
281   }
282 
AddArc(StateId s,const Arc & arc)283   virtual void AddArc(StateId s, const Arc &arc) {
284     MutateCheck();
285     GetImpl()->AddArc(s, arc);
286   }
287 
DeleteStates(const vector<StateId> & dstates)288   virtual void DeleteStates(const vector<StateId> &dstates) {
289     MutateCheck();
290     GetImpl()->DeleteStates(dstates);
291   }
292 
DeleteStates()293   virtual void DeleteStates() {
294     MutateCheck();
295     GetImpl()->DeleteStates();
296   }
297 
DeleteArcs(StateId s,size_t n)298   virtual void DeleteArcs(StateId s, size_t n) {
299     MutateCheck();
300     GetImpl()->DeleteArcs(s, n);
301   }
302 
DeleteArcs(StateId s)303   virtual void DeleteArcs(StateId s) {
304     MutateCheck();
305     GetImpl()->DeleteArcs(s);
306   }
307 
ReserveStates(StateId s)308   virtual void ReserveStates(StateId s) {
309     MutateCheck();
310     GetImpl()->ReserveStates(s);
311   }
312 
ReserveArcs(StateId s,size_t n)313   virtual void ReserveArcs(StateId s, size_t n) {
314     MutateCheck();
315     GetImpl()->ReserveArcs(s, n);
316   }
317 
InputSymbols()318   virtual const SymbolTable* InputSymbols() const {
319     return GetImpl()->InputSymbols();
320   }
321 
OutputSymbols()322   virtual const SymbolTable* OutputSymbols() const {
323     return GetImpl()->OutputSymbols();
324   }
325 
MutableInputSymbols()326   virtual SymbolTable* MutableInputSymbols() {
327     MutateCheck();
328     return GetImpl()->InputSymbols();
329   }
330 
MutableOutputSymbols()331   virtual SymbolTable* MutableOutputSymbols() {
332     MutateCheck();
333     return GetImpl()->OutputSymbols();
334   }
335 
SetInputSymbols(const SymbolTable * isyms)336   virtual void SetInputSymbols(const SymbolTable* isyms) {
337     MutateCheck();
338     GetImpl()->SetInputSymbols(isyms);
339   }
340 
SetOutputSymbols(const SymbolTable * osyms)341   virtual void SetOutputSymbols(const SymbolTable* osyms) {
342     MutateCheck();
343     GetImpl()->SetOutputSymbols(osyms);
344   }
345 
346  protected:
ImplToMutableFst()347   ImplToMutableFst() : ImplToExpandedFst<I, F>() {}
348 
ImplToMutableFst(I * impl)349   ImplToMutableFst(I *impl) : ImplToExpandedFst<I, F>(impl) {}
350 
351 
ImplToMutableFst(const ImplToMutableFst<I,F> & fst)352   ImplToMutableFst(const ImplToMutableFst<I, F> &fst)
353       : ImplToExpandedFst<I, F>(fst) {}
354 
ImplToMutableFst(const ImplToMutableFst<I,F> & fst,bool safe)355   ImplToMutableFst(const ImplToMutableFst<I, F> &fst, bool safe)
356       : ImplToExpandedFst<I, F>(fst, safe) {}
357 
MutateCheck()358   void MutateCheck() {
359     // Copy on write
360     if (GetImpl()->RefCount() > 1)
361       SetImpl(new I(*this));
362   }
363 
364  private:
365   // Disallow
366   ImplToMutableFst<I, F>  &operator=(const ImplToMutableFst<I, F> &fst);
367 
368   ImplToMutableFst<I, F> &operator=(const Fst<Arc> &fst) {
369     FSTERROR() << "ImplToMutableFst: Assignment operator disallowed";
370     GetImpl()->SetProperties(kError, kError);
371     return *this;
372   }
373 };
374 
375 
376 }  // namespace fst
377 
378 #endif  // FST_LIB_MUTABLE_FST_H__
379