1import chainer 2import chainer.links.rnn as rnn 3import chainermn.functions 4 5 6class _MultiNodeNStepRNN(chainer.Chain): 7 8 def __init__(self, link, communicator, rank_in, rank_out): 9 super(_MultiNodeNStepRNN, self).__init__(actual_rnn=link) 10 11 self.communicator = communicator 12 self.rank_in = rank_in 13 self.rank_out = rank_out 14 15 check_lstm = isinstance(link, rnn.n_step_rnn.NStepRNNBase) 16 if not check_lstm: 17 raise ValueError('link must be NStepRNN and its inherited link') 18 else: 19 self.n_cells = link.n_cells 20 21 def __call__(self, *inputs): 22 cells = [None for _ in range(self.n_cells)] 23 24 if self.rank_in is not None: 25 cells = [chainermn.functions.recv( 26 self.communicator, 27 rank=self.rank_in) 28 for _ in range(self.n_cells)] 29 30 outputs = self.actual_rnn(*(tuple(cells) + inputs)) 31 cells = outputs[:-1] 32 33 delegate_variable = None 34 if self.rank_out is not None: 35 cell = cells[0] 36 for i in range(self.n_cells): 37 delegate_variable = chainermn.functions.send( 38 cell, self.communicator, rank=self.rank_out) 39 if i < self.n_cells - 1: 40 cell, = chainermn.functions.pseudo_connect( 41 delegate_variable, cells[i + 1]) 42 43 return outputs + tuple([delegate_variable]) 44 45 46def create_multi_node_n_step_rnn( 47 actual_link, communicator, rank_in=None, rank_out=None): 48 """Create a multi node stacked RNN link from a Chainer stacked RNN link. 49 50 Multi node stacked RNN link is used for model-parallel. 51 The created link will receive initial hidden states from the process 52 specified by ``rank_in`` (or do not receive if ``None``), execute 53 the original RNN compuation, and then send resulting hidden states 54 to the process specified by ``rank_out``. 55 56 Compared with Chainer stacked RNN link, multi node stacked RNN link 57 returns an extra object called ``delegate_variable``. 58 If ``rank_out`` is not ``None``, backward computation is expected 59 to be begun from ``delegate_variable``. 60 For detail, please refer ``chainermn.functions.pseudo_connect``. 61 62 The following RNN links can be passed to this function: 63 64 - ``chainer.links.NStepBiGRU`` 65 - ``chainer.links.NStepBiLSTM`` 66 - ``chainer.links.NStepBiRNNReLU`` 67 - ``chainer.links.NStepBiRNNTanh`` 68 - ``chainer.links.NStepGRU`` 69 - ``chainer.links.NStepLSTM`` 70 - ``chainer.links.NStepRNNReLU`` 71 - ``chainer.links.NStepRNNTanh`` 72 73 Args: 74 link (chainer.Link): Chainer stacked RNN link 75 communicator: ChainerMN communicator 76 rank_in (int, or None): 77 Rank of the process which sends hidden RNN states to this process. 78 rank_out (int, or None): 79 Rank of the process to which this process sends hiddne RNN states. 80 81 Returns: 82 The multi node stacked RNN link based on ``actual_link``. 83 """ 84 chainer.utils.experimental('chainermn.links.create_multi_node_n_step_rnn') 85 return _MultiNodeNStepRNN(actual_link, communicator, rank_in, rank_out) 86