1 // algo_test.h
2 
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Regression test for various FST algorithms.
20 
21 #ifndef FST_TEST_ALGO_TEST_H__
22 #define FST_TEST_ALGO_TEST_H__
23 
24 #include <fst/fstlib.h>
25 #include <fst/random-weight.h>
26 
27 DECLARE_int32(repeat);  // defined in ./algo_test.cc
28 
29 namespace fst {
30 
31 // Mapper to change input and output label of every transition into
32 // epsilons.
33 template <class A>
34 class EpsMapper {
35  public:
EpsMapper()36   EpsMapper() {}
37 
operator()38   A operator()(const A &arc) const {
39     return A(0, 0, arc.weight, arc.nextstate);
40   }
41 
Properties(uint64 props)42   uint64 Properties(uint64 props) const {
43     props &= ~kNotAcceptor;
44     props |= kAcceptor;
45     props &= ~kNoIEpsilons & ~kNoOEpsilons &  ~kNoEpsilons;
46     props |= kIEpsilons | kOEpsilons | kEpsilons;
47     props &= ~kNotILabelSorted & ~kNotOLabelSorted;
48     props |= kILabelSorted | kOLabelSorted;
49     return props;
50   }
51 
FinalAction()52   MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
53 
54 
InputSymbolsAction()55   MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
56 
OutputSymbolsAction()57   MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
58 };
59 
60 // Generic - no lookahead.
61 template <class Arc>
LookAheadCompose(const Fst<Arc> & ifst1,const Fst<Arc> & ifst2,MutableFst<Arc> * ofst)62 void LookAheadCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
63                       MutableFst<Arc> *ofst) {
64   Compose(ifst1, ifst2, ofst);
65 }
66 
67 // Specialized and epsilon olabel acyclic - lookahead.
LookAheadCompose(const Fst<StdArc> & ifst1,const Fst<StdArc> & ifst2,MutableFst<StdArc> * ofst)68 void LookAheadCompose(const Fst<StdArc> &ifst1, const Fst<StdArc> &ifst2,
69                       MutableFst<StdArc> *ofst) {
70   vector<StdArc::StateId> order;
71   bool acyclic;
72   TopOrderVisitor<StdArc> visitor(&order, &acyclic);
73   DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter<StdArc>());
74   if (acyclic) {  // no ifst1 output epsilon cycles?
75     StdOLabelLookAheadFst lfst1(ifst1);
76     StdVectorFst lfst2(ifst2);
77     LabelLookAheadRelabeler<StdArc>::Relabel(&lfst2, lfst1, true);
78     Compose(lfst1, lfst2, ofst);
79   } else {
80     Compose(ifst1, ifst2, ofst);
81   }
82 }
83 
84 // This class tests a variety of identities and properties that must
85 // hold for various algorithms on weighted FSTs.
86 template <class Arc, class WeightGenerator>
87 class WeightedTester {
88  public:
89   typedef typename Arc::Label Label;
90   typedef typename Arc::StateId StateId;
91   typedef typename Arc::Weight Weight;
92 
WeightedTester(int seed,const Fst<Arc> & zero_fst,const Fst<Arc> & one_fst,const Fst<Arc> & univ_fst,WeightGenerator * weight_generator)93   WeightedTester(int seed, const Fst<Arc> &zero_fst, const Fst<Arc> &one_fst,
94                  const Fst<Arc> &univ_fst, WeightGenerator *weight_generator)
95       : seed_(seed), zero_fst_(zero_fst), one_fst_(one_fst),
96         univ_fst_(univ_fst), weight_generator_(weight_generator) {}
97 
Test(const Fst<Arc> & T1,const Fst<Arc> & T2,const Fst<Arc> & T3)98   void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
99     TestRational(T1, T2, T3);
100     TestMap(T1);
101     TestCompose(T1, T2, T3);
102     TestSort(T1);
103     TestOptimize(T1);
104     TestSearch(T1);
105   }
106 
107  private:
108   // Tests rational operations with identities
TestRational(const Fst<Arc> & T1,const Fst<Arc> & T2,const Fst<Arc> & T3)109   void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2,
110                     const Fst<Arc> &T3) {
111 
112     {
113       VLOG(1) << "Check destructive and delayed union are equivalent.";
114       VectorFst<Arc> U1(T1);
115       Union(&U1,  T2);
116       UnionFst<Arc> U2(T1, T2);
117       CHECK(Equiv(U1, U2));
118     }
119 
120 
121     {
122       VLOG(1) << "Check destructive and delayed concatenation are equivalent.";
123       VectorFst<Arc> C1(T1);
124       Concat(&C1,  T2);
125       ConcatFst<Arc> C2(T1, T2);
126       CHECK(Equiv(C1, C2));
127       VectorFst<Arc> C3(T2);
128       Concat(T1, &C3);
129       CHECK(Equiv(C3, C2));
130     }
131 
132     {
133       VLOG(1) << "Check destructive and delayed closure* are equivalent.";
134       VectorFst<Arc> C1(T1);
135       Closure(&C1, CLOSURE_STAR);
136       ClosureFst<Arc> C2(T1, CLOSURE_STAR);
137       CHECK(Equiv(C1, C2));
138     }
139 
140     {
141       VLOG(1) << "Check destructive and delayed closure+ are equivalent.";
142       VectorFst<Arc> C1(T1);
143       Closure(&C1, CLOSURE_PLUS);
144       ClosureFst<Arc> C2(T1, CLOSURE_PLUS);
145       CHECK(Equiv(C1, C2));
146     }
147 
148     {
149       VLOG(1)  << "Check union is associative (destructive).";
150       VectorFst<Arc> U1(T1);
151       Union(&U1, T2);
152       Union(&U1, T3);
153 
154       VectorFst<Arc> U3(T2);
155       Union(&U3, T3);
156       VectorFst<Arc> U4(T1);
157       Union(&U4, U3);
158 
159       CHECK(Equiv(U1, U4));
160     }
161 
162     {
163       VLOG(1) << "Check union is associative (delayed).";
164       UnionFst<Arc> U1(T1, T2);
165       UnionFst<Arc> U2(U1, T3);
166 
167       UnionFst<Arc> U3(T2, T3);
168       UnionFst<Arc> U4(T1, U3);
169 
170       CHECK(Equiv(U2, U4));
171     }
172 
173 
174     {
175       VLOG(1) << "Check union is associative (destructive delayed).";
176       UnionFst<Arc> U1(T1, T2);
177       Union(&U1, T3);
178 
179       UnionFst<Arc> U3(T2, T3);
180       UnionFst<Arc> U4(T1, U3);
181 
182       CHECK(Equiv(U1, U4));
183     }
184 
185     {
186       VLOG(1) << "Check concatenation is associative (destructive).";
187       VectorFst<Arc> C1(T1);
188       Concat(&C1, T2);
189       Concat(&C1, T3);
190 
191       VectorFst<Arc> C3(T2);
192       Concat(&C3, T3);
193       VectorFst<Arc> C4(T1);
194       Concat(&C4, C3);
195 
196       CHECK(Equiv(C1, C4));
197     }
198 
199     {
200       VLOG(1) << "Check concatenation is associative (delayed).";
201       ConcatFst<Arc> C1(T1, T2);
202       ConcatFst<Arc> C2(C1, T3);
203 
204       ConcatFst<Arc> C3(T2, T3);
205       ConcatFst<Arc> C4(T1, C3);
206 
207       CHECK(Equiv(C2, C4));
208     }
209 
210     {
211       VLOG(1) << "Check concatenation is associative (destructive delayed).";
212       ConcatFst<Arc> C1(T1, T2);
213       Concat(&C1, T3);
214 
215       ConcatFst<Arc> C3(T2, T3);
216       ConcatFst<Arc> C4(T1, C3);
217 
218       CHECK(Equiv(C1, C4));
219     }
220 
221     if (Weight::Properties() & kLeftSemiring) {
222       VLOG(1) << "Check concatenation left distributes"
223               << " over union (destructive).";
224 
225       VectorFst<Arc> U1(T1);
226       Union(&U1, T2);
227       VectorFst<Arc> C1(T3);
228       Concat(&C1, U1);
229 
230       VectorFst<Arc> C2(T3);
231       Concat(&C2, T1);
232       VectorFst<Arc> C3(T3);
233       Concat(&C3, T2);
234       VectorFst<Arc> U2(C2);
235       Union(&U2, C3);
236 
237       CHECK(Equiv(C1, U2));
238     }
239 
240     if (Weight::Properties() & kRightSemiring) {
241       VLOG(1) << "Check concatenation right distributes"
242               <<  " over union (destructive).";
243       VectorFst<Arc> U1(T1);
244       Union(&U1, T2);
245       VectorFst<Arc> C1(U1);
246       Concat(&C1, T3);
247 
248       VectorFst<Arc> C2(T1);
249       Concat(&C2, T3);
250       VectorFst<Arc> C3(T2);
251       Concat(&C3, T3);
252       VectorFst<Arc> U2(C2);
253       Union(&U2, C3);
254 
255       CHECK(Equiv(C1, U2));
256     }
257 
258     if (Weight::Properties() & kLeftSemiring) {
259       VLOG(1) << "Check concatenation left distributes over union (delayed).";
260       UnionFst<Arc> U1(T1, T2);
261       ConcatFst<Arc> C1(T3, U1);
262 
263       ConcatFst<Arc> C2(T3, T1);
264       ConcatFst<Arc> C3(T3, T2);
265       UnionFst<Arc> U2(C2, C3);
266 
267       CHECK(Equiv(C1, U2));
268     }
269 
270     if (Weight::Properties() & kRightSemiring) {
271       VLOG(1) << "Check concatenation right distributes over union (delayed).";
272       UnionFst<Arc> U1(T1, T2);
273       ConcatFst<Arc> C1(U1, T3);
274 
275       ConcatFst<Arc> C2(T1, T3);
276       ConcatFst<Arc> C3(T2, T3);
277       UnionFst<Arc> U2(C2, C3);
278 
279       CHECK(Equiv(C1, U2));
280     }
281 
282 
283     if (Weight::Properties() & kLeftSemiring) {
284       VLOG(1) << "Check T T* == T+ (destructive).";
285       VectorFst<Arc> S(T1);
286       Closure(&S, CLOSURE_STAR);
287       VectorFst<Arc> C(T1);
288       Concat(&C, S);
289 
290       VectorFst<Arc> P(T1);
291       Closure(&P, CLOSURE_PLUS);
292 
293       CHECK(Equiv(C, P));
294     }
295 
296 
297     if (Weight::Properties() & kRightSemiring) {
298       VLOG(1) << "Check T* T == T+ (destructive).";
299       VectorFst<Arc> S(T1);
300       Closure(&S, CLOSURE_STAR);
301       VectorFst<Arc> C(S);
302       Concat(&C, T1);
303 
304       VectorFst<Arc> P(T1);
305       Closure(&P, CLOSURE_PLUS);
306 
307       CHECK(Equiv(C, P));
308    }
309 
310     if (Weight::Properties() & kLeftSemiring) {
311       VLOG(1) << "Check T T* == T+ (delayed).";
312       ClosureFst<Arc> S(T1, CLOSURE_STAR);
313       ConcatFst<Arc> C(T1, S);
314 
315       ClosureFst<Arc> P(T1, CLOSURE_PLUS);
316 
317       CHECK(Equiv(C, P));
318     }
319 
320     if (Weight::Properties() & kRightSemiring) {
321       VLOG(1) << "Check T* T == T+ (delayed).";
322       ClosureFst<Arc> S(T1, CLOSURE_STAR);
323       ConcatFst<Arc> C(S, T1);
324 
325       ClosureFst<Arc> P(T1, CLOSURE_PLUS);
326 
327       CHECK(Equiv(C, P));
328     }
329   }
330 
331   // Tests map-based operations.
TestMap(const Fst<Arc> & T)332   void TestMap(const Fst<Arc> &T) {
333 
334     {
335       VLOG(1) << "Check destructive and delayed projection are equivalent.";
336       VectorFst<Arc> P1(T);
337       Project(&P1, PROJECT_INPUT);
338       ProjectFst<Arc> P2(T, PROJECT_INPUT);
339       CHECK(Equiv(P1, P2));
340     }
341 
342 
343     {
344       VLOG(1) << "Check destructive and delayed inversion are equivalent.";
345       VectorFst<Arc> I1(T);
346       Invert(&I1);
347       InvertFst<Arc> I2(T);
348       CHECK(Equiv(I1, I2));
349     }
350 
351     {
352       VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive).";
353       VectorFst<Arc> P1(T);
354       VectorFst<Arc> I1(T);
355       Project(&P1, PROJECT_INPUT);
356       Invert(&I1);
357       Project(&I1, PROJECT_OUTPUT);
358       CHECK(Equiv(P1, I1));
359     }
360 
361     {
362       VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive).";
363       VectorFst<Arc> P1(T);
364       VectorFst<Arc> I1(T);
365       Project(&P1, PROJECT_OUTPUT);
366       Invert(&I1);
367       Project(&I1, PROJECT_INPUT);
368       CHECK(Equiv(P1, I1));
369     }
370 
371     {
372       VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed).";
373       ProjectFst<Arc> P1(T, PROJECT_INPUT);
374       InvertFst<Arc> I1(T);
375       ProjectFst<Arc> P2(I1, PROJECT_OUTPUT);
376       CHECK(Equiv(P1, P2));
377     }
378 
379 
380     {
381       VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed).";
382       ProjectFst<Arc> P1(T, PROJECT_OUTPUT);
383       InvertFst<Arc> I1(T);
384       ProjectFst<Arc> P2(I1, PROJECT_INPUT);
385       CHECK(Equiv(P1, P2));
386     }
387 
388 
389     {
390       VLOG(1) << "Check destructive relabeling";
391       static const int kNumLabels = 10;
392       // set up relabeling pairs
393       vector<Label> labelset(kNumLabels);
394       for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
395       for (size_t i = 0; i < kNumLabels; ++i) {
396         swap(labelset[i], labelset[rand() % kNumLabels]);
397       }
398 
399       vector<pair<Label, Label> > ipairs1(kNumLabels);
400       vector<pair<Label, Label> > opairs1(kNumLabels);
401       for (size_t i = 0; i < kNumLabels; ++i) {
402         ipairs1[i] = make_pair(i, labelset[i]);
403         opairs1[i] = make_pair(labelset[i], i);
404       }
405       VectorFst<Arc> R(T);
406       Relabel(&R, ipairs1, opairs1);
407 
408       vector<pair<Label, Label> > ipairs2(kNumLabels);
409       vector<pair<Label, Label> > opairs2(kNumLabels);
410       for (size_t i = 0; i < kNumLabels; ++i) {
411         ipairs2[i] = make_pair(labelset[i], i);
412         opairs2[i] = make_pair(i, labelset[i]);
413       }
414       Relabel(&R, ipairs2, opairs2);
415       CHECK(Equiv(R, T));
416 
417       VLOG(1) << "Check on-the-fly relabeling";
418       RelabelFst<Arc> Rdelay(T, ipairs1, opairs1);
419 
420       RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2);
421       CHECK(Equiv(RRdelay, T));
422     }
423 
424     {
425       VLOG(1) << "Check encoding/decoding (destructive).";
426       VectorFst<Arc> D(T);
427       uint32 encode_props = 0;
428       if (rand() % 2)
429         encode_props |= kEncodeLabels;
430       if (rand() % 2)
431         encode_props |= kEncodeWeights;
432       EncodeMapper<Arc> encoder(encode_props, ENCODE);
433       Encode(&D, &encoder);
434       Decode(&D, encoder);
435       CHECK(Equiv(D, T));
436     }
437 
438     {
439       VLOG(1) << "Check encoding/decoding (delayed).";
440       uint32 encode_props = 0;
441       if (rand() % 2)
442         encode_props |= kEncodeLabels;
443       if (rand() % 2)
444         encode_props |= kEncodeWeights;
445       EncodeMapper<Arc> encoder(encode_props, ENCODE);
446       EncodeFst<Arc> E(T, &encoder);
447       VectorFst<Arc> Encoded(E);
448       DecodeFst<Arc> D(Encoded, encoder);
449       CHECK(Equiv(D, T));
450     }
451 
452     {
453       VLOG(1) << "Check gallic mappers (constructive).";
454       ToGallicMapper<Arc> to_mapper;
455       FromGallicMapper<Arc> from_mapper;
456       VectorFst< GallicArc<Arc> > G;
457       VectorFst<Arc> F;
458       ArcMap(T, &G, to_mapper);
459       ArcMap(G, &F, from_mapper);
460       CHECK(Equiv(T, F));
461     }
462 
463     {
464       VLOG(1) << "Check gallic mappers (delayed).";
465       ToGallicMapper<Arc> to_mapper;
466       FromGallicMapper<Arc> from_mapper;
467       ArcMapFst<Arc, GallicArc<Arc>, ToGallicMapper<Arc> >
468         G(T, to_mapper);
469       ArcMapFst<GallicArc<Arc>, Arc, FromGallicMapper<Arc> >
470         F(G, from_mapper);
471       CHECK(Equiv(T, F));
472     }
473   }
474 
475   // Tests compose-based operations.
TestCompose(const Fst<Arc> & T1,const Fst<Arc> & T2,const Fst<Arc> & T3)476   void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2,
477                    const Fst<Arc> &T3) {
478     if (!(Weight::Properties() & kCommutative))
479       return;
480 
481     VectorFst<Arc> S1(T1);
482     VectorFst<Arc> S2(T2);
483     VectorFst<Arc> S3(T3);
484 
485     ILabelCompare<Arc> icomp;
486     OLabelCompare<Arc> ocomp;
487 
488     ArcSort(&S1, ocomp);
489     ArcSort(&S2, ocomp);
490     ArcSort(&S3, icomp);
491 
492     {
493       VLOG(1) << "Check composition is associative.";
494       ComposeFst<Arc> C1(S1, S2);
495       ComposeFst<Arc> C2(C1, S3);
496       ComposeFst<Arc> C3(S2, S3);
497       ComposeFst<Arc> C4(S1, C3);
498 
499       CHECK(Equiv(C2, C4));
500     }
501 
502     {
503       VLOG(1) << "Check composition left distributes over union.";
504       UnionFst<Arc> U1(S2, S3);
505       ComposeFst<Arc> C1(S1, U1);
506 
507       ComposeFst<Arc> C2(S1, S2);
508       ComposeFst<Arc> C3(S1, S3);
509       UnionFst<Arc> U2(C2, C3);
510 
511       CHECK(Equiv(C1, U2));
512     }
513 
514     {
515       VLOG(1) << "Check composition right distributes over union.";
516       UnionFst<Arc> U1(S1, S2);
517       ComposeFst<Arc> C1(U1, S3);
518 
519       ComposeFst<Arc> C2(S1, S3);
520       ComposeFst<Arc> C3(S2, S3);
521       UnionFst<Arc> U2(C2, C3);
522 
523       CHECK(Equiv(C1, U2));
524     }
525 
526     VectorFst<Arc> A1(S1);
527     VectorFst<Arc> A2(S2);
528     VectorFst<Arc> A3(S3);
529     Project(&A1, PROJECT_OUTPUT);
530     Project(&A2, PROJECT_INPUT);
531     Project(&A3, PROJECT_INPUT);
532 
533     {
534       VLOG(1) << "Check intersection is commutative.";
535       IntersectFst<Arc> I1(A1, A2);
536       IntersectFst<Arc> I2(A2, A1);
537       CHECK(Equiv(I1, I2));
538     }
539 
540     {
541       VLOG(1) << "Check all epsilon filters leads to equivalent results.";
542       typedef Matcher< Fst<Arc> > M;
543       ComposeFst<Arc> C1(S1, S2);
544       ComposeFst<Arc> C2(
545           S1, S2,
546           ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M> >());
547       ComposeFst<Arc> C3(
548           S1, S2,
549           ComposeFstOptions<Arc, M, MatchComposeFilter<M> >());
550 
551       CHECK(Equiv(C1, C2));
552       CHECK(Equiv(C1, C3));
553     }
554 
555     {
556       VLOG(1) << "Check look-ahead filters lead to equivalent results.";
557       VectorFst<Arc> C1, C2;
558       Compose(S1, S2, &C1);
559       LookAheadCompose(S1, S2, &C2);
560       CHECK(Equiv(C1, C2));
561     }
562   }
563 
564   // Tests sorting operations
TestSort(const Fst<Arc> & T)565   void TestSort(const Fst<Arc> &T) {
566     ILabelCompare<Arc> icomp;
567     OLabelCompare<Arc> ocomp;
568 
569     {
570       VLOG(1) << "Check arc sorted Fst is equivalent to its input.";
571       VectorFst<Arc> S1(T);
572       ArcSort(&S1, icomp);
573       CHECK(Equiv(T, S1));
574     }
575 
576     {
577       VLOG(1) << "Check destructive and delayed arcsort are equivalent.";
578       VectorFst<Arc> S1(T);
579       ArcSort(&S1, icomp);
580       ArcSortFst< Arc, ILabelCompare<Arc> > S2(T, icomp);
581       CHECK(Equiv(S1, S2));
582     }
583 
584     {
585       VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions.";
586       VectorFst<Arc> S1(T);
587       VectorFst<Arc> S2(T);
588       ArcSort(&S1, icomp);
589       Invert(&S2);
590       ArcSort(&S2, ocomp);
591       Invert(&S2);
592       CHECK(Equiv(S1, S2));
593     }
594 
595     {
596       VLOG(1) << "Check topologically sorted Fst is equivalent to its input.";
597       VectorFst<Arc> S1(T);
598       TopSort(&S1);
599       CHECK(Equiv(T, S1));
600     }
601 
602     {
603       VLOG(1) << "Check reverse(reverse(T)) = T";
604       for (int i = 0; i < 2; ++i) {
605         VectorFst< ReverseArc<Arc> > R1;
606         VectorFst<Arc> R2;
607         bool require_superinitial = i == 1;
608         Reverse(T, &R1, require_superinitial);
609         Reverse(R1, &R2, require_superinitial);
610         CHECK(Equiv(T, R2));
611       }
612     }
613   }
614 
615   // Tests optimization operations
TestOptimize(const Fst<Arc> & T)616   void TestOptimize(const Fst<Arc> &T) {
617     uint64 tprops = T.Properties(kFstProperties, true);
618     uint64 wprops = Weight::Properties();
619 
620     VectorFst<Arc> A(T);
621     Project(&A, PROJECT_INPUT);
622 
623     {
624       VLOG(1) << "Check connected FST is equivalent to its input.";
625       VectorFst<Arc> C1(T);
626       Connect(&C1);
627       CHECK(Equiv(T, C1));
628     }
629 
630     if ((wprops & kSemiring) == kSemiring &&
631         (tprops & kAcyclic || wprops & kIdempotent)) {
632       VLOG(1) << "Check epsilon-removed FST is equivalent to its input.";
633       VectorFst<Arc> R1(T);
634       RmEpsilon(&R1);
635       CHECK(Equiv(T, R1));
636 
637       VLOG(1) << "Check destructive and delayed epsilon removal"
638               << "are equivalent.";
639       RmEpsilonFst<Arc> R2(T);
640       CHECK(Equiv(R1, R2));
641 
642       VLOG(1) << "Check an FST with a large proportion"
643               << " of epsilon transitions:";
644       // Maps all transitions of T to epsilon-transitions and append
645       // a non-epsilon transition.
646       VectorFst<Arc> U;
647       ArcMap(T, &U, EpsMapper<Arc>());
648       VectorFst<Arc> V;
649       V.SetStart(V.AddState());
650       Arc arc(1, 1, Weight::One(), V.AddState());
651       V.AddArc(V.Start(), arc);
652       V.SetFinal(arc.nextstate, Weight::One());
653       Concat(&U, V);
654       // Check that epsilon-removal preserves the shortest-distance
655       // from the initial state to the final states.
656       vector<Weight> d;
657       ShortestDistance(U, &d, true);
658       Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero();
659       VectorFst<Arc> U1(U);
660       RmEpsilon(&U1);
661       ShortestDistance(U1, &d, true);
662       Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero();
663       CHECK(ApproxEqual(w, w1, kTestDelta));
664       RmEpsilonFst<Arc> U2(U);
665       ShortestDistance(U2, &d, true);
666       Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero();
667       CHECK(ApproxEqual(w, w2, kTestDelta));
668     }
669 
670     if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
671       VLOG(1) << "Check determinized FSA is equivalent to its input.";
672       DeterminizeFst<Arc> D(A);
673       CHECK(Equiv(A, D));
674 
675       if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
676         VLOG(1)  << "Check pruning in determinization";
677         VectorFst<Arc> P;
678         Weight threshold = (*weight_generator_)();
679         DeterminizeOptions<Arc> opts;
680         opts.weight_threshold = threshold;
681         Determinize(A, &P, opts);
682         CHECK(P.Properties(kIDeterministic, true));
683         CHECK(PruneEquiv(A, P, threshold));
684       }
685 
686       int n;
687       {
688         VLOG(1) << "Check size(min(det(A))) <= size(det(A))"
689                 << " and  min(det(A)) equiv det(A)";
690         VectorFst<Arc> M(D);
691         n = M.NumStates();
692         Minimize(&M);
693         CHECK(Equiv(D, M));
694         CHECK(M.NumStates() <= n);
695         n = M.NumStates();
696       }
697 
698       if (n && (wprops & kIdempotent) == kIdempotent &&
699           A.Properties(kNoEpsilons, true)) {
700         VLOG(1) << "Check that Revuz's algorithm leads to the"
701                 << " same number of states as Brozozowski's algorithm";
702 
703         // Skip test if A is the empty machine or contains epsilons or
704         // if the semiring is not idempotent (to avoid floating point
705         // errors)
706         VectorFst<Arc> R;
707         Reverse(A, &R);
708         RmEpsilon(&R);
709         DeterminizeFst<Arc> DR(R);
710         VectorFst<Arc> RD;
711         Reverse(DR, &RD);
712         DeterminizeFst<Arc> DRD(RD);
713         VectorFst<Arc> M(DRD);
714         CHECK_EQ(n + 1, M.NumStates());  // Accounts for the epsilon transition
715                                          // to the initial state
716       }
717     }
718 
719     if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
720       VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
721       VectorFst<Arc> R(A), D;
722       RmEpsilon(&R);
723       Disambiguate(R, &D);
724       CHECK(Equiv(R, D));
725       VLOG(1) << "Check disambiguated FSA is unambiguous";
726       CHECK(Unambiguous(D));
727 
728       if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
729         VLOG(1)  << "Check pruning in disambiguation";
730         VectorFst<Arc> P;
731         Weight threshold = (*weight_generator_)();
732         DisambiguateOptions<Arc> opts;
733         opts.weight_threshold = threshold;
734         Disambiguate(R, &P, opts);
735         CHECK(PruneEquiv(A, P, threshold));
736         CHECK(Unambiguous(P));
737       }
738     }
739 
740     if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) {
741       VLOG(1) << "Check reweight(T) equiv T";
742       vector<Weight> potential;
743       VectorFst<Arc> RI(T);
744       VectorFst<Arc> RF(T);
745       while (potential.size() < RI.NumStates())
746         potential.push_back((*weight_generator_)());
747 
748       Reweight(&RI, potential, REWEIGHT_TO_INITIAL);
749       CHECK(Equiv(T, RI));
750 
751       Reweight(&RF, potential, REWEIGHT_TO_FINAL);
752       CHECK(Equiv(T, RF));
753     }
754 
755     if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
756       VLOG(1) << "Check pushed FST is equivalent to input FST.";
757       // Pushing towards the final state.
758       if (wprops & kRightSemiring) {
759         VectorFst<Arc> P1;
760         Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels);
761         CHECK(Equiv(T, P1));
762 
763         VectorFst<Arc> P2;
764         Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights);
765         CHECK(Equiv(T, P2));
766 
767         VectorFst<Arc> P3;
768         Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights);
769         CHECK(Equiv(T, P3));
770       }
771 
772       // Pushing towards the initial state.
773       if (wprops & kLeftSemiring) {
774         VectorFst<Arc> P1;
775         Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels);
776         CHECK(Equiv(T, P1));
777 
778         VectorFst<Arc> P2;
779         Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights);
780         CHECK(Equiv(T, P2));
781         VectorFst<Arc> P3;
782         Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights);
783         CHECK(Equiv(T, P3));
784       }
785     }
786 
787     if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
788       VLOG(1) << "Check pruning algorithm";
789       {
790         VLOG(1) << "Check equiv. of constructive and destructive algorithms";
791         Weight thresold = (*weight_generator_)();
792         VectorFst<Arc> P1(T);
793         Prune(&P1, thresold);
794         VectorFst<Arc> P2;
795         Prune(T, &P2, thresold);
796         CHECK(Equiv(P1,P2));
797       }
798 
799       {
800         VLOG(1) << "Check prune(reverse) equiv reverse(prune)";
801         Weight thresold = (*weight_generator_)();
802         VectorFst< ReverseArc<Arc> > R;
803         VectorFst<Arc> P1(T);
804         VectorFst<Arc> P2;
805         Prune(&P1, thresold);
806         Reverse(T, &R);
807         Prune(&R, thresold.Reverse());
808         Reverse(R, &P2);
809         CHECK(Equiv(P1, P2));
810       }
811       {
812         VLOG(1) << "Check: ShortestDistance(A - prune(A))"
813                 << " > ShortestDistance(A) times Threshold";
814         Weight threshold = (*weight_generator_)();
815         VectorFst<Arc> P;
816         Prune(A, &P, threshold);
817         CHECK(PruneEquiv(A, P, threshold));
818       }
819     }
820     if (tprops & kAcyclic) {
821       VLOG(1) << "Check synchronize(T) equiv T";
822       SynchronizeFst<Arc> S(T);
823       CHECK(Equiv(T, S));
824     }
825   }
826 
827   // Tests search operations
TestSearch(const Fst<Arc> & T)828   void TestSearch(const Fst<Arc> &T) {
829     uint64 wprops = Weight::Properties();
830 
831     VectorFst<Arc> A(T);
832     Project(&A, PROJECT_INPUT);
833 
834     if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) {
835       VLOG(1) << "Check 1-best weight.";
836       VectorFst<Arc> path;
837       ShortestPath(T, &path);
838       Weight tsum = ShortestDistance(T);
839       Weight psum = ShortestDistance(path);
840       CHECK(ApproxEqual(tsum, psum, kTestDelta));
841     }
842 
843     if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) {
844       VLOG(1) << "Check n-best weights";
845       VectorFst<Arc> R(A);
846       RmEpsilon(&R);
847       int nshortest = rand() % kNumRandomShortestPaths + 2;
848       VectorFst<Arc> paths;
849       ShortestPath(R, &paths, nshortest, true, false,
850                    Weight::Zero(), kNumShortestStates);
851       vector<Weight> distance;
852       ShortestDistance(paths, &distance, true);
853       StateId pstart = paths.Start();
854       if (pstart != kNoStateId) {
855         ArcIterator< Fst<Arc> > piter(paths, pstart);
856         for (; !piter.Done(); piter.Next()) {
857           StateId s = piter.Value().nextstate;
858           Weight nsum = s < distance.size() ?
859               Times(piter.Value().weight, distance[s]) : Weight::Zero();
860           VectorFst<Arc> path;
861           ShortestPath(R, &path);
862           Weight dsum = ShortestDistance(path);
863           CHECK(ApproxEqual(nsum, dsum, kTestDelta));
864           ArcMap(&path, RmWeightMapper<Arc>());
865           VectorFst<Arc> S;
866           Difference(R, path, &S);
867           R = S;
868         }
869       }
870     }
871   }
872 
873   // Tests if two FSTS are equivalent by checking if random
874   // strings from one FST are transduced the same by both FSTs.
875   template <class A>
Equiv(const Fst<A> & fst1,const Fst<A> & fst2)876   bool Equiv(const Fst<A> &fst1, const Fst<A> &fst2) {
877     VLOG(1) << "Check FSTs for sanity (including property bits).";
878     CHECK(Verify(fst1));
879     CHECK(Verify(fst2));
880 
881     UniformArcSelector<A> uniform_selector(seed_);
882     RandGenOptions< UniformArcSelector<A> >
883         opts(uniform_selector, kRandomPathLength);
884     return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts);
885   }
886 
887   // Tests FSA is unambiguous
Unambiguous(const Fst<Arc> & fst)888   bool Unambiguous(const Fst<Arc> &fst) {
889     VectorFst<StdArc> sfst, dfst;
890     VectorFst<LogArc> lfst1, lfst2;
891     Map(fst, &sfst, RmWeightMapper<Arc, StdArc>());
892     Determinize(sfst, &dfst);
893     Map(fst, &lfst1, RmWeightMapper<Arc, LogArc>());
894     Map(dfst, &lfst2, RmWeightMapper<StdArc, LogArc>());
895     return Equiv(lfst1, lfst2);
896   }
897 
898   // Tests ShortestDistance(A - P) >
899   // ShortestDistance(A) times Threshold.
900   template <class A>
PruneEquiv(const Fst<A> & fst,const Fst<A> & pfst,Weight threshold)901   bool PruneEquiv(const Fst<A> &fst, const Fst<A> &pfst,
902                    Weight threshold) {
903     VLOG(1) << "Check FSTs for sanity (including property bits).";
904     CHECK(Verify(fst));
905     CHECK(Verify(pfst));
906 
907     DifferenceFst<Arc> D(fst, DeterminizeFst<Arc>
908                          (RmEpsilonFst<Arc>
909                           (ArcMapFst<Arc, Arc,
910                                      RmWeightMapper<Arc> >
911                            (pfst, RmWeightMapper<Arc>()))));
912     Weight sum1 = Times(ShortestDistance(fst), threshold);
913     Weight sum2 = ShortestDistance(D);
914     return Plus(sum1, sum2) == sum1;
915   }
916 
917   // Random seed
918   int seed_;
919 
920   // FST with no states
921   VectorFst<Arc> zero_fst_;
922 
923   // FST with one state that accepts epsilon.
924   VectorFst<Arc> one_fst_;
925 
926   // FST with one state that accepts all strings.
927   VectorFst<Arc> univ_fst_;
928 
929   // Generates weights used in testing.
930   WeightGenerator *weight_generator_;
931 
932   // Maximum random path length.
933   static const int kRandomPathLength;
934 
935   // Number of random paths to explore.
936   static const int kNumRandomPaths;
937 
938   // Maximum number of nshortest paths.
939   static const int kNumRandomShortestPaths;
940 
941   // Maximum number of nshortest states.
942   static const int kNumShortestStates;
943 
944   // Delta for equivalence tests.
945   static const float kTestDelta;
946 
947   DISALLOW_COPY_AND_ASSIGN(WeightedTester);
948 };
949 
950 
951 template <class A, class WG>
952 const int WeightedTester<A, WG>::kRandomPathLength = 25;
953 
954 template <class A, class WG>
955 const int WeightedTester<A, WG>::kNumRandomPaths = 100;
956 
957 template <class A, class WG>
958 const int WeightedTester<A, WG>::kNumRandomShortestPaths = 100;
959 
960 template <class A, class WG>
961 const int WeightedTester<A, WG>::kNumShortestStates = 10000;
962 
963 template <class A, class WG>
964 const float WeightedTester<A, WG>::kTestDelta = .05;
965 
966 // This class tests a variety of identities and properties that must
967 // hold for various algorithms on unweighted FSAs and that are not tested
968 // by WeightedTester. Only the specialization does anything interesting.
969 template <class Arc>
970 class UnweightedTester {
971  public:
UnweightedTester(const Fst<Arc> & zero_fsa,const Fst<Arc> & one_fsa,const Fst<Arc> & univ_fsa)972   UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
973                    const Fst<Arc> &univ_fsa) {}
974 
Test(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)975   void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {}
976 };
977 
978 
979 // Specialization for StdArc. This should work for any commutative,
980 // idempotent semiring when restricted to the unweighted case
981 // (being isomorphic to the boolean semiring).
982 template <>
983 class UnweightedTester<StdArc> {
984  public:
985   typedef StdArc Arc;
986   typedef Arc::Label Label;
987   typedef Arc::StateId StateId;
988   typedef Arc::Weight Weight;
989 
UnweightedTester(const Fst<Arc> & zero_fsa,const Fst<Arc> & one_fsa,const Fst<Arc> & univ_fsa)990   UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
991                    const Fst<Arc> &univ_fsa)
992       : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {}
993 
Test(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)994   void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {
995     TestRational(A1, A2, A3);
996     TestIntersect(A1, A2, A3);
997     TestOptimize(A1);
998   }
999 
1000  private:
1001   // Tests rational operations with identities
TestRational(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)1002   void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2,
1003                     const Fst<Arc> &A3) {
1004 
1005     {
1006       VLOG(1) << "Check the union contains its arguments (destructive).";
1007       VectorFst<Arc> U(A1);
1008       Union(&U, A2);
1009 
1010       CHECK(Subset(A1, U));
1011       CHECK(Subset(A2, U));
1012     }
1013 
1014     {
1015       VLOG(1) << "Check the union contains its arguments (delayed).";
1016       UnionFst<Arc> U(A1, A2);
1017 
1018       CHECK(Subset(A1, U));
1019       CHECK(Subset(A2, U));
1020     }
1021 
1022     {
1023       VLOG(1) << "Check if A^n c A* (destructive).";
1024       VectorFst<Arc> C(one_fsa_);
1025       int n = rand() % 5;
1026       for (int i = 0; i < n; ++i)
1027         Concat(&C, A1);
1028 
1029       VectorFst<Arc> S(A1);
1030       Closure(&S, CLOSURE_STAR);
1031       CHECK(Subset(C, S));
1032     }
1033 
1034     {
1035       VLOG(1) << "Check if A^n c A* (delayed).";
1036       int n = rand() % 5;
1037       Fst<Arc> *C = new VectorFst<Arc>(one_fsa_);
1038       for (int i = 0; i < n; ++i) {
1039         ConcatFst<Arc> *F = new ConcatFst<Arc>(*C, A1);
1040         delete C;
1041         C = F;
1042       }
1043       ClosureFst<Arc> S(A1, CLOSURE_STAR);
1044       CHECK(Subset(*C, S));
1045       delete C;
1046     }
1047   }
1048 
1049   // Tests intersect-based operations.
TestIntersect(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)1050   void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2,
1051                    const Fst<Arc> &A3) {
1052     VectorFst<Arc> S1(A1);
1053     VectorFst<Arc> S2(A2);
1054     VectorFst<Arc> S3(A3);
1055 
1056     ILabelCompare<Arc> comp;
1057 
1058     ArcSort(&S1, comp);
1059     ArcSort(&S2, comp);
1060     ArcSort(&S3, comp);
1061 
1062     {
1063       VLOG(1) << "Check the intersection is contained in its arguments.";
1064       IntersectFst<Arc> I1(S1, S2);
1065       CHECK(Subset(I1, S1));
1066       CHECK(Subset(I1, S2));
1067     }
1068 
1069     {
1070       VLOG(1) << "Check union distributes over intersection.";
1071       IntersectFst<Arc> I1(S1, S2);
1072       UnionFst<Arc> U1(I1, S3);
1073 
1074       UnionFst<Arc> U2(S1, S3);
1075       UnionFst<Arc> U3(S2, S3);
1076       ArcSortFst< Arc, ILabelCompare<Arc> > S4(U3, comp);
1077       IntersectFst<Arc> I2(U2, S4);
1078 
1079       CHECK(Equiv(U1, I2));
1080     }
1081 
1082     VectorFst<Arc> C1;
1083     VectorFst<Arc> C2;
1084     Complement(S1, &C1);
1085     Complement(S2, &C2);
1086     ArcSort(&C1, comp);
1087     ArcSort(&C2, comp);
1088 
1089 
1090     {
1091       VLOG(1) << "Check S U S' = Sigma*";
1092       UnionFst<Arc> U(S1, C1);
1093       CHECK(Equiv(U, univ_fsa_));
1094     }
1095 
1096     {
1097       VLOG(1) << "Check S n S' = {}";
1098       IntersectFst<Arc> I(S1, C1);
1099       CHECK(Equiv(I, zero_fsa_));
1100     }
1101 
1102     {
1103       VLOG(1) << "Check (S1' U S2') == (S1 n S2)'";
1104       UnionFst<Arc> U(C1, C2);
1105 
1106       IntersectFst<Arc> I(S1, S2);
1107       VectorFst<Arc> C3;
1108       Complement(I, &C3);
1109       CHECK(Equiv(U, C3));
1110     }
1111 
1112     {
1113       VLOG(1) << "Check (S1' n S2') == (S1 U S2)'";
1114       IntersectFst<Arc> I(C1, C2);
1115 
1116       UnionFst<Arc> U(S1, S2);
1117       VectorFst<Arc> C3;
1118       Complement(U, &C3);
1119       CHECK(Equiv(I, C3));
1120     }
1121   }
1122 
1123   // Tests optimization operations
TestOptimize(const Fst<Arc> & A)1124   void TestOptimize(const Fst<Arc> &A) {
1125     {
1126       VLOG(1) << "Check determinized FSA is equivalent to its input.";
1127       DeterminizeFst<Arc> D(A);
1128       CHECK(Equiv(A, D));
1129     }
1130 
1131     {
1132       VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
1133       VectorFst<Arc> R(A), D;
1134       RmEpsilon(&R);
1135 
1136       Disambiguate(R, &D);
1137       CHECK(Equiv(R, D));
1138     }
1139 
1140     {
1141       VLOG(1) << "Check minimized FSA is equivalent to its input.";
1142       int n;
1143       {
1144         RmEpsilonFst<Arc> R(A);
1145         DeterminizeFst<Arc> D(R);
1146         VectorFst<Arc> M(D);
1147         Minimize(&M);
1148         CHECK(Equiv(A, M));
1149         n = M.NumStates();
1150       }
1151 
1152       if (n) {  // Skip test if A is the empty machine
1153         VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the"
1154                 << " same number of states as Brozozowski's algorithm";
1155         VectorFst<Arc> R;
1156         Reverse(A, &R);
1157         RmEpsilon(&R);
1158         DeterminizeFst<Arc> DR(R);
1159         VectorFst<Arc> RD;
1160         Reverse(DR, &RD);
1161         DeterminizeFst<Arc> DRD(RD);
1162         VectorFst<Arc> M(DRD);
1163         CHECK_EQ(n + 1, M.NumStates());  // Accounts for the epsilon transition
1164                                          // to the initial state
1165       }
1166     }
1167   }
1168 
1169   // Tests if two FSAS are equivalent.
Equiv(const Fst<Arc> & fsa1,const Fst<Arc> & fsa2)1170   bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1171     VLOG(1) << "Check FSAs for sanity (including property bits).";
1172     CHECK(Verify(fsa1));
1173     CHECK(Verify(fsa2));
1174 
1175     VectorFst<Arc> vfsa1(fsa1);
1176     VectorFst<Arc> vfsa2(fsa2);
1177     RmEpsilon(&vfsa1);
1178     RmEpsilon(&vfsa2);
1179     DeterminizeFst<Arc> dfa1(vfsa1);
1180     DeterminizeFst<Arc> dfa2(vfsa2);
1181 
1182     // Test equivalence using union-find algorithm
1183     bool equiv1 = Equivalent(dfa1, dfa2);
1184 
1185     // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty
1186     ILabelCompare<Arc> comp;
1187     VectorFst<Arc> sdfa1(dfa1);
1188     ArcSort(&sdfa1, comp);
1189     VectorFst<Arc> sdfa2(dfa2);
1190     ArcSort(&sdfa2, comp);
1191 
1192     DifferenceFst<Arc> dfsa1(sdfa1, sdfa2);
1193     DifferenceFst<Arc> dfsa2(sdfa2, sdfa1);
1194 
1195     VectorFst<Arc> ufsa(dfsa1);
1196     Union(&ufsa, dfsa2);
1197     Connect(&ufsa);
1198     bool equiv2 = ufsa.NumStates() == 0;
1199 
1200     // Check two equivalence tests match
1201     CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1202 
1203     return equiv1;
1204   }
1205 
1206   // Tests if FSA1 is a subset of FSA2 (disregarding weights).
Subset(const Fst<Arc> & fsa1,const Fst<Arc> & fsa2)1207   bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1208     VLOG(1) << "Check FSAs (incl. property bits) for sanity";
1209     CHECK(Verify(fsa1));
1210     CHECK(Verify(fsa2));
1211 
1212     VectorFst<StdArc> vfsa1;
1213     VectorFst<StdArc> vfsa2;
1214     RmEpsilon(&vfsa1);
1215     RmEpsilon(&vfsa2);
1216     ILabelCompare<StdArc> comp;
1217     ArcSort(&vfsa1, comp);
1218     ArcSort(&vfsa2, comp);
1219     IntersectFst<StdArc> ifsa(vfsa1, vfsa2);
1220     DeterminizeFst<StdArc> dfa1(vfsa1);
1221     DeterminizeFst<StdArc> dfa2(ifsa);
1222     return Equivalent(dfa1, dfa2);
1223   }
1224 
1225   // Returns complement Fsa
Complement(const Fst<Arc> & ifsa,MutableFst<Arc> * ofsa)1226   void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) {
1227     RmEpsilonFst<Arc> rfsa(ifsa);
1228     DeterminizeFst<Arc> dfa(rfsa);
1229     DifferenceFst<Arc> cfsa(univ_fsa_, dfa);
1230     *ofsa = cfsa;
1231   }
1232 
1233   // FSA with no states
1234   VectorFst<Arc> zero_fsa_;
1235 
1236   // FSA with one state that accepts epsilon.
1237   VectorFst<Arc> one_fsa_;
1238 
1239   // FSA with one state that accepts all strings.
1240   VectorFst<Arc> univ_fsa_;
1241 
1242   DISALLOW_COPY_AND_ASSIGN(UnweightedTester);
1243 };
1244 
1245 
1246 // This class tests a variety of identities and properties that must
1247 // hold for various FST algorithms. It randomly generates FSTs, using
1248 // function object 'weight_generator' to select weights. 'WeightTester'
1249 // and 'UnweightedTester' are then called.
1250 template <class Arc, class WeightGenerator>
1251 class AlgoTester {
1252  public:
1253   typedef typename Arc::Label Label;
1254   typedef typename Arc::StateId StateId;
1255   typedef typename Arc::Weight Weight;
1256 
AlgoTester(WeightGenerator generator,int seed)1257   AlgoTester(WeightGenerator generator, int seed) :
1258       weight_generator_(generator), seed_(seed) {
1259       one_fst_.AddState();
1260       one_fst_.SetStart(0);
1261       one_fst_.SetFinal(0, Weight::One());
1262 
1263       univ_fst_.AddState();
1264       univ_fst_.SetStart(0);
1265       univ_fst_.SetFinal(0, Weight::One());
1266       for (int i = 0; i < kNumRandomLabels; ++i)
1267         univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0));
1268   }
1269 
Test()1270   void Test() {
1271     VLOG(1) << "weight type = " << Weight::Type();
1272 
1273     for (int i = 0; i < FLAGS_repeat; ++i) {
1274       // Random transducers
1275       VectorFst<Arc> T1;
1276       VectorFst<Arc> T2;
1277       VectorFst<Arc> T3;
1278       RandFst(&T1);
1279       RandFst(&T2);
1280       RandFst(&T3);
1281       WeightedTester<Arc, WeightGenerator>
1282         weighted_tester(seed_, zero_fst_, one_fst_,
1283                         univ_fst_, &weight_generator_);
1284       weighted_tester.Test(T1, T2, T3);
1285 
1286       VectorFst<Arc> A1(T1);
1287       VectorFst<Arc> A2(T2);
1288       VectorFst<Arc> A3(T3);
1289       Project(&A1, PROJECT_OUTPUT);
1290       Project(&A2, PROJECT_INPUT);
1291       Project(&A3, PROJECT_INPUT);
1292       ArcMap(&A1, rm_weight_mapper);
1293       ArcMap(&A2, rm_weight_mapper);
1294       ArcMap(&A3, rm_weight_mapper);
1295       UnweightedTester<Arc> unweighted_tester(zero_fst_, one_fst_, univ_fst_);
1296       unweighted_tester.Test(A1, A2, A3);
1297     }
1298   }
1299 
1300  private:
1301   // Generates a random FST.
RandFst(MutableFst<Arc> * fst)1302   void RandFst(MutableFst<Arc> *fst) {
1303     // Determines direction of the arcs wrt state numbering. This way we
1304     // can force acyclicity when desired.
1305     enum ArcDirection { ANY_DIRECTION = 0, FORWARD_DIRECTION = 1,
1306                         REVERSE_DIRECTION = 2, NUM_DIRECTIONS = 3 };
1307 
1308     ArcDirection arc_direction = ANY_DIRECTION;
1309     if (rand()/(RAND_MAX  + 1.0) < kAcyclicProb)
1310       arc_direction =  rand() % 2 ? FORWARD_DIRECTION : REVERSE_DIRECTION;
1311 
1312     fst->DeleteStates();
1313     StateId ns = rand() % kNumRandomStates;
1314 
1315     if (ns == 0)
1316       return;
1317     for (StateId s = 0; s < ns; ++s)
1318       fst->AddState();
1319 
1320     StateId start = rand() % ns;
1321     fst->SetStart(start);
1322 
1323     size_t na = rand() % kNumRandomArcs;
1324     for (size_t n = 0; n < na; ++n) {
1325       StateId s = rand() % ns;
1326       Arc arc;
1327       arc.ilabel = rand() % kNumRandomLabels;
1328       arc.olabel = rand() % kNumRandomLabels;
1329       arc.weight = weight_generator_();
1330       arc.nextstate = rand() % ns;
1331 
1332       if (arc_direction == ANY_DIRECTION ||
1333           (arc_direction == FORWARD_DIRECTION && arc.ilabel > arc.olabel) ||
1334           (arc_direction == REVERSE_DIRECTION && arc.ilabel < arc.olabel))
1335         fst->AddArc(s, arc);
1336     }
1337 
1338     StateId nf = rand() % (ns + 1);
1339     for (StateId n = 0; n < nf; ++n) {
1340       StateId s = rand() % ns;
1341       Weight final = weight_generator_();
1342       fst->SetFinal(s, final);
1343     }
1344     VLOG(1) << "Check FST for sanity (including property bits).";
1345     CHECK(Verify(*fst));
1346 
1347     // Get/compute all properties.
1348     uint64 props = fst->Properties(kFstProperties, true);
1349 
1350     // Select random set of properties to be unknown.
1351     uint64 mask = 0;
1352     for (int n = 0; n < 8; ++n) {
1353       mask |= rand() & 0xff;
1354       mask <<= 8;
1355     }
1356     mask &= ~kTrinaryProperties;
1357     fst->SetProperties(props & ~mask, mask);
1358   }
1359 
1360   // Generates weights used in testing.
1361   WeightGenerator weight_generator_;
1362 
1363   // Random seed
1364   int seed_;
1365 
1366   // FST with no states
1367   VectorFst<Arc> zero_fst_;
1368 
1369   // FST with one state that accepts epsilon.
1370   VectorFst<Arc> one_fst_;
1371 
1372   // FST with one state that accepts all strings.
1373   VectorFst<Arc> univ_fst_;
1374 
1375   // Mapper to remove weights from an Fst
1376   RmWeightMapper<Arc> rm_weight_mapper;
1377 
1378   // Maximum number of states in random test Fst.
1379   static const int kNumRandomStates;
1380 
1381   // Maximum number of arcs in random test Fst.
1382   static const int kNumRandomArcs;
1383 
1384   // Number of alternative random labels.
1385   static const int kNumRandomLabels;
1386 
1387   // Probability to force an acyclic Fst
1388   static const float kAcyclicProb;
1389 
1390   // Maximum random path length.
1391   static const int kRandomPathLength;
1392 
1393   // Number of random paths to explore.
1394   static const int kNumRandomPaths;
1395 
1396   DISALLOW_COPY_AND_ASSIGN(AlgoTester);
1397 };
1398 
1399 template <class A, class G> const int AlgoTester<A, G>::kNumRandomStates = 10;
1400 
1401 template <class A, class G> const int AlgoTester<A, G>::kNumRandomArcs = 25;
1402 
1403 template <class A, class G> const int AlgoTester<A, G>::kNumRandomLabels = 5;
1404 
1405 template <class A, class G> const float AlgoTester<A, G>::kAcyclicProb = .25;
1406 
1407 template <class A, class G> const int AlgoTester<A, G>::kRandomPathLength = 25;
1408 
1409 template <class A, class G> const int AlgoTester<A, G>::kNumRandomPaths = 100;
1410 
1411 }  // namespace fst
1412 
1413 #endif  // FST_TEST_ALGO_TEST_H__
1414