1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=C0111, too-many-statements, too-many-locals
19# pylint: too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
20# pylint: disable=superfluous-parens, no-member, invalid-name
21"""
22architecture file for deep speech 2 model
23"""
24import json
25import math
26import argparse
27import mxnet as mx
28
29from stt_layer_batchnorm import batchnorm
30from stt_layer_conv import conv
31from stt_layer_fc import sequence_fc
32from stt_layer_gru import bi_gru_unroll, gru_unroll
33from stt_layer_lstm import bi_lstm_unroll
34from stt_layer_slice import slice_symbol_to_seq_symobls
35from stt_layer_warpctc import warpctc_layer
36
37
38def prepare_data(args):
39    """
40    set atual shape of data
41    """
42    rnn_type = args.config.get("arch", "rnn_type")
43    num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
44    num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
45
46    batch_size = args.config.getint("common", "batch_size")
47
48    if rnn_type == 'lstm':
49        init_c = [('l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
50                  for l in range(num_rnn_layer)]
51        init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
52                  for l in range(num_rnn_layer)]
53    elif rnn_type == 'bilstm':
54        forward_init_c = [('forward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
55                          for l in range(num_rnn_layer)]
56        backward_init_c = [('backward_l%d_init_c' % l, (batch_size, num_hidden_rnn_list[l]))
57                           for l in range(num_rnn_layer)]
58        init_c = forward_init_c + backward_init_c
59        forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
60                          for l in range(num_rnn_layer)]
61        backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
62                           for l in range(num_rnn_layer)]
63        init_h = forward_init_h + backward_init_h
64    elif rnn_type == 'gru':
65        init_h = [('l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
66                  for l in range(num_rnn_layer)]
67    elif rnn_type == 'bigru':
68        forward_init_h = [('forward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
69                          for l in range(num_rnn_layer)]
70        backward_init_h = [('backward_l%d_init_h' % l, (batch_size, num_hidden_rnn_list[l]))
71                           for l in range(num_rnn_layer)]
72        init_h = forward_init_h + backward_init_h
73    else:
74        raise Exception('network type should be one of the lstm,bilstm,gru,bigru')
75
76    if rnn_type == 'lstm' or rnn_type == 'bilstm':
77        init_states = init_c + init_h
78    elif rnn_type == 'gru' or rnn_type == 'bigru':
79        init_states = init_h
80    return init_states
81
82
83def arch(args, seq_len=None):
84    """
85    define deep speech 2 network
86    """
87    if isinstance(args, argparse.Namespace):
88        mode = args.config.get("common", "mode")
89        is_bucketing = args.config.getboolean("arch", "is_bucketing")
90        if mode == "train" or is_bucketing:
91            channel_num = args.config.getint("arch", "channel_num")
92            conv_layer1_filter_dim = \
93                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
94            conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
95            conv_layer2_filter_dim = \
96                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
97            conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
98
99            rnn_type = args.config.get("arch", "rnn_type")
100            num_rnn_layer = args.config.getint("arch", "num_rnn_layer")
101            num_hidden_rnn_list = json.loads(args.config.get("arch", "num_hidden_rnn_list"))
102
103            is_batchnorm = args.config.getboolean("arch", "is_batchnorm")
104
105            if seq_len is None:
106                seq_len = args.config.getint('arch', 'max_t_count')
107
108            num_label = args.config.getint('arch', 'max_label_length')
109
110            num_rear_fc_layers = args.config.getint("arch", "num_rear_fc_layers")
111            num_hidden_rear_fc_list = json.loads(args.config.get("arch", "num_hidden_rear_fc_list"))
112            act_type_rear_fc_list = json.loads(args.config.get("arch", "act_type_rear_fc_list"))
113            # model symbol generation
114            # input preparation
115            data = mx.sym.Variable('data')
116            label = mx.sym.Variable('label')
117
118            net = mx.sym.Reshape(data=data, shape=(-4, -1, 1, 0, 0))
119            net = conv(net=net,
120                       channels=channel_num,
121                       filter_dimension=conv_layer1_filter_dim,
122                       stride=conv_layer1_stride,
123                       no_bias=is_batchnorm,
124                       name='conv1')
125            if is_batchnorm:
126               # batch norm normalizes axis 1
127               net = batchnorm(net, name="conv1_batchnorm")
128
129            net = conv(net=net,
130                       channels=channel_num,
131                       filter_dimension=conv_layer2_filter_dim,
132                       stride=conv_layer2_stride,
133                       no_bias=is_batchnorm,
134                       name='conv2')
135            if is_batchnorm:
136                # batch norm normalizes axis 1
137                net = batchnorm(net, name="conv2_batchnorm")
138
139            net = mx.sym.transpose(data=net, axes=(0, 2, 1, 3))
140            net = mx.sym.Reshape(data=net, shape=(0, 0, -3))
141            seq_len_after_conv_layer1 = int(
142                math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
143            seq_len_after_conv_layer2 = int(
144                math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
145                           / conv_layer2_stride[0])) + 1
146            net = slice_symbol_to_seq_symobls(net=net, seq_len=seq_len_after_conv_layer2, axis=1)
147            if rnn_type == "bilstm":
148                net = bi_lstm_unroll(net=net,
149                                     seq_len=seq_len_after_conv_layer2,
150                                     num_hidden_lstm_list=num_hidden_rnn_list,
151                                     num_lstm_layer=num_rnn_layer,
152                                     dropout=0.,
153                                     is_batchnorm=is_batchnorm,
154                                     is_bucketing=is_bucketing)
155            elif rnn_type == "gru":
156                net = gru_unroll(net=net,
157                                 seq_len=seq_len_after_conv_layer2,
158                                 num_hidden_gru_list=num_hidden_rnn_list,
159                                 num_gru_layer=num_rnn_layer,
160                                 dropout=0.,
161                                 is_batchnorm=is_batchnorm,
162                                 is_bucketing=is_bucketing)
163            elif rnn_type == "bigru":
164                net = bi_gru_unroll(net=net,
165                                    seq_len=seq_len_after_conv_layer2,
166                                    num_hidden_gru_list=num_hidden_rnn_list,
167                                    num_gru_layer=num_rnn_layer,
168                                    dropout=0.,
169                                    is_batchnorm=is_batchnorm,
170                                    is_bucketing=is_bucketing)
171            else:
172                raise Exception('rnn_type should be one of the followings, bilstm,gru,bigru')
173
174            # rear fc layers
175            net = sequence_fc(net=net, seq_len=seq_len_after_conv_layer2,
176                              num_layer=num_rear_fc_layers, prefix="rear",
177                              num_hidden_list=num_hidden_rear_fc_list,
178                              act_type_list=act_type_rear_fc_list,
179                              is_batchnorm=is_batchnorm)
180            # warpctc layer
181            net = warpctc_layer(net=net,
182                                seq_len=seq_len_after_conv_layer2,
183                                label=label,
184                                num_label=num_label,
185                                character_classes_count=
186                                (args.config.getint('arch', 'n_classes') + 1))
187            args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
188            return net
189        elif mode == 'load' or mode == 'predict':
190            conv_layer1_filter_dim = \
191                tuple(json.loads(args.config.get("arch", "conv_layer1_filter_dim")))
192            conv_layer1_stride = tuple(json.loads(args.config.get("arch", "conv_layer1_stride")))
193            conv_layer2_filter_dim = \
194                tuple(json.loads(args.config.get("arch", "conv_layer2_filter_dim")))
195            conv_layer2_stride = tuple(json.loads(args.config.get("arch", "conv_layer2_stride")))
196            if seq_len is None:
197                seq_len = args.config.getint('arch', 'max_t_count')
198            seq_len_after_conv_layer1 = int(
199                math.floor((seq_len - conv_layer1_filter_dim[0]) / conv_layer1_stride[0])) + 1
200            seq_len_after_conv_layer2 = int(
201                math.floor((seq_len_after_conv_layer1 - conv_layer2_filter_dim[0])
202                           / conv_layer2_stride[0])) + 1
203
204            args.config.set('arch', 'max_t_count', str(seq_len_after_conv_layer2))
205        else:
206            raise Exception('mode must be the one of the followings - train,predict,load')
207
208
209class BucketingArch(object):
210    def __init__(self, args):
211        self.args = args
212
213    def sym_gen(self, seq_len):
214        args = self.args
215        net = arch(args, seq_len)
216        init_states = prepare_data(args)
217        init_state_names = [x[0] for x in init_states]
218        init_state_names.insert(0, 'data')
219        return net, init_state_names, ('label',)
220
221    def get_sym_gen(self):
222        return self.sym_gen
223