1#!/usr/bin/env python3 2 3# Licensed to the Apache Software Foundation (ASF) under one 4# or more contributor license agreements. See the NOTICE file 5# distributed with this work for additional information 6# regarding copyright ownership. The ASF licenses this file 7# to you under the Apache License, Version 2.0 (the 8# "License"); you may not use this file except in compliance 9# with the License. You may obtain a copy of the License at 10# 11# http://www.apache.org/licenses/LICENSE-2.0 12# 13# Unless required by applicable law or agreed to in writing, 14# software distributed under the License is distributed on an 15# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 16# KIND, either express or implied. See the License for the 17# specific language governing permissions and limitations 18# under the License. 19 20""" 21Module : this module to decode using beam search 22https://github.com/ThomasDelteil/HandwrittenTextRecognition_MXNet/blob/master/utils/CTCDecoder/BeamSearch.py 23""" 24 25from __future__ import division 26from __future__ import print_function 27import numpy as np 28 29class BeamEntry: 30 """ 31 information about one single beam at specific time-step 32 """ 33 def __init__(self): 34 self.prTotal = 0 # blank and non-blank 35 self.prNonBlank = 0 # non-blank 36 self.prBlank = 0 # blank 37 self.prText = 1 # LM score 38 self.lmApplied = False # flag if LM was already applied to this beam 39 self.labeling = () # beam-labeling 40 41class BeamState: 42 """ 43 information about the beams at specific time-step 44 """ 45 def __init__(self): 46 self.entries = {} 47 48 def norm(self): 49 """ 50 length-normalise LM score 51 """ 52 for (k, _) in self.entries.items(): 53 labelingLen = len(self.entries[k].labeling) 54 self.entries[k].prText = self.entries[k].prText ** (1.0 / (labelingLen if labelingLen else 1.0)) 55 56 def sort(self): 57 """ 58 return beam-labelings, sorted by probability 59 """ 60 beams = [v for (_, v) in self.entries.items()] 61 sortedBeams = sorted(beams, reverse=True, key=lambda x: x.prTotal*x.prText) 62 return [x.labeling for x in sortedBeams] 63 64def applyLM(parentBeam, childBeam, classes, lm): 65 """ 66 calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars 67 """ 68 if lm and not childBeam.lmApplied: 69 c1 = classes[parentBeam.labeling[-1] if parentBeam.labeling else classes.index(' ')] # first char 70 c2 = classes[childBeam.labeling[-1]] # second char 71 lmFactor = 0.01 # influence of language model 72 bigramProb = lm.getCharBigram(c1, c2) ** lmFactor # probability of seeing first and second char next to each other 73 childBeam.prText = parentBeam.prText * bigramProb # probability of char sequence 74 childBeam.lmApplied = True # only apply LM once per beam entry 75 76def addBeam(beamState, labeling): 77 """ 78 add beam if it does not yet exist 79 """ 80 if labeling not in beamState.entries: 81 beamState.entries[labeling] = BeamEntry() 82 83def ctcBeamSearch(mat, classes, lm, k, beamWidth): 84 """ 85 beam search as described by the paper of Hwang et al. and the paper of Graves et al. 86 """ 87 88 blankIdx = len(classes) 89 maxT, maxC = mat.shape 90 91 # initialise beam state 92 last = BeamState() 93 labeling = () 94 last.entries[labeling] = BeamEntry() 95 last.entries[labeling].prBlank = 1 96 last.entries[labeling].prTotal = 1 97 98 # go over all time-steps 99 for t in range(maxT): 100 curr = BeamState() 101 102 # get beam-labelings of best beams 103 bestLabelings = last.sort()[0:beamWidth] 104 105 # go over best beams 106 for labeling in bestLabelings: 107 108 # probability of paths ending with a non-blank 109 prNonBlank = 0 110 # in case of non-empty beam 111 if labeling: 112 # probability of paths with repeated last char at the end 113 try: 114 prNonBlank = last.entries[labeling].prNonBlank * mat[t, labeling[-1]] 115 except FloatingPointError: 116 prNonBlank = 0 117 118 # probability of paths ending with a blank 119 prBlank = (last.entries[labeling].prTotal) * mat[t, blankIdx] 120 121 # add beam at current time-step if needed 122 addBeam(curr, labeling) 123 124 # fill in data 125 curr.entries[labeling].labeling = labeling 126 curr.entries[labeling].prNonBlank += prNonBlank 127 curr.entries[labeling].prBlank += prBlank 128 curr.entries[labeling].prTotal += prBlank + prNonBlank 129 curr.entries[labeling].prText = last.entries[labeling].prText # beam-labeling not changed, therefore also LM score unchanged from 130 curr.entries[labeling].lmApplied = True # LM already applied at previous time-step for this beam-labeling 131 132 # extend current beam-labeling 133 for c in range(maxC - 1): 134 # add new char to current beam-labeling 135 newLabeling = labeling + (c,) 136 137 # if new labeling contains duplicate char at the end, only consider paths ending with a blank 138 if labeling and labeling[-1] == c: 139 prNonBlank = mat[t, c] * last.entries[labeling].prBlank 140 else: 141 prNonBlank = mat[t, c] * last.entries[labeling].prTotal 142 143 # add beam at current time-step if needed 144 addBeam(curr, newLabeling) 145 146 # fill in data 147 curr.entries[newLabeling].labeling = newLabeling 148 curr.entries[newLabeling].prNonBlank += prNonBlank 149 curr.entries[newLabeling].prTotal += prNonBlank 150 151 # apply LM 152 applyLM(curr.entries[labeling], curr.entries[newLabeling], classes, lm) 153 154 # set new beam state 155 last = curr 156 157 # normalise LM scores according to beam-labeling-length 158 last.norm() 159 160 # sort by probability 161 bestLabelings = last.sort()[:k] # get most probable labeling 162 163 output = [] 164 for bestLabeling in bestLabelings: 165 # map labels to chars 166 res = '' 167 for l in bestLabeling: 168 res += classes[l] 169 output.append(res) 170 return output