1 // reweight.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: allauzen@google.com (Cyril Allauzen)
17 //
18 // \file
19 // Function to reweight an FST.
20 
21 #ifndef FST_LIB_REWEIGHT_H__
22 #define FST_LIB_REWEIGHT_H__
23 
24 #include <vector>
25 using std::vector;
26 
27 #include <fst/mutable-fst.h>
28 
29 
30 namespace fst {
31 
32 enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL };
33 
34 // Reweight FST according to the potentials defined by the POTENTIAL
35 // vector in the direction defined by TYPE. Weight needs to be left
36 // distributive when reweighting towards the initial state and right
37 // distributive when reweighting towards the final states.
38 //
39 // An arc of weight w, with an origin state of potential p and
40 // destination state of potential q, is reweighted by p\wq when
41 // reweighting towards the initial state and by pw/q when reweighting
42 // towards the final states.
43 template <class Arc>
Reweight(MutableFst<Arc> * fst,const vector<typename Arc::Weight> & potential,ReweightType type)44 void Reweight(MutableFst<Arc> *fst,
45               const vector<typename Arc::Weight> &potential,
46               ReweightType type) {
47   typedef typename Arc::Weight Weight;
48 
49   if (fst->NumStates() == 0)
50     return;
51 
52   if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) {
53     FSTERROR() << "Reweight: Reweighting to the final states requires "
54                << "Weight to be right distributive: "
55                << Weight::Type();
56     fst->SetProperties(kError, kError);
57     return;
58   }
59 
60   if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) {
61     FSTERROR() << "Reweight: Reweighting to the initial state requires "
62                << "Weight to be left distributive: "
63                << Weight::Type();
64     fst->SetProperties(kError, kError);
65     return;
66   }
67 
68   StateIterator< MutableFst<Arc> > sit(*fst);
69   for (; !sit.Done(); sit.Next()) {
70     typename Arc::StateId state = sit.Value();
71     if (state == potential.size())
72       break;
73     typename Arc::Weight weight = potential[state];
74     if (weight != Weight::Zero()) {
75       for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
76            !ait.Done();
77            ait.Next()) {
78         Arc arc = ait.Value();
79         if (arc.nextstate >= potential.size())
80           continue;
81         typename Arc::Weight nextweight = potential[arc.nextstate];
82         if (nextweight == Weight::Zero())
83           continue;
84         if (type == REWEIGHT_TO_INITIAL)
85           arc.weight = Divide(Times(arc.weight, nextweight), weight,
86                               DIVIDE_LEFT);
87         if (type == REWEIGHT_TO_FINAL)
88           arc.weight = Divide(Times(weight, arc.weight), nextweight,
89                               DIVIDE_RIGHT);
90         ait.SetValue(arc);
91       }
92       if (type == REWEIGHT_TO_INITIAL)
93         fst->SetFinal(state, Divide(fst->Final(state), weight, DIVIDE_LEFT));
94     }
95     if (type == REWEIGHT_TO_FINAL)
96       fst->SetFinal(state, Times(weight, fst->Final(state)));
97   }
98 
99   // This handles elements past the end of the potentials array.
100   for (; !sit.Done(); sit.Next()) {
101     typename Arc::StateId state = sit.Value();
102     if (type == REWEIGHT_TO_FINAL)
103       fst->SetFinal(state, Times(Weight::Zero(), fst->Final(state)));
104   }
105 
106   typename Arc::Weight startweight = fst->Start() < potential.size() ?
107       potential[fst->Start()] : Weight::Zero();
108   if ((startweight != Weight::One()) && (startweight != Weight::Zero())) {
109     if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) {
110       typename Arc::StateId state = fst->Start();
111       for (MutableArcIterator< MutableFst<Arc> > ait(fst, state);
112            !ait.Done();
113            ait.Next()) {
114         Arc arc = ait.Value();
115         if (type == REWEIGHT_TO_INITIAL)
116           arc.weight = Times(startweight, arc.weight);
117         else
118           arc.weight = Times(
119               Divide(Weight::One(), startweight, DIVIDE_RIGHT),
120               arc.weight);
121         ait.SetValue(arc);
122       }
123       if (type == REWEIGHT_TO_INITIAL)
124         fst->SetFinal(state, Times(startweight, fst->Final(state)));
125       else
126         fst->SetFinal(state, Times(Divide(Weight::One(), startweight,
127                                           DIVIDE_RIGHT),
128                                    fst->Final(state)));
129     } else {
130       typename Arc::StateId state = fst->AddState();
131       Weight w = type == REWEIGHT_TO_INITIAL ?  startweight :
132                  Divide(Weight::One(), startweight, DIVIDE_RIGHT);
133       Arc arc(0, 0, w, fst->Start());
134       fst->AddArc(state, arc);
135       fst->SetStart(state);
136     }
137   }
138 
139   fst->SetProperties(ReweightProperties(
140                          fst->Properties(kFstProperties, false)),
141                      kFstProperties);
142 }
143 
144 }  // namespace fst
145 
146 #endif  // FST_LIB_REWEIGHT_H_
147