1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import sys 19sys.path.insert(0, "../../python/") 20import mxnet as mx 21import numpy as np 22from collections import namedtuple 23import time 24import math 25 26RNNState = namedtuple("RNNState", ["h"]) 27RNNParam = namedtuple("RNNParam", ["i2h_weight", "i2h_bias", 28 "h2h_weight", "h2h_bias"]) 29RNNModel = namedtuple("RNNModel", ["rnn_exec", "symbol", 30 "init_states", "last_states", 31 "seq_data", "seq_labels", "seq_outputs", 32 "param_blocks"]) 33 34def rnn(num_hidden, in_data, prev_state, param, seqidx, layeridx, dropout=0., batch_norm=False): 35 if dropout > 0. : 36 in_data = mx.sym.Dropout(data=in_data, p=dropout) 37 i2h = mx.sym.FullyConnected(data=in_data, 38 weight=param.i2h_weight, 39 bias=param.i2h_bias, 40 num_hidden=num_hidden, 41 name="t%d_l%d_i2h" % (seqidx, layeridx)) 42 h2h = mx.sym.FullyConnected(data=prev_state.h, 43 weight=param.h2h_weight, 44 bias=param.h2h_bias, 45 num_hidden=num_hidden, 46 name="t%d_l%d_h2h" % (seqidx, layeridx)) 47 hidden = i2h + h2h 48 49 hidden = mx.sym.Activation(data=hidden, act_type="tanh") 50 if batch_norm == True: 51 hidden = mx.sym.BatchNorm(data=hidden) 52 return RNNState(h=hidden) 53 54 55 56def rnn_unroll(num_rnn_layer, seq_len, input_size, 57 num_hidden, num_embed, num_label, dropout=0., batch_norm=False): 58 59 embed_weight=mx.sym.Variable("embed_weight") 60 cls_weight = mx.sym.Variable("cls_weight") 61 cls_bias = mx.sym.Variable("cls_bias") 62 param_cells = [] 63 last_states = [] 64 for i in range(num_rnn_layer): 65 param_cells.append(RNNParam(i2h_weight = mx.sym.Variable("l%d_i2h_weight" % i), 66 i2h_bias = mx.sym.Variable("l%d_i2h_bias" % i), 67 h2h_weight = mx.sym.Variable("l%d_h2h_weight" % i), 68 h2h_bias = mx.sym.Variable("l%d_h2h_bias" % i))) 69 state = RNNState(h=mx.sym.Variable("l%d_init_h" % i)) 70 last_states.append(state) 71 assert(len(last_states) == num_rnn_layer) 72 73 loss_all = [] 74 for seqidx in range(seq_len): 75 # embeding layer 76 data = mx.sym.Variable("data/%d" % seqidx) 77 78 hidden = mx.sym.Embedding(data=data, weight=embed_weight, 79 input_dim=input_size, 80 output_dim=num_embed, 81 name="t%d_embed" % seqidx) 82 # stack RNN 83 for i in range(num_rnn_layer): 84 if i==0: 85 dp=0. 86 else: 87 dp = dropout 88 next_state = rnn(num_hidden, in_data=hidden, 89 prev_state=last_states[i], 90 param=param_cells[i], 91 seqidx=seqidx, layeridx=i, dropout=dp, batch_norm=batch_norm) 92 hidden = next_state.h 93 last_states[i] = next_state 94 # decoder 95 if dropout > 0.: 96 hidden = mx.sym.Dropout(data=hidden, p=dropout) 97 fc = mx.sym.FullyConnected(data=hidden, weight=cls_weight, bias=cls_bias, 98 num_hidden=num_label) 99 sm = mx.sym.SoftmaxOutput(data=fc, label=mx.sym.Variable('label/%d' % seqidx), 100 name='t%d_sm' % seqidx) 101 loss_all.append(sm) 102 return mx.sym.Group(loss_all) 103