1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17""" An example of using WarpCTC loss for an OCR problem using LSTM and CAPTCHA image data"""
18
19from __future__ import print_function
20
21import argparse
22import logging
23import os
24
25from captcha_generator import MPDigitCaptcha
26from hyperparams import Hyperparams
27from ctc_metrics import CtcMetrics
28import lstm
29import mxnet as mx
30from ocr_iter import OCRIter
31
32
33def get_fonts(path):
34    fonts = list()
35    if os.path.isdir(path):
36        for filename in os.listdir(path):
37            if filename.endswith('.ttf'):
38                fonts.append(os.path.join(path, filename))
39    else:
40        fonts.append(path)
41    return fonts
42
43
44def parse_args():
45    """Parse command line arguments"""
46    parser = argparse.ArgumentParser()
47    parser.add_argument("font_path", help="Path to ttf font file or directory containing ttf files")
48    parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc')
49    parser.add_argument("--cpu",
50                        help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.",
51                        type=int, default=8)
52    parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int)
53    parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4)
54    parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr')
55    return parser.parse_args()
56
57
58def main():
59    """Program entry point"""
60    args = parse_args()
61    if not any(args.loss == s for s in ['ctc', 'warpctc']):
62        raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss))
63
64    hp = Hyperparams()
65
66    # Start a multiprocessor captcha image generator
67    mp_captcha = MPDigitCaptcha(
68        font_paths=get_fonts(args.font_path), h=hp.seq_length, w=30,
69        num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2)
70    try:
71        # Must call start() before any call to mxnet module (https://github.com/apache/incubator-mxnet/issues/9213)
72        mp_captcha.start()
73
74        if args.gpu:
75            contexts = [mx.context.gpu(i) for i in range(args.gpu)]
76        else:
77            contexts = [mx.context.cpu(i) for i in range(args.cpu)]
78
79        init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden)
80
81        data_train = OCRIter(
82            hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='train')
83        data_val = OCRIter(
84            hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='val')
85
86        symbol = lstm.lstm_unroll(
87            num_lstm_layer=hp.num_lstm_layer,
88            seq_len=hp.seq_length,
89            num_hidden=hp.num_hidden,
90            num_label=hp.num_label,
91            loss_type=args.loss)
92
93        head = '%(asctime)-15s %(message)s'
94        logging.basicConfig(level=logging.DEBUG, format=head)
95
96        module = mx.mod.Module(
97            symbol,
98            data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'],
99            label_names=['label'],
100            context=contexts)
101
102        metrics = CtcMetrics(hp.seq_length)
103        module.fit(train_data=data_train,
104                   eval_data=data_val,
105                   # use metrics.accuracy or metrics.accuracy_lcs
106                   eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
107                   optimizer='sgd',
108                   optimizer_params={'learning_rate': hp.learning_rate,
109                                     'momentum': hp.momentum,
110                                     'wd': 0.00001,
111                                     },
112                   initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
113                   num_epoch=hp.num_epoch,
114                   batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
115                   epoch_end_callback=mx.callback.do_checkpoint(args.prefix),
116                   )
117    except KeyboardInterrupt:
118        print("W: interrupt received, stopping...")
119    finally:
120        # Reset multiprocessing captcha generator to stop processes
121        mp_captcha.reset()
122
123
124if __name__ == '__main__':
125    main()
126