1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17""" An example of using WarpCTC loss for an OCR problem using LSTM and CAPTCHA image data""" 18 19from __future__ import print_function 20 21import argparse 22import logging 23import os 24 25from captcha_generator import MPDigitCaptcha 26from hyperparams import Hyperparams 27from ctc_metrics import CtcMetrics 28import lstm 29import mxnet as mx 30from ocr_iter import OCRIter 31 32 33def get_fonts(path): 34 fonts = list() 35 if os.path.isdir(path): 36 for filename in os.listdir(path): 37 if filename.endswith('.ttf'): 38 fonts.append(os.path.join(path, filename)) 39 else: 40 fonts.append(path) 41 return fonts 42 43 44def parse_args(): 45 """Parse command line arguments""" 46 parser = argparse.ArgumentParser() 47 parser.add_argument("font_path", help="Path to ttf font file or directory containing ttf files") 48 parser.add_argument("--loss", help="'ctc' or 'warpctc' loss [Default 'ctc']", default='ctc') 49 parser.add_argument("--cpu", 50 help="Number of CPUs for training [Default 8]. Ignored if --gpu is specified.", 51 type=int, default=8) 52 parser.add_argument("--gpu", help="Number of GPUs for training [Default 0]", type=int) 53 parser.add_argument("--num_proc", help="Number CAPTCHA generating processes [Default 4]", type=int, default=4) 54 parser.add_argument("--prefix", help="Checkpoint prefix [Default 'ocr']", default='ocr') 55 return parser.parse_args() 56 57 58def main(): 59 """Program entry point""" 60 args = parse_args() 61 if not any(args.loss == s for s in ['ctc', 'warpctc']): 62 raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss)) 63 64 hp = Hyperparams() 65 66 # Start a multiprocessor captcha image generator 67 mp_captcha = MPDigitCaptcha( 68 font_paths=get_fonts(args.font_path), h=hp.seq_length, w=30, 69 num_digit_min=3, num_digit_max=4, num_processes=args.num_proc, max_queue_size=hp.batch_size * 2) 70 try: 71 # Must call start() before any call to mxnet module (https://github.com/apache/incubator-mxnet/issues/9213) 72 mp_captcha.start() 73 74 if args.gpu: 75 contexts = [mx.context.gpu(i) for i in range(args.gpu)] 76 else: 77 contexts = [mx.context.cpu(i) for i in range(args.cpu)] 78 79 init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden) 80 81 data_train = OCRIter( 82 hp.train_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='train') 83 data_val = OCRIter( 84 hp.eval_epoch_size // hp.batch_size, hp.batch_size, init_states, captcha=mp_captcha, name='val') 85 86 symbol = lstm.lstm_unroll( 87 num_lstm_layer=hp.num_lstm_layer, 88 seq_len=hp.seq_length, 89 num_hidden=hp.num_hidden, 90 num_label=hp.num_label, 91 loss_type=args.loss) 92 93 head = '%(asctime)-15s %(message)s' 94 logging.basicConfig(level=logging.DEBUG, format=head) 95 96 module = mx.mod.Module( 97 symbol, 98 data_names=['data', 'l0_init_c', 'l0_init_h', 'l1_init_c', 'l1_init_h'], 99 label_names=['label'], 100 context=contexts) 101 102 metrics = CtcMetrics(hp.seq_length) 103 module.fit(train_data=data_train, 104 eval_data=data_val, 105 # use metrics.accuracy or metrics.accuracy_lcs 106 eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), 107 optimizer='sgd', 108 optimizer_params={'learning_rate': hp.learning_rate, 109 'momentum': hp.momentum, 110 'wd': 0.00001, 111 }, 112 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), 113 num_epoch=hp.num_epoch, 114 batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50), 115 epoch_end_callback=mx.callback.do_checkpoint(args.prefix), 116 ) 117 except KeyboardInterrupt: 118 print("W: interrupt received, stopping...") 119 finally: 120 # Reset multiprocessing captcha generator to stop processes 121 mp_captcha.reset() 122 123 124if __name__ == '__main__': 125 main() 126