1import chainer
2import chainer.links.rnn as rnn
3import chainermn.functions
4
5
6class _MultiNodeNStepRNN(chainer.Chain):
7
8    def __init__(self, link, communicator, rank_in, rank_out):
9        super(_MultiNodeNStepRNN, self).__init__(actual_rnn=link)
10
11        self.communicator = communicator
12        self.rank_in = rank_in
13        self.rank_out = rank_out
14
15        check_lstm = isinstance(link, rnn.n_step_rnn.NStepRNNBase)
16        if not check_lstm:
17            raise ValueError('link must be NStepRNN and its inherited link')
18        else:
19            self.n_cells = link.n_cells
20
21    def __call__(self, *inputs):
22        cells = [None for _ in range(self.n_cells)]
23
24        if self.rank_in is not None:
25            cells = [chainermn.functions.recv(
26                self.communicator,
27                rank=self.rank_in)
28                for _ in range(self.n_cells)]
29
30        outputs = self.actual_rnn(*(tuple(cells) + inputs))
31        cells = outputs[:-1]
32
33        delegate_variable = None
34        if self.rank_out is not None:
35            cell = cells[0]
36            for i in range(self.n_cells):
37                delegate_variable = chainermn.functions.send(
38                    cell, self.communicator, rank=self.rank_out)
39                if i < self.n_cells - 1:
40                    cell, = chainermn.functions.pseudo_connect(
41                        delegate_variable, cells[i + 1])
42
43        return outputs + tuple([delegate_variable])
44
45
46def create_multi_node_n_step_rnn(
47        actual_link, communicator, rank_in=None, rank_out=None):
48    """Create a multi node stacked RNN link from a Chainer stacked RNN link.
49
50    Multi node stacked RNN link is used for model-parallel.
51    The created link will receive initial hidden states from the process
52    specified by ``rank_in`` (or do not receive if ``None``), execute
53    the original RNN compuation, and then send resulting hidden states
54    to the process specified by ``rank_out``.
55
56    Compared with Chainer stacked RNN link, multi node stacked RNN link
57    returns an extra object called ``delegate_variable``.
58    If ``rank_out`` is not ``None``, backward computation is expected
59    to be begun from ``delegate_variable``.
60    For detail, please refer ``chainermn.functions.pseudo_connect``.
61
62    The following RNN links can be passed to this function:
63
64    - ``chainer.links.NStepBiGRU``
65    - ``chainer.links.NStepBiLSTM``
66    - ``chainer.links.NStepBiRNNReLU``
67    - ``chainer.links.NStepBiRNNTanh``
68    - ``chainer.links.NStepGRU``
69    - ``chainer.links.NStepLSTM``
70    - ``chainer.links.NStepRNNReLU``
71    - ``chainer.links.NStepRNNTanh``
72
73    Args:
74        link (chainer.Link): Chainer stacked RNN link
75        communicator: ChainerMN communicator
76        rank_in (int, or None):
77            Rank of the process which sends hidden RNN states to this process.
78        rank_out (int, or None):
79            Rank of the process to which this process sends hiddne RNN states.
80
81    Returns:
82        The multi node stacked RNN link based on ``actual_link``.
83    """
84    chainer.utils.experimental('chainermn.links.create_multi_node_n_step_rnn')
85    return _MultiNodeNStepRNN(actual_link, communicator, rank_in, rank_out)
86