1 // algo_test.h
2
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 //
15 // Copyright 2005-2010 Google, Inc.
16 // Author: riley@google.com (Michael Riley)
17 //
18 // \file
19 // Regression test for various FST algorithms.
20
21 #ifndef FST_TEST_ALGO_TEST_H__
22 #define FST_TEST_ALGO_TEST_H__
23
24 #include <fst/fstlib.h>
25 #include <fst/random-weight.h>
26
27 DECLARE_int32(repeat); // defined in ./algo_test.cc
28
29 namespace fst {
30
31 // Mapper to change input and output label of every transition into
32 // epsilons.
33 template <class A>
34 class EpsMapper {
35 public:
EpsMapper()36 EpsMapper() {}
37
operator()38 A operator()(const A &arc) const {
39 return A(0, 0, arc.weight, arc.nextstate);
40 }
41
Properties(uint64 props)42 uint64 Properties(uint64 props) const {
43 props &= ~kNotAcceptor;
44 props |= kAcceptor;
45 props &= ~kNoIEpsilons & ~kNoOEpsilons & ~kNoEpsilons;
46 props |= kIEpsilons | kOEpsilons | kEpsilons;
47 props &= ~kNotILabelSorted & ~kNotOLabelSorted;
48 props |= kILabelSorted | kOLabelSorted;
49 return props;
50 }
51
FinalAction()52 MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; }
53
54
InputSymbolsAction()55 MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
56
OutputSymbolsAction()57 MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS;}
58 };
59
60 // Generic - no lookahead.
61 template <class Arc>
LookAheadCompose(const Fst<Arc> & ifst1,const Fst<Arc> & ifst2,MutableFst<Arc> * ofst)62 void LookAheadCompose(const Fst<Arc> &ifst1, const Fst<Arc> &ifst2,
63 MutableFst<Arc> *ofst) {
64 Compose(ifst1, ifst2, ofst);
65 }
66
67 // Specialized and epsilon olabel acyclic - lookahead.
LookAheadCompose(const Fst<StdArc> & ifst1,const Fst<StdArc> & ifst2,MutableFst<StdArc> * ofst)68 void LookAheadCompose(const Fst<StdArc> &ifst1, const Fst<StdArc> &ifst2,
69 MutableFst<StdArc> *ofst) {
70 vector<StdArc::StateId> order;
71 bool acyclic;
72 TopOrderVisitor<StdArc> visitor(&order, &acyclic);
73 DfsVisit(ifst1, &visitor, OutputEpsilonArcFilter<StdArc>());
74 if (acyclic) { // no ifst1 output epsilon cycles?
75 StdOLabelLookAheadFst lfst1(ifst1);
76 StdVectorFst lfst2(ifst2);
77 LabelLookAheadRelabeler<StdArc>::Relabel(&lfst2, lfst1, true);
78 Compose(lfst1, lfst2, ofst);
79 } else {
80 Compose(ifst1, ifst2, ofst);
81 }
82 }
83
84 // This class tests a variety of identities and properties that must
85 // hold for various algorithms on weighted FSTs.
86 template <class Arc, class WeightGenerator>
87 class WeightedTester {
88 public:
89 typedef typename Arc::Label Label;
90 typedef typename Arc::StateId StateId;
91 typedef typename Arc::Weight Weight;
92
WeightedTester(int seed,const Fst<Arc> & zero_fst,const Fst<Arc> & one_fst,const Fst<Arc> & univ_fst,WeightGenerator * weight_generator)93 WeightedTester(int seed, const Fst<Arc> &zero_fst, const Fst<Arc> &one_fst,
94 const Fst<Arc> &univ_fst, WeightGenerator *weight_generator)
95 : seed_(seed), zero_fst_(zero_fst), one_fst_(one_fst),
96 univ_fst_(univ_fst), weight_generator_(weight_generator) {}
97
Test(const Fst<Arc> & T1,const Fst<Arc> & T2,const Fst<Arc> & T3)98 void Test(const Fst<Arc> &T1, const Fst<Arc> &T2, const Fst<Arc> &T3) {
99 TestRational(T1, T2, T3);
100 TestMap(T1);
101 TestCompose(T1, T2, T3);
102 TestSort(T1);
103 TestOptimize(T1);
104 TestSearch(T1);
105 }
106
107 private:
108 // Tests rational operations with identities
TestRational(const Fst<Arc> & T1,const Fst<Arc> & T2,const Fst<Arc> & T3)109 void TestRational(const Fst<Arc> &T1, const Fst<Arc> &T2,
110 const Fst<Arc> &T3) {
111
112 {
113 VLOG(1) << "Check destructive and delayed union are equivalent.";
114 VectorFst<Arc> U1(T1);
115 Union(&U1, T2);
116 UnionFst<Arc> U2(T1, T2);
117 CHECK(Equiv(U1, U2));
118 }
119
120
121 {
122 VLOG(1) << "Check destructive and delayed concatenation are equivalent.";
123 VectorFst<Arc> C1(T1);
124 Concat(&C1, T2);
125 ConcatFst<Arc> C2(T1, T2);
126 CHECK(Equiv(C1, C2));
127 VectorFst<Arc> C3(T2);
128 Concat(T1, &C3);
129 CHECK(Equiv(C3, C2));
130 }
131
132 {
133 VLOG(1) << "Check destructive and delayed closure* are equivalent.";
134 VectorFst<Arc> C1(T1);
135 Closure(&C1, CLOSURE_STAR);
136 ClosureFst<Arc> C2(T1, CLOSURE_STAR);
137 CHECK(Equiv(C1, C2));
138 }
139
140 {
141 VLOG(1) << "Check destructive and delayed closure+ are equivalent.";
142 VectorFst<Arc> C1(T1);
143 Closure(&C1, CLOSURE_PLUS);
144 ClosureFst<Arc> C2(T1, CLOSURE_PLUS);
145 CHECK(Equiv(C1, C2));
146 }
147
148 {
149 VLOG(1) << "Check union is associative (destructive).";
150 VectorFst<Arc> U1(T1);
151 Union(&U1, T2);
152 Union(&U1, T3);
153
154 VectorFst<Arc> U3(T2);
155 Union(&U3, T3);
156 VectorFst<Arc> U4(T1);
157 Union(&U4, U3);
158
159 CHECK(Equiv(U1, U4));
160 }
161
162 {
163 VLOG(1) << "Check union is associative (delayed).";
164 UnionFst<Arc> U1(T1, T2);
165 UnionFst<Arc> U2(U1, T3);
166
167 UnionFst<Arc> U3(T2, T3);
168 UnionFst<Arc> U4(T1, U3);
169
170 CHECK(Equiv(U2, U4));
171 }
172
173
174 {
175 VLOG(1) << "Check union is associative (destructive delayed).";
176 UnionFst<Arc> U1(T1, T2);
177 Union(&U1, T3);
178
179 UnionFst<Arc> U3(T2, T3);
180 UnionFst<Arc> U4(T1, U3);
181
182 CHECK(Equiv(U1, U4));
183 }
184
185 {
186 VLOG(1) << "Check concatenation is associative (destructive).";
187 VectorFst<Arc> C1(T1);
188 Concat(&C1, T2);
189 Concat(&C1, T3);
190
191 VectorFst<Arc> C3(T2);
192 Concat(&C3, T3);
193 VectorFst<Arc> C4(T1);
194 Concat(&C4, C3);
195
196 CHECK(Equiv(C1, C4));
197 }
198
199 {
200 VLOG(1) << "Check concatenation is associative (delayed).";
201 ConcatFst<Arc> C1(T1, T2);
202 ConcatFst<Arc> C2(C1, T3);
203
204 ConcatFst<Arc> C3(T2, T3);
205 ConcatFst<Arc> C4(T1, C3);
206
207 CHECK(Equiv(C2, C4));
208 }
209
210 {
211 VLOG(1) << "Check concatenation is associative (destructive delayed).";
212 ConcatFst<Arc> C1(T1, T2);
213 Concat(&C1, T3);
214
215 ConcatFst<Arc> C3(T2, T3);
216 ConcatFst<Arc> C4(T1, C3);
217
218 CHECK(Equiv(C1, C4));
219 }
220
221 if (Weight::Properties() & kLeftSemiring) {
222 VLOG(1) << "Check concatenation left distributes"
223 << " over union (destructive).";
224
225 VectorFst<Arc> U1(T1);
226 Union(&U1, T2);
227 VectorFst<Arc> C1(T3);
228 Concat(&C1, U1);
229
230 VectorFst<Arc> C2(T3);
231 Concat(&C2, T1);
232 VectorFst<Arc> C3(T3);
233 Concat(&C3, T2);
234 VectorFst<Arc> U2(C2);
235 Union(&U2, C3);
236
237 CHECK(Equiv(C1, U2));
238 }
239
240 if (Weight::Properties() & kRightSemiring) {
241 VLOG(1) << "Check concatenation right distributes"
242 << " over union (destructive).";
243 VectorFst<Arc> U1(T1);
244 Union(&U1, T2);
245 VectorFst<Arc> C1(U1);
246 Concat(&C1, T3);
247
248 VectorFst<Arc> C2(T1);
249 Concat(&C2, T3);
250 VectorFst<Arc> C3(T2);
251 Concat(&C3, T3);
252 VectorFst<Arc> U2(C2);
253 Union(&U2, C3);
254
255 CHECK(Equiv(C1, U2));
256 }
257
258 if (Weight::Properties() & kLeftSemiring) {
259 VLOG(1) << "Check concatenation left distributes over union (delayed).";
260 UnionFst<Arc> U1(T1, T2);
261 ConcatFst<Arc> C1(T3, U1);
262
263 ConcatFst<Arc> C2(T3, T1);
264 ConcatFst<Arc> C3(T3, T2);
265 UnionFst<Arc> U2(C2, C3);
266
267 CHECK(Equiv(C1, U2));
268 }
269
270 if (Weight::Properties() & kRightSemiring) {
271 VLOG(1) << "Check concatenation right distributes over union (delayed).";
272 UnionFst<Arc> U1(T1, T2);
273 ConcatFst<Arc> C1(U1, T3);
274
275 ConcatFst<Arc> C2(T1, T3);
276 ConcatFst<Arc> C3(T2, T3);
277 UnionFst<Arc> U2(C2, C3);
278
279 CHECK(Equiv(C1, U2));
280 }
281
282
283 if (Weight::Properties() & kLeftSemiring) {
284 VLOG(1) << "Check T T* == T+ (destructive).";
285 VectorFst<Arc> S(T1);
286 Closure(&S, CLOSURE_STAR);
287 VectorFst<Arc> C(T1);
288 Concat(&C, S);
289
290 VectorFst<Arc> P(T1);
291 Closure(&P, CLOSURE_PLUS);
292
293 CHECK(Equiv(C, P));
294 }
295
296
297 if (Weight::Properties() & kRightSemiring) {
298 VLOG(1) << "Check T* T == T+ (destructive).";
299 VectorFst<Arc> S(T1);
300 Closure(&S, CLOSURE_STAR);
301 VectorFst<Arc> C(S);
302 Concat(&C, T1);
303
304 VectorFst<Arc> P(T1);
305 Closure(&P, CLOSURE_PLUS);
306
307 CHECK(Equiv(C, P));
308 }
309
310 if (Weight::Properties() & kLeftSemiring) {
311 VLOG(1) << "Check T T* == T+ (delayed).";
312 ClosureFst<Arc> S(T1, CLOSURE_STAR);
313 ConcatFst<Arc> C(T1, S);
314
315 ClosureFst<Arc> P(T1, CLOSURE_PLUS);
316
317 CHECK(Equiv(C, P));
318 }
319
320 if (Weight::Properties() & kRightSemiring) {
321 VLOG(1) << "Check T* T == T+ (delayed).";
322 ClosureFst<Arc> S(T1, CLOSURE_STAR);
323 ConcatFst<Arc> C(S, T1);
324
325 ClosureFst<Arc> P(T1, CLOSURE_PLUS);
326
327 CHECK(Equiv(C, P));
328 }
329 }
330
331 // Tests map-based operations.
TestMap(const Fst<Arc> & T)332 void TestMap(const Fst<Arc> &T) {
333
334 {
335 VLOG(1) << "Check destructive and delayed projection are equivalent.";
336 VectorFst<Arc> P1(T);
337 Project(&P1, PROJECT_INPUT);
338 ProjectFst<Arc> P2(T, PROJECT_INPUT);
339 CHECK(Equiv(P1, P2));
340 }
341
342
343 {
344 VLOG(1) << "Check destructive and delayed inversion are equivalent.";
345 VectorFst<Arc> I1(T);
346 Invert(&I1);
347 InvertFst<Arc> I2(T);
348 CHECK(Equiv(I1, I2));
349 }
350
351 {
352 VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (destructive).";
353 VectorFst<Arc> P1(T);
354 VectorFst<Arc> I1(T);
355 Project(&P1, PROJECT_INPUT);
356 Invert(&I1);
357 Project(&I1, PROJECT_OUTPUT);
358 CHECK(Equiv(P1, I1));
359 }
360
361 {
362 VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (destructive).";
363 VectorFst<Arc> P1(T);
364 VectorFst<Arc> I1(T);
365 Project(&P1, PROJECT_OUTPUT);
366 Invert(&I1);
367 Project(&I1, PROJECT_INPUT);
368 CHECK(Equiv(P1, I1));
369 }
370
371 {
372 VLOG(1) << "Check Pi_1(T) = Pi_2(T^-1) (delayed).";
373 ProjectFst<Arc> P1(T, PROJECT_INPUT);
374 InvertFst<Arc> I1(T);
375 ProjectFst<Arc> P2(I1, PROJECT_OUTPUT);
376 CHECK(Equiv(P1, P2));
377 }
378
379
380 {
381 VLOG(1) << "Check Pi_2(T) = Pi_1(T^-1) (delayed).";
382 ProjectFst<Arc> P1(T, PROJECT_OUTPUT);
383 InvertFst<Arc> I1(T);
384 ProjectFst<Arc> P2(I1, PROJECT_INPUT);
385 CHECK(Equiv(P1, P2));
386 }
387
388
389 {
390 VLOG(1) << "Check destructive relabeling";
391 static const int kNumLabels = 10;
392 // set up relabeling pairs
393 vector<Label> labelset(kNumLabels);
394 for (size_t i = 0; i < kNumLabels; ++i) labelset[i] = i;
395 for (size_t i = 0; i < kNumLabels; ++i) {
396 swap(labelset[i], labelset[rand() % kNumLabels]);
397 }
398
399 vector<pair<Label, Label> > ipairs1(kNumLabels);
400 vector<pair<Label, Label> > opairs1(kNumLabels);
401 for (size_t i = 0; i < kNumLabels; ++i) {
402 ipairs1[i] = make_pair(i, labelset[i]);
403 opairs1[i] = make_pair(labelset[i], i);
404 }
405 VectorFst<Arc> R(T);
406 Relabel(&R, ipairs1, opairs1);
407
408 vector<pair<Label, Label> > ipairs2(kNumLabels);
409 vector<pair<Label, Label> > opairs2(kNumLabels);
410 for (size_t i = 0; i < kNumLabels; ++i) {
411 ipairs2[i] = make_pair(labelset[i], i);
412 opairs2[i] = make_pair(i, labelset[i]);
413 }
414 Relabel(&R, ipairs2, opairs2);
415 CHECK(Equiv(R, T));
416
417 VLOG(1) << "Check on-the-fly relabeling";
418 RelabelFst<Arc> Rdelay(T, ipairs1, opairs1);
419
420 RelabelFst<Arc> RRdelay(Rdelay, ipairs2, opairs2);
421 CHECK(Equiv(RRdelay, T));
422 }
423
424 {
425 VLOG(1) << "Check encoding/decoding (destructive).";
426 VectorFst<Arc> D(T);
427 uint32 encode_props = 0;
428 if (rand() % 2)
429 encode_props |= kEncodeLabels;
430 if (rand() % 2)
431 encode_props |= kEncodeWeights;
432 EncodeMapper<Arc> encoder(encode_props, ENCODE);
433 Encode(&D, &encoder);
434 Decode(&D, encoder);
435 CHECK(Equiv(D, T));
436 }
437
438 {
439 VLOG(1) << "Check encoding/decoding (delayed).";
440 uint32 encode_props = 0;
441 if (rand() % 2)
442 encode_props |= kEncodeLabels;
443 if (rand() % 2)
444 encode_props |= kEncodeWeights;
445 EncodeMapper<Arc> encoder(encode_props, ENCODE);
446 EncodeFst<Arc> E(T, &encoder);
447 VectorFst<Arc> Encoded(E);
448 DecodeFst<Arc> D(Encoded, encoder);
449 CHECK(Equiv(D, T));
450 }
451
452 {
453 VLOG(1) << "Check gallic mappers (constructive).";
454 ToGallicMapper<Arc> to_mapper;
455 FromGallicMapper<Arc> from_mapper;
456 VectorFst< GallicArc<Arc> > G;
457 VectorFst<Arc> F;
458 ArcMap(T, &G, to_mapper);
459 ArcMap(G, &F, from_mapper);
460 CHECK(Equiv(T, F));
461 }
462
463 {
464 VLOG(1) << "Check gallic mappers (delayed).";
465 ToGallicMapper<Arc> to_mapper;
466 FromGallicMapper<Arc> from_mapper;
467 ArcMapFst<Arc, GallicArc<Arc>, ToGallicMapper<Arc> >
468 G(T, to_mapper);
469 ArcMapFst<GallicArc<Arc>, Arc, FromGallicMapper<Arc> >
470 F(G, from_mapper);
471 CHECK(Equiv(T, F));
472 }
473 }
474
475 // Tests compose-based operations.
TestCompose(const Fst<Arc> & T1,const Fst<Arc> & T2,const Fst<Arc> & T3)476 void TestCompose(const Fst<Arc> &T1, const Fst<Arc> &T2,
477 const Fst<Arc> &T3) {
478 if (!(Weight::Properties() & kCommutative))
479 return;
480
481 VectorFst<Arc> S1(T1);
482 VectorFst<Arc> S2(T2);
483 VectorFst<Arc> S3(T3);
484
485 ILabelCompare<Arc> icomp;
486 OLabelCompare<Arc> ocomp;
487
488 ArcSort(&S1, ocomp);
489 ArcSort(&S2, ocomp);
490 ArcSort(&S3, icomp);
491
492 {
493 VLOG(1) << "Check composition is associative.";
494 ComposeFst<Arc> C1(S1, S2);
495 ComposeFst<Arc> C2(C1, S3);
496 ComposeFst<Arc> C3(S2, S3);
497 ComposeFst<Arc> C4(S1, C3);
498
499 CHECK(Equiv(C2, C4));
500 }
501
502 {
503 VLOG(1) << "Check composition left distributes over union.";
504 UnionFst<Arc> U1(S2, S3);
505 ComposeFst<Arc> C1(S1, U1);
506
507 ComposeFst<Arc> C2(S1, S2);
508 ComposeFst<Arc> C3(S1, S3);
509 UnionFst<Arc> U2(C2, C3);
510
511 CHECK(Equiv(C1, U2));
512 }
513
514 {
515 VLOG(1) << "Check composition right distributes over union.";
516 UnionFst<Arc> U1(S1, S2);
517 ComposeFst<Arc> C1(U1, S3);
518
519 ComposeFst<Arc> C2(S1, S3);
520 ComposeFst<Arc> C3(S2, S3);
521 UnionFst<Arc> U2(C2, C3);
522
523 CHECK(Equiv(C1, U2));
524 }
525
526 VectorFst<Arc> A1(S1);
527 VectorFst<Arc> A2(S2);
528 VectorFst<Arc> A3(S3);
529 Project(&A1, PROJECT_OUTPUT);
530 Project(&A2, PROJECT_INPUT);
531 Project(&A3, PROJECT_INPUT);
532
533 {
534 VLOG(1) << "Check intersection is commutative.";
535 IntersectFst<Arc> I1(A1, A2);
536 IntersectFst<Arc> I2(A2, A1);
537 CHECK(Equiv(I1, I2));
538 }
539
540 {
541 VLOG(1) << "Check all epsilon filters leads to equivalent results.";
542 typedef Matcher< Fst<Arc> > M;
543 ComposeFst<Arc> C1(S1, S2);
544 ComposeFst<Arc> C2(
545 S1, S2,
546 ComposeFstOptions<Arc, M, AltSequenceComposeFilter<M> >());
547 ComposeFst<Arc> C3(
548 S1, S2,
549 ComposeFstOptions<Arc, M, MatchComposeFilter<M> >());
550
551 CHECK(Equiv(C1, C2));
552 CHECK(Equiv(C1, C3));
553 }
554
555 {
556 VLOG(1) << "Check look-ahead filters lead to equivalent results.";
557 VectorFst<Arc> C1, C2;
558 Compose(S1, S2, &C1);
559 LookAheadCompose(S1, S2, &C2);
560 CHECK(Equiv(C1, C2));
561 }
562 }
563
564 // Tests sorting operations
TestSort(const Fst<Arc> & T)565 void TestSort(const Fst<Arc> &T) {
566 ILabelCompare<Arc> icomp;
567 OLabelCompare<Arc> ocomp;
568
569 {
570 VLOG(1) << "Check arc sorted Fst is equivalent to its input.";
571 VectorFst<Arc> S1(T);
572 ArcSort(&S1, icomp);
573 CHECK(Equiv(T, S1));
574 }
575
576 {
577 VLOG(1) << "Check destructive and delayed arcsort are equivalent.";
578 VectorFst<Arc> S1(T);
579 ArcSort(&S1, icomp);
580 ArcSortFst< Arc, ILabelCompare<Arc> > S2(T, icomp);
581 CHECK(Equiv(S1, S2));
582 }
583
584 {
585 VLOG(1) << "Check ilabel sorting vs. olabel sorting with inversions.";
586 VectorFst<Arc> S1(T);
587 VectorFst<Arc> S2(T);
588 ArcSort(&S1, icomp);
589 Invert(&S2);
590 ArcSort(&S2, ocomp);
591 Invert(&S2);
592 CHECK(Equiv(S1, S2));
593 }
594
595 {
596 VLOG(1) << "Check topologically sorted Fst is equivalent to its input.";
597 VectorFst<Arc> S1(T);
598 TopSort(&S1);
599 CHECK(Equiv(T, S1));
600 }
601
602 {
603 VLOG(1) << "Check reverse(reverse(T)) = T";
604 for (int i = 0; i < 2; ++i) {
605 VectorFst< ReverseArc<Arc> > R1;
606 VectorFst<Arc> R2;
607 bool require_superinitial = i == 1;
608 Reverse(T, &R1, require_superinitial);
609 Reverse(R1, &R2, require_superinitial);
610 CHECK(Equiv(T, R2));
611 }
612 }
613 }
614
615 // Tests optimization operations
TestOptimize(const Fst<Arc> & T)616 void TestOptimize(const Fst<Arc> &T) {
617 uint64 tprops = T.Properties(kFstProperties, true);
618 uint64 wprops = Weight::Properties();
619
620 VectorFst<Arc> A(T);
621 Project(&A, PROJECT_INPUT);
622
623 {
624 VLOG(1) << "Check connected FST is equivalent to its input.";
625 VectorFst<Arc> C1(T);
626 Connect(&C1);
627 CHECK(Equiv(T, C1));
628 }
629
630 if ((wprops & kSemiring) == kSemiring &&
631 (tprops & kAcyclic || wprops & kIdempotent)) {
632 VLOG(1) << "Check epsilon-removed FST is equivalent to its input.";
633 VectorFst<Arc> R1(T);
634 RmEpsilon(&R1);
635 CHECK(Equiv(T, R1));
636
637 VLOG(1) << "Check destructive and delayed epsilon removal"
638 << "are equivalent.";
639 RmEpsilonFst<Arc> R2(T);
640 CHECK(Equiv(R1, R2));
641
642 VLOG(1) << "Check an FST with a large proportion"
643 << " of epsilon transitions:";
644 // Maps all transitions of T to epsilon-transitions and append
645 // a non-epsilon transition.
646 VectorFst<Arc> U;
647 ArcMap(T, &U, EpsMapper<Arc>());
648 VectorFst<Arc> V;
649 V.SetStart(V.AddState());
650 Arc arc(1, 1, Weight::One(), V.AddState());
651 V.AddArc(V.Start(), arc);
652 V.SetFinal(arc.nextstate, Weight::One());
653 Concat(&U, V);
654 // Check that epsilon-removal preserves the shortest-distance
655 // from the initial state to the final states.
656 vector<Weight> d;
657 ShortestDistance(U, &d, true);
658 Weight w = U.Start() < d.size() ? d[U.Start()] : Weight::Zero();
659 VectorFst<Arc> U1(U);
660 RmEpsilon(&U1);
661 ShortestDistance(U1, &d, true);
662 Weight w1 = U1.Start() < d.size() ? d[U1.Start()] : Weight::Zero();
663 CHECK(ApproxEqual(w, w1, kTestDelta));
664 RmEpsilonFst<Arc> U2(U);
665 ShortestDistance(U2, &d, true);
666 Weight w2 = U2.Start() < d.size() ? d[U2.Start()] : Weight::Zero();
667 CHECK(ApproxEqual(w, w2, kTestDelta));
668 }
669
670 if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
671 VLOG(1) << "Check determinized FSA is equivalent to its input.";
672 DeterminizeFst<Arc> D(A);
673 CHECK(Equiv(A, D));
674
675 if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
676 VLOG(1) << "Check pruning in determinization";
677 VectorFst<Arc> P;
678 Weight threshold = (*weight_generator_)();
679 DeterminizeOptions<Arc> opts;
680 opts.weight_threshold = threshold;
681 Determinize(A, &P, opts);
682 CHECK(P.Properties(kIDeterministic, true));
683 CHECK(PruneEquiv(A, P, threshold));
684 }
685
686 int n;
687 {
688 VLOG(1) << "Check size(min(det(A))) <= size(det(A))"
689 << " and min(det(A)) equiv det(A)";
690 VectorFst<Arc> M(D);
691 n = M.NumStates();
692 Minimize(&M);
693 CHECK(Equiv(D, M));
694 CHECK(M.NumStates() <= n);
695 n = M.NumStates();
696 }
697
698 if (n && (wprops & kIdempotent) == kIdempotent &&
699 A.Properties(kNoEpsilons, true)) {
700 VLOG(1) << "Check that Revuz's algorithm leads to the"
701 << " same number of states as Brozozowski's algorithm";
702
703 // Skip test if A is the empty machine or contains epsilons or
704 // if the semiring is not idempotent (to avoid floating point
705 // errors)
706 VectorFst<Arc> R;
707 Reverse(A, &R);
708 RmEpsilon(&R);
709 DeterminizeFst<Arc> DR(R);
710 VectorFst<Arc> RD;
711 Reverse(DR, &RD);
712 DeterminizeFst<Arc> DRD(RD);
713 VectorFst<Arc> M(DRD);
714 CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
715 // to the initial state
716 }
717 }
718
719 if ((wprops & kSemiring) == kSemiring && tprops & kAcyclic) {
720 VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
721 VectorFst<Arc> R(A), D;
722 RmEpsilon(&R);
723 Disambiguate(R, &D);
724 CHECK(Equiv(R, D));
725 VLOG(1) << "Check disambiguated FSA is unambiguous";
726 CHECK(Unambiguous(D));
727
728 if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
729 VLOG(1) << "Check pruning in disambiguation";
730 VectorFst<Arc> P;
731 Weight threshold = (*weight_generator_)();
732 DisambiguateOptions<Arc> opts;
733 opts.weight_threshold = threshold;
734 Disambiguate(R, &P, opts);
735 CHECK(PruneEquiv(A, P, threshold));
736 CHECK(Unambiguous(P));
737 }
738 }
739
740 if (Arc::Type() == LogArc::Type() || Arc::Type() == StdArc::Type()) {
741 VLOG(1) << "Check reweight(T) equiv T";
742 vector<Weight> potential;
743 VectorFst<Arc> RI(T);
744 VectorFst<Arc> RF(T);
745 while (potential.size() < RI.NumStates())
746 potential.push_back((*weight_generator_)());
747
748 Reweight(&RI, potential, REWEIGHT_TO_INITIAL);
749 CHECK(Equiv(T, RI));
750
751 Reweight(&RF, potential, REWEIGHT_TO_FINAL);
752 CHECK(Equiv(T, RF));
753 }
754
755 if ((wprops & kIdempotent) || (tprops & kAcyclic)) {
756 VLOG(1) << "Check pushed FST is equivalent to input FST.";
757 // Pushing towards the final state.
758 if (wprops & kRightSemiring) {
759 VectorFst<Arc> P1;
760 Push<Arc, REWEIGHT_TO_FINAL>(T, &P1, kPushLabels);
761 CHECK(Equiv(T, P1));
762
763 VectorFst<Arc> P2;
764 Push<Arc, REWEIGHT_TO_FINAL>(T, &P2, kPushWeights);
765 CHECK(Equiv(T, P2));
766
767 VectorFst<Arc> P3;
768 Push<Arc, REWEIGHT_TO_FINAL>(T, &P3, kPushLabels | kPushWeights);
769 CHECK(Equiv(T, P3));
770 }
771
772 // Pushing towards the initial state.
773 if (wprops & kLeftSemiring) {
774 VectorFst<Arc> P1;
775 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P1, kPushLabels);
776 CHECK(Equiv(T, P1));
777
778 VectorFst<Arc> P2;
779 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P2, kPushWeights);
780 CHECK(Equiv(T, P2));
781 VectorFst<Arc> P3;
782 Push<Arc, REWEIGHT_TO_INITIAL>(T, &P3, kPushLabels | kPushWeights);
783 CHECK(Equiv(T, P3));
784 }
785 }
786
787 if ((wprops & (kPath | kCommutative)) == (kPath | kCommutative)) {
788 VLOG(1) << "Check pruning algorithm";
789 {
790 VLOG(1) << "Check equiv. of constructive and destructive algorithms";
791 Weight thresold = (*weight_generator_)();
792 VectorFst<Arc> P1(T);
793 Prune(&P1, thresold);
794 VectorFst<Arc> P2;
795 Prune(T, &P2, thresold);
796 CHECK(Equiv(P1,P2));
797 }
798
799 {
800 VLOG(1) << "Check prune(reverse) equiv reverse(prune)";
801 Weight thresold = (*weight_generator_)();
802 VectorFst< ReverseArc<Arc> > R;
803 VectorFst<Arc> P1(T);
804 VectorFst<Arc> P2;
805 Prune(&P1, thresold);
806 Reverse(T, &R);
807 Prune(&R, thresold.Reverse());
808 Reverse(R, &P2);
809 CHECK(Equiv(P1, P2));
810 }
811 {
812 VLOG(1) << "Check: ShortestDistance(A - prune(A))"
813 << " > ShortestDistance(A) times Threshold";
814 Weight threshold = (*weight_generator_)();
815 VectorFst<Arc> P;
816 Prune(A, &P, threshold);
817 CHECK(PruneEquiv(A, P, threshold));
818 }
819 }
820 if (tprops & kAcyclic) {
821 VLOG(1) << "Check synchronize(T) equiv T";
822 SynchronizeFst<Arc> S(T);
823 CHECK(Equiv(T, S));
824 }
825 }
826
827 // Tests search operations
TestSearch(const Fst<Arc> & T)828 void TestSearch(const Fst<Arc> &T) {
829 uint64 wprops = Weight::Properties();
830
831 VectorFst<Arc> A(T);
832 Project(&A, PROJECT_INPUT);
833
834 if ((wprops & (kPath | kRightSemiring)) == (kPath | kRightSemiring)) {
835 VLOG(1) << "Check 1-best weight.";
836 VectorFst<Arc> path;
837 ShortestPath(T, &path);
838 Weight tsum = ShortestDistance(T);
839 Weight psum = ShortestDistance(path);
840 CHECK(ApproxEqual(tsum, psum, kTestDelta));
841 }
842
843 if ((wprops & (kPath | kSemiring)) == (kPath | kSemiring)) {
844 VLOG(1) << "Check n-best weights";
845 VectorFst<Arc> R(A);
846 RmEpsilon(&R);
847 int nshortest = rand() % kNumRandomShortestPaths + 2;
848 VectorFst<Arc> paths;
849 ShortestPath(R, &paths, nshortest, true, false,
850 Weight::Zero(), kNumShortestStates);
851 vector<Weight> distance;
852 ShortestDistance(paths, &distance, true);
853 StateId pstart = paths.Start();
854 if (pstart != kNoStateId) {
855 ArcIterator< Fst<Arc> > piter(paths, pstart);
856 for (; !piter.Done(); piter.Next()) {
857 StateId s = piter.Value().nextstate;
858 Weight nsum = s < distance.size() ?
859 Times(piter.Value().weight, distance[s]) : Weight::Zero();
860 VectorFst<Arc> path;
861 ShortestPath(R, &path);
862 Weight dsum = ShortestDistance(path);
863 CHECK(ApproxEqual(nsum, dsum, kTestDelta));
864 ArcMap(&path, RmWeightMapper<Arc>());
865 VectorFst<Arc> S;
866 Difference(R, path, &S);
867 R = S;
868 }
869 }
870 }
871 }
872
873 // Tests if two FSTS are equivalent by checking if random
874 // strings from one FST are transduced the same by both FSTs.
875 template <class A>
Equiv(const Fst<A> & fst1,const Fst<A> & fst2)876 bool Equiv(const Fst<A> &fst1, const Fst<A> &fst2) {
877 VLOG(1) << "Check FSTs for sanity (including property bits).";
878 CHECK(Verify(fst1));
879 CHECK(Verify(fst2));
880
881 UniformArcSelector<A> uniform_selector(seed_);
882 RandGenOptions< UniformArcSelector<A> >
883 opts(uniform_selector, kRandomPathLength);
884 return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts);
885 }
886
887 // Tests FSA is unambiguous
Unambiguous(const Fst<Arc> & fst)888 bool Unambiguous(const Fst<Arc> &fst) {
889 VectorFst<StdArc> sfst, dfst;
890 VectorFst<LogArc> lfst1, lfst2;
891 Map(fst, &sfst, RmWeightMapper<Arc, StdArc>());
892 Determinize(sfst, &dfst);
893 Map(fst, &lfst1, RmWeightMapper<Arc, LogArc>());
894 Map(dfst, &lfst2, RmWeightMapper<StdArc, LogArc>());
895 return Equiv(lfst1, lfst2);
896 }
897
898 // Tests ShortestDistance(A - P) >
899 // ShortestDistance(A) times Threshold.
900 template <class A>
PruneEquiv(const Fst<A> & fst,const Fst<A> & pfst,Weight threshold)901 bool PruneEquiv(const Fst<A> &fst, const Fst<A> &pfst,
902 Weight threshold) {
903 VLOG(1) << "Check FSTs for sanity (including property bits).";
904 CHECK(Verify(fst));
905 CHECK(Verify(pfst));
906
907 DifferenceFst<Arc> D(fst, DeterminizeFst<Arc>
908 (RmEpsilonFst<Arc>
909 (ArcMapFst<Arc, Arc,
910 RmWeightMapper<Arc> >
911 (pfst, RmWeightMapper<Arc>()))));
912 Weight sum1 = Times(ShortestDistance(fst), threshold);
913 Weight sum2 = ShortestDistance(D);
914 return Plus(sum1, sum2) == sum1;
915 }
916
917 // Random seed
918 int seed_;
919
920 // FST with no states
921 VectorFst<Arc> zero_fst_;
922
923 // FST with one state that accepts epsilon.
924 VectorFst<Arc> one_fst_;
925
926 // FST with one state that accepts all strings.
927 VectorFst<Arc> univ_fst_;
928
929 // Generates weights used in testing.
930 WeightGenerator *weight_generator_;
931
932 // Maximum random path length.
933 static const int kRandomPathLength;
934
935 // Number of random paths to explore.
936 static const int kNumRandomPaths;
937
938 // Maximum number of nshortest paths.
939 static const int kNumRandomShortestPaths;
940
941 // Maximum number of nshortest states.
942 static const int kNumShortestStates;
943
944 // Delta for equivalence tests.
945 static const float kTestDelta;
946
947 DISALLOW_COPY_AND_ASSIGN(WeightedTester);
948 };
949
950
951 template <class A, class WG>
952 const int WeightedTester<A, WG>::kRandomPathLength = 25;
953
954 template <class A, class WG>
955 const int WeightedTester<A, WG>::kNumRandomPaths = 100;
956
957 template <class A, class WG>
958 const int WeightedTester<A, WG>::kNumRandomShortestPaths = 100;
959
960 template <class A, class WG>
961 const int WeightedTester<A, WG>::kNumShortestStates = 10000;
962
963 template <class A, class WG>
964 const float WeightedTester<A, WG>::kTestDelta = .05;
965
966 // This class tests a variety of identities and properties that must
967 // hold for various algorithms on unweighted FSAs and that are not tested
968 // by WeightedTester. Only the specialization does anything interesting.
969 template <class Arc>
970 class UnweightedTester {
971 public:
UnweightedTester(const Fst<Arc> & zero_fsa,const Fst<Arc> & one_fsa,const Fst<Arc> & univ_fsa)972 UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
973 const Fst<Arc> &univ_fsa) {}
974
Test(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)975 void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {}
976 };
977
978
979 // Specialization for StdArc. This should work for any commutative,
980 // idempotent semiring when restricted to the unweighted case
981 // (being isomorphic to the boolean semiring).
982 template <>
983 class UnweightedTester<StdArc> {
984 public:
985 typedef StdArc Arc;
986 typedef Arc::Label Label;
987 typedef Arc::StateId StateId;
988 typedef Arc::Weight Weight;
989
UnweightedTester(const Fst<Arc> & zero_fsa,const Fst<Arc> & one_fsa,const Fst<Arc> & univ_fsa)990 UnweightedTester(const Fst<Arc> &zero_fsa, const Fst<Arc> &one_fsa,
991 const Fst<Arc> &univ_fsa)
992 : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {}
993
Test(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)994 void Test(const Fst<Arc> &A1, const Fst<Arc> &A2, const Fst<Arc> &A3) {
995 TestRational(A1, A2, A3);
996 TestIntersect(A1, A2, A3);
997 TestOptimize(A1);
998 }
999
1000 private:
1001 // Tests rational operations with identities
TestRational(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)1002 void TestRational(const Fst<Arc> &A1, const Fst<Arc> &A2,
1003 const Fst<Arc> &A3) {
1004
1005 {
1006 VLOG(1) << "Check the union contains its arguments (destructive).";
1007 VectorFst<Arc> U(A1);
1008 Union(&U, A2);
1009
1010 CHECK(Subset(A1, U));
1011 CHECK(Subset(A2, U));
1012 }
1013
1014 {
1015 VLOG(1) << "Check the union contains its arguments (delayed).";
1016 UnionFst<Arc> U(A1, A2);
1017
1018 CHECK(Subset(A1, U));
1019 CHECK(Subset(A2, U));
1020 }
1021
1022 {
1023 VLOG(1) << "Check if A^n c A* (destructive).";
1024 VectorFst<Arc> C(one_fsa_);
1025 int n = rand() % 5;
1026 for (int i = 0; i < n; ++i)
1027 Concat(&C, A1);
1028
1029 VectorFst<Arc> S(A1);
1030 Closure(&S, CLOSURE_STAR);
1031 CHECK(Subset(C, S));
1032 }
1033
1034 {
1035 VLOG(1) << "Check if A^n c A* (delayed).";
1036 int n = rand() % 5;
1037 Fst<Arc> *C = new VectorFst<Arc>(one_fsa_);
1038 for (int i = 0; i < n; ++i) {
1039 ConcatFst<Arc> *F = new ConcatFst<Arc>(*C, A1);
1040 delete C;
1041 C = F;
1042 }
1043 ClosureFst<Arc> S(A1, CLOSURE_STAR);
1044 CHECK(Subset(*C, S));
1045 delete C;
1046 }
1047 }
1048
1049 // Tests intersect-based operations.
TestIntersect(const Fst<Arc> & A1,const Fst<Arc> & A2,const Fst<Arc> & A3)1050 void TestIntersect(const Fst<Arc> &A1, const Fst<Arc> &A2,
1051 const Fst<Arc> &A3) {
1052 VectorFst<Arc> S1(A1);
1053 VectorFst<Arc> S2(A2);
1054 VectorFst<Arc> S3(A3);
1055
1056 ILabelCompare<Arc> comp;
1057
1058 ArcSort(&S1, comp);
1059 ArcSort(&S2, comp);
1060 ArcSort(&S3, comp);
1061
1062 {
1063 VLOG(1) << "Check the intersection is contained in its arguments.";
1064 IntersectFst<Arc> I1(S1, S2);
1065 CHECK(Subset(I1, S1));
1066 CHECK(Subset(I1, S2));
1067 }
1068
1069 {
1070 VLOG(1) << "Check union distributes over intersection.";
1071 IntersectFst<Arc> I1(S1, S2);
1072 UnionFst<Arc> U1(I1, S3);
1073
1074 UnionFst<Arc> U2(S1, S3);
1075 UnionFst<Arc> U3(S2, S3);
1076 ArcSortFst< Arc, ILabelCompare<Arc> > S4(U3, comp);
1077 IntersectFst<Arc> I2(U2, S4);
1078
1079 CHECK(Equiv(U1, I2));
1080 }
1081
1082 VectorFst<Arc> C1;
1083 VectorFst<Arc> C2;
1084 Complement(S1, &C1);
1085 Complement(S2, &C2);
1086 ArcSort(&C1, comp);
1087 ArcSort(&C2, comp);
1088
1089
1090 {
1091 VLOG(1) << "Check S U S' = Sigma*";
1092 UnionFst<Arc> U(S1, C1);
1093 CHECK(Equiv(U, univ_fsa_));
1094 }
1095
1096 {
1097 VLOG(1) << "Check S n S' = {}";
1098 IntersectFst<Arc> I(S1, C1);
1099 CHECK(Equiv(I, zero_fsa_));
1100 }
1101
1102 {
1103 VLOG(1) << "Check (S1' U S2') == (S1 n S2)'";
1104 UnionFst<Arc> U(C1, C2);
1105
1106 IntersectFst<Arc> I(S1, S2);
1107 VectorFst<Arc> C3;
1108 Complement(I, &C3);
1109 CHECK(Equiv(U, C3));
1110 }
1111
1112 {
1113 VLOG(1) << "Check (S1' n S2') == (S1 U S2)'";
1114 IntersectFst<Arc> I(C1, C2);
1115
1116 UnionFst<Arc> U(S1, S2);
1117 VectorFst<Arc> C3;
1118 Complement(U, &C3);
1119 CHECK(Equiv(I, C3));
1120 }
1121 }
1122
1123 // Tests optimization operations
TestOptimize(const Fst<Arc> & A)1124 void TestOptimize(const Fst<Arc> &A) {
1125 {
1126 VLOG(1) << "Check determinized FSA is equivalent to its input.";
1127 DeterminizeFst<Arc> D(A);
1128 CHECK(Equiv(A, D));
1129 }
1130
1131 {
1132 VLOG(1) << "Check disambiguated FSA is equivalent to its input.";
1133 VectorFst<Arc> R(A), D;
1134 RmEpsilon(&R);
1135
1136 Disambiguate(R, &D);
1137 CHECK(Equiv(R, D));
1138 }
1139
1140 {
1141 VLOG(1) << "Check minimized FSA is equivalent to its input.";
1142 int n;
1143 {
1144 RmEpsilonFst<Arc> R(A);
1145 DeterminizeFst<Arc> D(R);
1146 VectorFst<Arc> M(D);
1147 Minimize(&M);
1148 CHECK(Equiv(A, M));
1149 n = M.NumStates();
1150 }
1151
1152 if (n) { // Skip test if A is the empty machine
1153 VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the"
1154 << " same number of states as Brozozowski's algorithm";
1155 VectorFst<Arc> R;
1156 Reverse(A, &R);
1157 RmEpsilon(&R);
1158 DeterminizeFst<Arc> DR(R);
1159 VectorFst<Arc> RD;
1160 Reverse(DR, &RD);
1161 DeterminizeFst<Arc> DRD(RD);
1162 VectorFst<Arc> M(DRD);
1163 CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition
1164 // to the initial state
1165 }
1166 }
1167 }
1168
1169 // Tests if two FSAS are equivalent.
Equiv(const Fst<Arc> & fsa1,const Fst<Arc> & fsa2)1170 bool Equiv(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1171 VLOG(1) << "Check FSAs for sanity (including property bits).";
1172 CHECK(Verify(fsa1));
1173 CHECK(Verify(fsa2));
1174
1175 VectorFst<Arc> vfsa1(fsa1);
1176 VectorFst<Arc> vfsa2(fsa2);
1177 RmEpsilon(&vfsa1);
1178 RmEpsilon(&vfsa2);
1179 DeterminizeFst<Arc> dfa1(vfsa1);
1180 DeterminizeFst<Arc> dfa2(vfsa2);
1181
1182 // Test equivalence using union-find algorithm
1183 bool equiv1 = Equivalent(dfa1, dfa2);
1184
1185 // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty
1186 ILabelCompare<Arc> comp;
1187 VectorFst<Arc> sdfa1(dfa1);
1188 ArcSort(&sdfa1, comp);
1189 VectorFst<Arc> sdfa2(dfa2);
1190 ArcSort(&sdfa2, comp);
1191
1192 DifferenceFst<Arc> dfsa1(sdfa1, sdfa2);
1193 DifferenceFst<Arc> dfsa2(sdfa2, sdfa1);
1194
1195 VectorFst<Arc> ufsa(dfsa1);
1196 Union(&ufsa, dfsa2);
1197 Connect(&ufsa);
1198 bool equiv2 = ufsa.NumStates() == 0;
1199
1200 // Check two equivalence tests match
1201 CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2));
1202
1203 return equiv1;
1204 }
1205
1206 // Tests if FSA1 is a subset of FSA2 (disregarding weights).
Subset(const Fst<Arc> & fsa1,const Fst<Arc> & fsa2)1207 bool Subset(const Fst<Arc> &fsa1, const Fst<Arc> &fsa2) {
1208 VLOG(1) << "Check FSAs (incl. property bits) for sanity";
1209 CHECK(Verify(fsa1));
1210 CHECK(Verify(fsa2));
1211
1212 VectorFst<StdArc> vfsa1;
1213 VectorFst<StdArc> vfsa2;
1214 RmEpsilon(&vfsa1);
1215 RmEpsilon(&vfsa2);
1216 ILabelCompare<StdArc> comp;
1217 ArcSort(&vfsa1, comp);
1218 ArcSort(&vfsa2, comp);
1219 IntersectFst<StdArc> ifsa(vfsa1, vfsa2);
1220 DeterminizeFst<StdArc> dfa1(vfsa1);
1221 DeterminizeFst<StdArc> dfa2(ifsa);
1222 return Equivalent(dfa1, dfa2);
1223 }
1224
1225 // Returns complement Fsa
Complement(const Fst<Arc> & ifsa,MutableFst<Arc> * ofsa)1226 void Complement(const Fst<Arc> &ifsa, MutableFst<Arc> *ofsa) {
1227 RmEpsilonFst<Arc> rfsa(ifsa);
1228 DeterminizeFst<Arc> dfa(rfsa);
1229 DifferenceFst<Arc> cfsa(univ_fsa_, dfa);
1230 *ofsa = cfsa;
1231 }
1232
1233 // FSA with no states
1234 VectorFst<Arc> zero_fsa_;
1235
1236 // FSA with one state that accepts epsilon.
1237 VectorFst<Arc> one_fsa_;
1238
1239 // FSA with one state that accepts all strings.
1240 VectorFst<Arc> univ_fsa_;
1241
1242 DISALLOW_COPY_AND_ASSIGN(UnweightedTester);
1243 };
1244
1245
1246 // This class tests a variety of identities and properties that must
1247 // hold for various FST algorithms. It randomly generates FSTs, using
1248 // function object 'weight_generator' to select weights. 'WeightTester'
1249 // and 'UnweightedTester' are then called.
1250 template <class Arc, class WeightGenerator>
1251 class AlgoTester {
1252 public:
1253 typedef typename Arc::Label Label;
1254 typedef typename Arc::StateId StateId;
1255 typedef typename Arc::Weight Weight;
1256
AlgoTester(WeightGenerator generator,int seed)1257 AlgoTester(WeightGenerator generator, int seed) :
1258 weight_generator_(generator), seed_(seed) {
1259 one_fst_.AddState();
1260 one_fst_.SetStart(0);
1261 one_fst_.SetFinal(0, Weight::One());
1262
1263 univ_fst_.AddState();
1264 univ_fst_.SetStart(0);
1265 univ_fst_.SetFinal(0, Weight::One());
1266 for (int i = 0; i < kNumRandomLabels; ++i)
1267 univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0));
1268 }
1269
Test()1270 void Test() {
1271 VLOG(1) << "weight type = " << Weight::Type();
1272
1273 for (int i = 0; i < FLAGS_repeat; ++i) {
1274 // Random transducers
1275 VectorFst<Arc> T1;
1276 VectorFst<Arc> T2;
1277 VectorFst<Arc> T3;
1278 RandFst(&T1);
1279 RandFst(&T2);
1280 RandFst(&T3);
1281 WeightedTester<Arc, WeightGenerator>
1282 weighted_tester(seed_, zero_fst_, one_fst_,
1283 univ_fst_, &weight_generator_);
1284 weighted_tester.Test(T1, T2, T3);
1285
1286 VectorFst<Arc> A1(T1);
1287 VectorFst<Arc> A2(T2);
1288 VectorFst<Arc> A3(T3);
1289 Project(&A1, PROJECT_OUTPUT);
1290 Project(&A2, PROJECT_INPUT);
1291 Project(&A3, PROJECT_INPUT);
1292 ArcMap(&A1, rm_weight_mapper);
1293 ArcMap(&A2, rm_weight_mapper);
1294 ArcMap(&A3, rm_weight_mapper);
1295 UnweightedTester<Arc> unweighted_tester(zero_fst_, one_fst_, univ_fst_);
1296 unweighted_tester.Test(A1, A2, A3);
1297 }
1298 }
1299
1300 private:
1301 // Generates a random FST.
RandFst(MutableFst<Arc> * fst)1302 void RandFst(MutableFst<Arc> *fst) {
1303 // Determines direction of the arcs wrt state numbering. This way we
1304 // can force acyclicity when desired.
1305 enum ArcDirection { ANY_DIRECTION = 0, FORWARD_DIRECTION = 1,
1306 REVERSE_DIRECTION = 2, NUM_DIRECTIONS = 3 };
1307
1308 ArcDirection arc_direction = ANY_DIRECTION;
1309 if (rand()/(RAND_MAX + 1.0) < kAcyclicProb)
1310 arc_direction = rand() % 2 ? FORWARD_DIRECTION : REVERSE_DIRECTION;
1311
1312 fst->DeleteStates();
1313 StateId ns = rand() % kNumRandomStates;
1314
1315 if (ns == 0)
1316 return;
1317 for (StateId s = 0; s < ns; ++s)
1318 fst->AddState();
1319
1320 StateId start = rand() % ns;
1321 fst->SetStart(start);
1322
1323 size_t na = rand() % kNumRandomArcs;
1324 for (size_t n = 0; n < na; ++n) {
1325 StateId s = rand() % ns;
1326 Arc arc;
1327 arc.ilabel = rand() % kNumRandomLabels;
1328 arc.olabel = rand() % kNumRandomLabels;
1329 arc.weight = weight_generator_();
1330 arc.nextstate = rand() % ns;
1331
1332 if (arc_direction == ANY_DIRECTION ||
1333 (arc_direction == FORWARD_DIRECTION && arc.ilabel > arc.olabel) ||
1334 (arc_direction == REVERSE_DIRECTION && arc.ilabel < arc.olabel))
1335 fst->AddArc(s, arc);
1336 }
1337
1338 StateId nf = rand() % (ns + 1);
1339 for (StateId n = 0; n < nf; ++n) {
1340 StateId s = rand() % ns;
1341 Weight final = weight_generator_();
1342 fst->SetFinal(s, final);
1343 }
1344 VLOG(1) << "Check FST for sanity (including property bits).";
1345 CHECK(Verify(*fst));
1346
1347 // Get/compute all properties.
1348 uint64 props = fst->Properties(kFstProperties, true);
1349
1350 // Select random set of properties to be unknown.
1351 uint64 mask = 0;
1352 for (int n = 0; n < 8; ++n) {
1353 mask |= rand() & 0xff;
1354 mask <<= 8;
1355 }
1356 mask &= ~kTrinaryProperties;
1357 fst->SetProperties(props & ~mask, mask);
1358 }
1359
1360 // Generates weights used in testing.
1361 WeightGenerator weight_generator_;
1362
1363 // Random seed
1364 int seed_;
1365
1366 // FST with no states
1367 VectorFst<Arc> zero_fst_;
1368
1369 // FST with one state that accepts epsilon.
1370 VectorFst<Arc> one_fst_;
1371
1372 // FST with one state that accepts all strings.
1373 VectorFst<Arc> univ_fst_;
1374
1375 // Mapper to remove weights from an Fst
1376 RmWeightMapper<Arc> rm_weight_mapper;
1377
1378 // Maximum number of states in random test Fst.
1379 static const int kNumRandomStates;
1380
1381 // Maximum number of arcs in random test Fst.
1382 static const int kNumRandomArcs;
1383
1384 // Number of alternative random labels.
1385 static const int kNumRandomLabels;
1386
1387 // Probability to force an acyclic Fst
1388 static const float kAcyclicProb;
1389
1390 // Maximum random path length.
1391 static const int kRandomPathLength;
1392
1393 // Number of random paths to explore.
1394 static const int kNumRandomPaths;
1395
1396 DISALLOW_COPY_AND_ASSIGN(AlgoTester);
1397 };
1398
1399 template <class A, class G> const int AlgoTester<A, G>::kNumRandomStates = 10;
1400
1401 template <class A, class G> const int AlgoTester<A, G>::kNumRandomArcs = 25;
1402
1403 template <class A, class G> const int AlgoTester<A, G>::kNumRandomLabels = 5;
1404
1405 template <class A, class G> const float AlgoTester<A, G>::kAcyclicProb = .25;
1406
1407 template <class A, class G> const int AlgoTester<A, G>::kRandomPathLength = 25;
1408
1409 template <class A, class G> const int AlgoTester<A, G>::kNumRandomPaths = 100;
1410
1411 } // namespace fst
1412
1413 #endif // FST_TEST_ALGO_TEST_H__
1414