• Home
  • History
  • Annotate
Name Date Size #Lines LOC

..04-Nov-2021-

README.mdH A D04-Nov-20218.7 KiB13191

captcha_generator.pyH A D04-Nov-20216.3 KiB200157

ctc_metrics.pyH A D04-Nov-20214.2 KiB12192

hyperparams.pyH A D04-Nov-20212.1 KiB7745

lstm.pyH A D04-Nov-20216.6 KiB175123

lstm_ocr_infer.pyH A D04-Nov-20213.5 KiB9353

lstm_ocr_train.pyH A D04-Nov-20214.9 KiB12685

multiproc_data.pyH A D04-Nov-20214.1 KiB12798

ocr_iter.pyH A D04-Nov-20213.7 KiB11480

ocr_predict.pyH A D04-Nov-20216.2 KiB16485

README.md

1<!--- Licensed to the Apache Software Foundation (ASF) under one -->
2<!--- or more contributor license agreements.  See the NOTICE file -->
3<!--- distributed with this work for additional information -->
4<!--- regarding copyright ownership.  The ASF licenses this file -->
5<!--- to you under the Apache License, Version 2.0 (the -->
6<!--- "License"); you may not use this file except in compliance -->
7<!--- with the License.  You may obtain a copy of the License at -->
8
9<!---   http://www.apache.org/licenses/LICENSE-2.0 -->
10
11<!--- Unless required by applicable law or agreed to in writing, -->
12<!--- software distributed under the License is distributed on an -->
13<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
14<!--- KIND, either express or implied.  See the License for the -->
15<!--- specific language governing permissions and limitations -->
16<!--- under the License. -->
17
18# Connectionist Temporal Classification
19
20[Connectionist Temporal Classification](https://www.cs.toronto.edu/~graves/icml_2006.pdf) (CTC) is a cost function that is used to train Recurrent Neural Networks (RNNs) to label unsegmented input sequence data in supervised learning. For example in a speech recognition application, using a typical cross-entropy loss the input signal needs to be segmented into words or sub-words. However, using CTC-loss, a single unaligned label sequence per input sequence is sufficient for the network to learn both the alignment and labeling. Baidu's warp-ctc page contains a more detailed [introduction to CTC-loss](https://github.com/baidu-research/warp-ctc#introduction).
21
22## LSTM OCR Example
23In this example, we use CTC loss to train a network on the problem of Optical Character Recognition (OCR) of CAPTCHA images. This example uses the `captcha` python package to generate a random dataset for training. Training the network requires a CTC-loss layer and MXNet provides two options for such layer. The OCR example is constructed as follows:
24
251. 80x30 CAPTCHA images containing 3 to 4 random digits are generated using python captcha library.
262. Each image is used as a data sequence with sequence-length of 80 and vector length of 30.
273. The output layer uses CTC loss in training and softmax in inference.
28
29Note: When using CTC-loss, one prediction label is reserved for blank label. In this example, when predicting digits between 0 to 9, softmax output has 11 labels, with label 0 used for blank and 1 to 10 used for digit 0 to digit 9 respectively.
30
31### Description of the files
32LSTM-OCR example contains the following files:
33* `captcha_generator.py`: Module for generating random 3 or 4 digit CAPTCHA images for training. It also contains a script for generating sample CAPTCHA images into an output file for inference testing.
34* `ctc_metrics.py`: Module for calculating the prediction accuracy during training. Two accuracy measures are implemented: A simple accuracy measure that calculates number of correct predictions divided by total number of predictions and a second accuracy measure based on sum of Longest Common Sequence (LCS) ratio of all predictions divided by total number of predictions.
35* `hyperparameters.py`: Contains all hyperparameters for the network structure and training.
36* `lstm.py`: Contains LSTM network implementations. Options for adding mxnet-ctc and warp-ctc loss for training as well as adding softmax for inference are available.
37* `lstm_ocr_infer.py`: Script for running inference after training.
38* `lstm_ocr_train.py`: Script for training with ctc or warp-ctc loss.
39* `multiproc_data.py`: A module for multiprocess data generation.
40* `oct_iter.py`: A DataIter module for iterating through training data.
41
42## CTC-loss in MXNet
43MXNet supports two CTC-loss layers in Symbol API:
44
45* `mxnet.symbol.contrib.ctc_loss` is implemented in MXNet and included as part of the standard package.
46* `mxnet.symbol.WarpCTC` uses Baidu's warp-ctc library and requires building warp-ctc library and mxnet library both from source.
47
48### Building MXNet with warp-ctc
49In order to use `mxnet.symbol.WarpCTC` layer, you need to first build Baidu's [warp-ctc](https://github.com/baidu-research/warp-ctc) library from source and then build MXNet from source with warp-ctc config flags enabled.
50
51#### Building warp-ctc
52You need to first build warp-ctc from source and then install it in your system. Please follow [instructions here](https://github.com/baidu-research/warp-ctc#compilation) to build warp-ctc from source. Once compiled, you need to install the library by running the following command from `warp-ctc/build` directory:
53```
54$ sudo make install
55```
56
57#### Building MXNet from source with warp-ctc integration
58In order to build MXNet from source, you need to follow [instructions here](https://mxnet.apache.org/install/index.html). After choosing your system configuration, Python environment, and "Build from Source" options, before running `make` in step 4, you need to enable warp-ctc integration by uncommenting the following lines in `make/config.mk` in `incubator-mxnet` directory:
59```
60WARPCTC_PATH = $(HOME)/warp-ctc
61MXNET_PLUGINS += plugin/warpctc/warpctc.mk
62```
63
64## Run LSTM OCR Example
65Running this example requires the following pre-requisites:
66* `captcha` and `opencv` python packages are installed:
67```
68$ pip install captcha
69$ pip install opencv-python
70```
71* You have access to one (or more) `ttf` font files. You can download a collection of font files from [Ubuntu's website](https://design.ubuntu.com/font/). The instructions in this section assume that a `./font/Ubuntu-M.ttf` file exists under the `example/ctc/` directory.
72
73### Training
74The training script demonstrates how to construct a network with both CTC loss options and train using `mxnet.Module` API. Training is done by generating random CAPTCHA images using the font(s) provided. This example uses 80x30 captcha images that contain 3 to 4 digits each.
75
76When using a GPU for training, the training bottleneck will be data generation. To remedy this bottleneck, this example implements a multiprocess data generation. Number of processes for image generation as well as training on CPU or GPU can be configured using command line arguments.
77
78To see the list of all arguments:
79```
80$ python lstm_ocr_train.py --help
81```
82Using command line, you can also select between ctc or warp-ctc loss options. For example, the following command initiates a training session on a single GPU with 4 CAPTCHA generating processes using ctc loss and `font/Ubuntu-M.ttf` font file:
83```
84$ python lstm_ocr_train.py --gpu 1 --num_proc 4 --loss ctc font/Ubuntu-M.ttf
85```
86
87You can train with multiple fonts by specifying a folder that contains multiple `ttf` font files instead. The training saves a checkpoint after each epoch. The prefix used for checkpoint is 'ocr' by default, but can be changed with `--prefix` argument.
88
89When testing this example, the following system configuration was used:
90* p2.xlarge AWS EC2 instance (4 x CPU and 1 x K80 GPU)
91* Deep Learning Amazon Machine Image (with mxnet 1.0.0)
92
93This training example finishes after 100 epochs with ~87% accuracy. If you continue training further, the network achieves over 95% accuracy. Similar accuracy is achieved with both ctc (`--loss ctc`) and warp-ctc (`--loss warpctc`) options. Logs of the last training epoch:
94
95```
9605:58:36,128 Epoch[99] Batch [50]	Speed: 1067.63 samples/sec	accuracy=0.877757
9705:58:42,119 Epoch[99] Batch [100]	Speed: 1068.14 samples/sec	accuracy=0.859688
9805:58:48,114 Epoch[99] Batch [150]	Speed: 1067.73 samples/sec	accuracy=0.870469
9905:58:54,107 Epoch[99] Batch [200]	Speed: 1067.91 samples/sec	accuracy=0.864219
10005:58:58,004 Epoch[99] Train-accuracy=0.877367
10105:58:58,005 Epoch[99] Time cost=28.068
10205:58:58,047 Saved checkpoint to "ocr-0100.params"
10305:59:00,721 Epoch[99] Validation-accuracy=0.868886
104```
105
106### Inference
107The inference script demonstrates how to load a network from a checkpoint, modify its final layer, and predict a label for a CAPTCHA image using `mxnet.Module` API. You can choose the prefix as well as the epoch number of the checkpoint using command line arguments. To see the full list of arguments:
108```
109$ python lstm_ocr_infer.py --help
110```
111For example, to predict label for 'sample.jpg' file using 'ocr' prefix and checkpoint at epoch 100:
112```
113$ python lstm_ocr_infer.py --prefix ocr --epoch 100 sample.jpg
114
115Digits: [0, 0, 8, 9]
116```
117
118Note: The above command expects the following files, generated by the training script, to exist in the current directory:
119* ocr-symbol.json
120* ocr-0100.params
121
122#### Generate CAPTCHA samples
123CAPTCHA images can be generated using the `captcha_generator.py` script. To see the list of all arguments:
124```
125$ python captcha_generator.py --help
126```
127For example, to generate a CAPTCHA image with random digits from 'font/Ubuntu-M.ttf' and save to 'sample.jpg' file:
128```
129$ python captcha_generator.py font/Ubuntu-M.ttf sample.jpg
130```
131