1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from __future__ import print_function 19from six.moves import range 20 21import argparse 22import subprocess 23from itertools import product 24from time import time 25 26import mxnet as mx 27import numpy as np 28from mxnet import gluon 29 30 31_parser = argparse.ArgumentParser(description='Benchmark foreach and while_loop on RNN tasks.') 32_parser.add_argument('--benchmark', choices=["foreach", "while_loop"], required=True) 33_parser.add_argument('--warmup_rounds', type=int, default=20) 34_parser.add_argument('--test_rounds', type=int, default=100) 35_parser.add_argument('--gpu', type=bool, default=False) 36args = _parser.parse_args() 37 38 39class ForeachRNN(gluon.HybridBlock): 40 def __init__(self, cell, length, prefix=None, params=None): 41 super(ForeachRNN, self).__init__(prefix=prefix, params=params) 42 self.length = length 43 self.cell = cell 44 45 def hybrid_forward(self, F, inputs, states): 46 out, states = F.contrib.foreach(self.cell, inputs, states) 47 return out 48 49 50class WhileRNN(gluon.HybridBlock): 51 def __init__(self, cell, length, prefix=None, params=None): 52 super(WhileRNN, self).__init__(prefix=prefix, params=params) 53 self.length = length 54 self.cell = cell 55 56 def hybrid_forward(self, F, inputs, states): 57 def _func(*states): 58 i = states[0] 59 s = states[1: ] 60 data = inputs.take(i).squeeze(axis=0) 61 out, new_s = self.cell(data, s) 62 new_s = [i + 1] + new_s 63 return out, new_s 64 out, states = F.contrib.while_loop( 65 cond=lambda i, *_: i < self.length, 66 func=_func, 67 loop_vars=states, 68 max_iterations=self.length, 69 ) 70 return out 71 72 73def _zeros(shape, ctx): 74 return mx.nd.zeros(shape=shape, ctx=ctx) 75 76 77def _array(shape, ctx): 78 return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=ctx) 79 80 81def _get_gpus(): 82 return range(mx.util.get_gpu_count()) 83 84def run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim): 85 obj = {"foreach": ForeachRNN, "while_loop": WhileRNN}[args.benchmark] 86 inputs = _array((seq_len, batch_size, hidden_dim), ctx) 87 states = [_array((batch_size, hidden_dim), ctx) for _ in cell_type(0).state_info()] 88 if args.benchmark == "while_loop": 89 states.insert(0, _zeros((1, ), ctx)) 90 91 for is_train, is_hyb_cell, is_hyb_layer in product([True, False], [False, True], [False, True]): 92 cell = cell_type(hidden_dim) 93 if is_hyb_cell: 94 cell.hybridize(static_alloc=True) 95 layer = obj(cell, seq_len) 96 layer.initialize(ctx=ctx) 97 if is_hyb_layer: 98 layer.hybridize(static_alloc=True) 99 print("is_train = %r, hybridize_cell = %r, hybridize_layer = %r" % (is_train, is_hyb_cell, is_hyb_layer)) 100 times = [] 101 for _ in range(args.warmup_rounds + args.test_rounds): 102 tick = time() 103 if not is_train: 104 res = layer(inputs, states) 105 else: 106 with mx.autograd.record(): 107 res = layer(inputs, states) 108 if is_train: 109 res.backward() 110 mx.nd.waitall() 111 tock = time() 112 times.append((tock - tick) * 1000.0) 113 times = times[args.warmup_rounds: ] 114 print("Time used: mean = %.3f ms, std = %.3f ms" % (np.mean(times), np.std(times))) 115 116 117def main(): 118 # testing configurations 119 cell_types = [gluon.rnn.RNNCell, 120 gluon.rnn.GRUCell, 121 gluon.rnn.LSTMCell] 122 ctxs = [mx.cpu(0)] 123 if args.gpu: 124 ctxs = ctxs + [mx.gpu(i) for i in _get_gpus()] 125 seq_lens = [100] 126 batch_sizes = [1, 32] 127 hidden_dims = [512] 128 print("--------------------------------------") 129 print("Benchmarking", args.benchmark) 130 for cell_type, ctx, seq_len, batch_size, hidden_dim in product( \ 131 cell_types, ctxs, seq_lens, batch_sizes, hidden_dims): 132 print("--------------------------------------") 133 print("cell: %s ctx: %s length: %d batch size: %d dim: %d" % \ 134 (cell_type.__name__, str(ctx), seq_len, batch_size, hidden_dim)) 135 run_benchmark(cell_type, ctx, seq_len, batch_size, hidden_dim) 136 137 138if __name__ == "__main__": 139 main() 140