1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import mxnet as mx 19 20from stt_layer_batchnorm import batchnorm 21 22 23def fc(net, 24 num_hidden, 25 act_type, 26 weight=None, 27 bias=None, 28 no_bias=False, 29 name=None 30 ): 31 # when weight and bias doesn't have specific name 32 if weight is None and bias is None: 33 net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name) 34 # when weight doesn't have specific name but bias has 35 elif weight is None and bias is not None: 36 if no_bias: 37 net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, no_bias=no_bias, name=name) 38 else: 39 net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, bias=bias, no_bias=no_bias, name=name) 40 # when bias doesn't have specific name but weight has 41 elif weight is not None and bias is None: 42 net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name) 43 # when weight and bias specific name 44 else: 45 if no_bias: 46 net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, no_bias=no_bias, name=name) 47 else: 48 net = mx.sym.FullyConnected(data=net, num_hidden=num_hidden, weight=weight, bias=bias, no_bias=no_bias, name=name) 49 # activation 50 if act_type is not None: 51 net = mx.sym.Activation(data=net, act_type=act_type, name="%s_activation" % name) 52 return net 53 54 55def sequence_fc(net, 56 seq_len, 57 num_layer, 58 prefix, 59 num_hidden_list=[], 60 act_type_list=[], 61 is_batchnorm=False, 62 dropout_rate=0, 63 ): 64 if num_layer == len(num_hidden_list) == len(act_type_list): 65 if num_layer > 0: 66 weight_list = [] 67 bias_list = [] 68 69 for layer_index in range(num_layer): 70 weight_list.append(mx.sym.Variable(name='%s_sequence_fc%d_weight' % (prefix, layer_index))) 71 # if you use batchnorm bias do not have any effect 72 if not is_batchnorm: 73 bias_list.append(mx.sym.Variable(name='%s_sequence_fc%d_bias' % (prefix, layer_index))) 74 # batch normalization parameters 75 gamma_list = [] 76 beta_list = [] 77 if is_batchnorm: 78 for layer_index in range(num_layer): 79 gamma_list.append(mx.sym.Variable(name='%s_sequence_fc%d_gamma' % (prefix, layer_index))) 80 beta_list.append(mx.sym.Variable(name='%s_sequence_fc%d_beta' % (prefix, layer_index))) 81 # batch normalization parameters ends 82 if type(net) is mx.symbol.Symbol: 83 net = mx.sym.SliceChannel(data=net, num_outputs=seq_len, axis=1, squeeze_axis=1) 84 elif type(net) is list: 85 for net_index, one_net in enumerate(net): 86 if type(one_net) is not mx.symbol.Symbol: 87 raise Exception('%d th elements of the net should be mx.symbol.Symbol' % net_index) 88 else: 89 raise Exception('type of net should be whether mx.symbol.Symbol or list of mx.symbol.Symbol') 90 hidden_all = [] 91 for seq_index in range(seq_len): 92 hidden = net[seq_index] 93 for layer_index in range(num_layer): 94 if dropout_rate > 0: 95 hidden = mx.sym.Dropout(data=hidden, p=dropout_rate) 96 97 if is_batchnorm: 98 hidden = fc(net=hidden, 99 num_hidden=num_hidden_list[layer_index], 100 act_type=None, 101 weight=weight_list[layer_index], 102 no_bias=is_batchnorm, 103 name="%s_t%d_l%d_fc" % (prefix, seq_index, layer_index) 104 ) 105 # last layer doesn't have batchnorm 106 hidden = batchnorm(net=hidden, 107 gamma=gamma_list[layer_index], 108 beta=beta_list[layer_index], 109 name="%s_t%d_l%d_batchnorm" % (prefix, seq_index, layer_index)) 110 hidden = mx.sym.Activation(data=hidden, act_type=act_type_list[layer_index], 111 name="%s_t%d_l%d_activation" % (prefix, seq_index, layer_index)) 112 else: 113 hidden = fc(net=hidden, 114 num_hidden=num_hidden_list[layer_index], 115 act_type=act_type_list[layer_index], 116 weight=weight_list[layer_index], 117 bias=bias_list[layer_index] 118 ) 119 hidden_all.append(hidden) 120 net = hidden_all 121 return net 122 else: 123 raise Exception("length doesn't met - num_layer:", 124 num_layer, ",len(num_hidden_list):", 125 len(num_hidden_list), 126 ",len(act_type_list):", 127 len(act_type_list) 128 ) 129