1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable=C0111, too-many-statements, too-many-locals 19# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme 20# pylint: disable=superfluous-parens, no-member, invalid-name 21""" 22architecture file for deep speech 2 model 23""" 24import json 25import math 26import argparse 27import mxnet as mx 28 29from stt_layer_batchnorm import batchnorm 30from stt_layer_conv import conv 31from stt_layer_fc import sequence_fc 32from stt_layer_gru import bi_gru_unroll, gru_unroll 33from stt_layer_lstm import bi_lstm_unroll 34from stt_layer_slice import slice_symbol_to_seq_symobls 35from stt_layer_warpctc import warpctc_layer 36 37 38def prepare_data(args): 39 """ 40 set atual shape of data 41 """ 42 rnn_type = args.config.get("arch", "rnn_type") 43 num_rnn_layer = args.config.getint("arch", "num_rnn_layer") 44 num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) 45 46 batch_size = args.config.getint("common", "batch_size") 47 48 if rnn_type == 'lstm': 49 init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) 50 for l in range(num_rnn_layer)] 51 init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) 52 for l in range(num_rnn_layer)] 53 elif rnn_type == 'bilstm': 54 forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) 55 for l in range(num_rnn_layer)] 56 backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l])) 57 for l in range(num_rnn_layer)] 58 init_c = forward_init_c + backward_init_c 59 forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) 60 for l in range(num_rnn_layer)] 61 backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) 62 for l in range(num_rnn_layer)] 63 init_h = forward_init_h + backward_init_h 64 elif rnn_type == 'gru': 65 init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) 66 for l in range(num_rnn_layer)] 67 elif rnn_type == 'bigru': 68 forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) 69 for l in range(num_rnn_layer)] 70 backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l])) 71 for l in range(num_rnn_layer)] 72 init_h = forward_init_h + backward_init_h 73 else: 74 raise Exception('network type should be one of the lstm,bilstm,gru,bigru') 75 76 if rnn_type == 'lstm' or rnn_type == 'bilstm': 77 init_states = init_c + init_h 78 elif rnn_type == 'gru' or rnn_type == 'bigru': 79 init_states = init_h 80 return init_states 81 82 83def arch(args, seq_len=None): 84 """ 85 define deep speech 2 network 86 """ 87 if isinstance(args, argparse.Namespace): 88 mode = args.config.get("common", "mode") 89 is_bucketing = args.config.getboolean("arch", "is_bucketing") 90 if mode == "train" or is_bucketing: 91 channel_num = args.config.getint("arch", "channel_num") 92 conv_layer1_filter_dim = \ 93 tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) 94 conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) 95 conv_layer2_filter_dim = \ 96 tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) 97 conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) 98 99 rnn_type = args.config.get("arch", "rnn_type") 100 num_rnn_layer = args.config.getint("arch", "num_rnn_layer") 101 num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list")) 102 103 is_batchnorm = args.config.getboolean("arch", "is_batchnorm") 104 105 if seq_len is None: 106 seq_len = args.config.getint('arch', 'max_t_count') 107 108 num_label = args.config.getint('arch', 'max_label_length') 109 110 num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers") 111 num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list")) 112 act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list")) 113 # model symbol generation 114 # input preparation 115 data = mx.sym.Variable('data') 116 label = mx.sym.Variable('label') 117 118 net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0)) 119 net = conv(net=net, 120 channels=channel_num, 121 filter_dimension=conv_layer1_filter_dim, 122 stride=conv_layer1_stride, 123 no_bias=is_batchnorm, 124 name='conv1') 125 if is_batchnorm: 126 # batch norm normalizes axis 1 127 net = batchnorm(net, name="conv1_batchnorm") 128 129 net = conv(net=net, 130 channels=channel_num, 131 filter_dimension=conv_layer2_filter_dim, 132 stride=conv_layer2_stride, 133 no_bias=is_batchnorm, 134 name='conv2') 135 if is_batchnorm: 136 # batch norm normalizes axis 1 137 net = batchnorm(net, name="conv2_batchnorm") 138 139 net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3)) 140 net = mx.sym.Reshape(data=net, shape=(0, 0, -3)) 141 seq_len_after_conv_layer1 = int( 142 math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 143 seq_len_after_conv_layer2 = int( 144 math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) 145 / conv_layer2_stride[0])) + 1 146 net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1) 147 if rnn_type == "bilstm": 148 net = bi_lstm_unroll(net=net, 149 seq_len=seq_len_after_conv_layer2, 150 num_hidden_lstm_list=num_hidden_rnn_list, 151 num_lstm_layer=num_rnn_layer, 152 dropout=0., 153 is_batchnorm=is_batchnorm, 154 is_bucketing=is_bucketing) 155 elif rnn_type == "gru": 156 net = gru_unroll(net=net, 157 seq_len=seq_len_after_conv_layer2, 158 num_hidden_gru_list=num_hidden_rnn_list, 159 num_gru_layer=num_rnn_layer, 160 dropout=0., 161 is_batchnorm=is_batchnorm, 162 is_bucketing=is_bucketing) 163 elif rnn_type == "bigru": 164 net = bi_gru_unroll(net=net, 165 seq_len=seq_len_after_conv_layer2, 166 num_hidden_gru_list=num_hidden_rnn_list, 167 num_gru_layer=num_rnn_layer, 168 dropout=0., 169 is_batchnorm=is_batchnorm, 170 is_bucketing=is_bucketing) 171 else: 172 raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru') 173 174 # rear fc layers 175 net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2, 176 num_layer=num_rear_fc_layers, prefix="rear", 177 num_hidden_list=num_hidden_rear_fc_list, 178 act_type_list=act_type_rear_fc_list, 179 is_batchnorm=is_batchnorm) 180 # warpctc layer 181 net = warpctc_layer(net=net, 182 seq_len=seq_len_after_conv_layer2, 183 label=label, 184 num_label=num_label, 185 character_classes_count= 186 (args.config.getint('arch', 'n_classes') + 1)) 187 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) 188 return net 189 elif mode == 'load' or mode == 'predict': 190 conv_layer1_filter_dim = \ 191 tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim"))) 192 conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride"))) 193 conv_layer2_filter_dim = \ 194 tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim"))) 195 conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride"))) 196 if seq_len is None: 197 seq_len = args.config.getint('arch', 'max_t_count') 198 seq_len_after_conv_layer1 = int( 199 math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1 200 seq_len_after_conv_layer2 = int( 201 math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0]) 202 / conv_layer2_stride[0])) + 1 203 204 args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2)) 205 else: 206 raise Exception('mode must be the one of the followings - train,predict,load') 207 208 209class BucketingArch(object): 210 def __init__(self, args): 211 self.args = args 212 213 def sym_gen(self, seq_len): 214 args = self.args 215 net = arch(args, seq_len) 216 init_states = prepare_data(args) 217 init_state_names = [x[0] for x in init_states] 218 init_state_names.insert(0, 'data') 219 return net, init_state_names, ('label',) 220 221 def get_sym_gen(self): 222 return self.sym_gen 223