1#!/usr/bin/env python3
2
3# Licensed to the Apache Software Foundation (ASF) under one
4# or more contributor license agreements.  See the NOTICE file
5# distributed with this work for additional information
6# regarding copyright ownership.  The ASF licenses this file
7# to you under the Apache License, Version 2.0 (the
8# "License"); you may not use this file except in compliance
9# with the License.  You may obtain a copy of the License at
10#
11#   http://www.apache.org/licenses/LICENSE-2.0
12#
13# Unless required by applicable law or agreed to in writing,
14# software distributed under the License is distributed on an
15# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16# KIND, either express or implied.  See the License for the
17# specific language governing permissions and limitations
18# under the License.
19
20"""
21Module : this module to decode using beam search
22https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py
23"""
24
25from __future__ import division
26from __future__ import print_function
27import numpy as np
28
29class BeamEntry:
30    """
31    information about one single beam at specific time-step
32    """
33    def __init__(self):
34        self.prTotal = 0 # blank and non-blank
35        self.prNonBlank = 0 # non-blank
36        self.prBlank = 0 # blank
37        self.prText = 1 # LM score
38        self.lmApplied = False # flag if LM was already applied to this beam
39        self.labeling = () # beam-labeling
40
41class BeamState:
42    """
43    information about the beams at specific time-step
44    """
45    def __init__(self):
46        self.entries = {}
47
48    def norm(self):
49        """
50        length-normalise LM score
51        """
52        for (k, _) in self.entries.items():
53            labelingLen = len(self.entries[k].labeling)
54            self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0))
55
56    def sort(self):
57        """
58        return beam-labelings, sorted by probability
59        """
60        beams = [v for (_, v) in self.entries.items()]
61        sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText)
62        return [x.labeling for x in sortedBeams]
63
64def applyLM(parentBeam, childBeam, classes, lm):
65    """
66    calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars
67    """
68    if lm and not childBeam.lmApplied:
69        c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char
70        c2 = classes[childBeam.labeling[-1]] # second char
71        lmFactor = 0.01 # influence of language model
72        bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other
73        childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence
74        childBeam.lmApplied = True # only apply LM once per beam entry
75
76def addBeam(beamState, labeling):
77    """
78    add beam if it does not yet exist
79    """
80    if labeling not in beamState.entries:
81        beamState.entries[labeling] = BeamEntry()
82
83def ctcBeamSearch(mat, classes, lm, k, beamWidth):
84    """
85    beam search as described by the paper of Hwang et al. and the paper of Graves et al.
86    """
87
88    blankIdx = len(classes)
89    maxT, maxC = mat.shape
90
91    # initialise beam state
92    last = BeamState()
93    labeling = ()
94    last.entries[labeling] = BeamEntry()
95    last.entries[labeling].prBlank = 1
96    last.entries[labeling].prTotal = 1
97
98    # go over all time-steps
99    for t in range(maxT):
100        curr = BeamState()
101
102        # get beam-labelings of best beams
103        bestLabelings = last.sort()[0:beamWidth]
104
105	    # go over best beams
106        for labeling in bestLabelings:
107
108	        # probability of paths ending with a non-blank
109            prNonBlank = 0
110	        # in case of non-empty beam
111            if labeling:
112		       # probability of paths with repeated last char at the end
113                try:
114                    prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]]
115                except FloatingPointError:
116                    prNonBlank = 0
117
118	    # probability of paths ending with a blank
119            prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx]
120
121	    # add beam at current time-step if needed
122            addBeam(curr, labeling)
123
124            # fill in data
125            curr.entries[labeling].labeling = labeling
126            curr.entries[labeling].prNonBlank += prNonBlank
127            curr.entries[labeling].prBlank += prBlank
128            curr.entries[labeling].prTotal += prBlank + prNonBlank
129            curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from
130            curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling
131
132            # extend current beam-labeling
133            for c in range(maxC - 1):
134                # add new char to current beam-labeling
135                newLabeling = labeling + (c,)
136
137                # if new labeling contains duplicate char at the end, only consider paths ending with a blank
138                if labeling and labeling[-1] == c:
139                    prNonBlank = mat[t, c] * last.entries[labeling].prBlank
140                else:
141                    prNonBlank = mat[t, c] * last.entries[labeling].prTotal
142
143		        # add beam at current time-step if needed
144                addBeam(curr, newLabeling)
145
146		        # fill in data
147                curr.entries[newLabeling].labeling = newLabeling
148                curr.entries[newLabeling].prNonBlank += prNonBlank
149                curr.entries[newLabeling].prTotal += prNonBlank
150
151		        # apply LM
152                applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm)
153
154        # set new beam state
155        last = curr
156
157    # normalise LM scores according to beam-labeling-length
158    last.norm()
159
160    # sort by probability
161    bestLabelings = last.sort()[:k] # get most probable labeling
162
163    output = []
164    for bestLabeling in bestLabelings:
165        # map labels to chars
166        res = ''
167        for l in bestLabeling:
168            res += classes[l]
169        output.append(res)
170    return output